Spaces:
Sleeping
Sleeping
| # mistral_hf_wrapper.py | |
| import os | |
| import requests | |
| class MistralInference: | |
| def __init__(self): | |
| self.api_url = os.getenv("HF_MISTRAL_URL") | |
| self.api_token = os.getenv("HF_TOKEN") | |
| if not self.api_url or not self.api_token: | |
| raise ValueError("Missing HF_MISTRAL_URL or HF_TOKEN environment variables") | |
| def run(self, prompt: str) -> str: | |
| headers = { | |
| "Authorization": f"Bearer {self.api_token}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "max_new_tokens": 512, | |
| "temperature": 0.7, | |
| "return_full_text": False | |
| } | |
| } | |
| response = requests.post(self.api_url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| output = response.json() | |
| if isinstance(output, list) and "generated_text" in output[0]: | |
| return output[0]["generated_text"] | |
| elif isinstance(output, dict) and "generated_text" in output: | |
| return output["generated_text"] | |
| else: | |
| raise ValueError("Unexpected response format from Mistral endpoint") |