Final_Assignment_Template / mistral_hf_wrapper.py
FD900's picture
Update mistral_hf_wrapper.py
4d9f4ff verified
raw
history blame
1.21 kB
# mistral_hf_wrapper.py
import os
import requests
class MistralInference:
def __init__(self):
self.api_url = os.getenv("HF_MISTRAL_URL")
self.api_token = os.getenv("HF_TOKEN")
if not self.api_url or not self.api_token:
raise ValueError("Missing HF_MISTRAL_URL or HF_TOKEN environment variables")
def run(self, prompt: str) -> str:
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json"
}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 512,
"temperature": 0.7,
"return_full_text": False
}
}
response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status()
output = response.json()
if isinstance(output, list) and "generated_text" in output[0]:
return output[0]["generated_text"]
elif isinstance(output, dict) and "generated_text" in output:
return output["generated_text"]
else:
raise ValueError("Unexpected response format from Mistral endpoint")