Final_Assignment_Template / mistral_hf_wrapper.py
FD900's picture
Create mistral_hf_wrapper.py
877add2 verified
raw
history blame
1.02 kB
import requests
import os
class MistralInference:
def __init__(self, api_url: str, api_token: str):
self.api_url = api_url.rstrip("/")
self.headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json"
}
def generate(self, prompt: str, temperature: float = 0.7, max_tokens: int = 512) -> str:
payload = {
"inputs": prompt,
"parameters": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"return_full_text": False
}
}
response = requests.post(
f"{self.api_url}/generate",
headers=self.headers,
json=payload
)
if response.status_code != 200:
raise RuntimeError(f"Request failed: {response.status_code} - {response.text}")
data = response.json()
return data["generated_text"] if "generated_text" in data else data[0]["generated_text"]