FD900 commited on
Commit
4d9f4ff
·
verified ·
1 Parent(s): 0a31dab

Update mistral_hf_wrapper.py

Browse files
Files changed (1) hide show
  1. mistral_hf_wrapper.py +19 -2
mistral_hf_wrapper.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import requests
3
 
@@ -5,16 +7,31 @@ class MistralInference:
5
  def __init__(self):
6
  self.api_url = os.getenv("HF_MISTRAL_URL")
7
  self.api_token = os.getenv("HF_TOKEN")
 
 
8
 
9
  def run(self, prompt: str) -> str:
10
  headers = {
11
  "Authorization": f"Bearer {self.api_token}",
12
  "Content-Type": "application/json"
13
  }
 
14
  payload = {
15
  "inputs": prompt,
16
- "parameters": {"max_new_tokens": 512}
 
 
 
 
17
  }
 
18
  response = requests.post(self.api_url, headers=headers, json=payload)
19
  response.raise_for_status()
20
- return response.json()[0]["generated_text"]
 
 
 
 
 
 
 
 
1
+ # mistral_hf_wrapper.py
2
+
3
  import os
4
  import requests
5
 
 
7
  def __init__(self):
8
  self.api_url = os.getenv("HF_MISTRAL_URL")
9
  self.api_token = os.getenv("HF_TOKEN")
10
+ if not self.api_url or not self.api_token:
11
+ raise ValueError("Missing HF_MISTRAL_URL or HF_TOKEN environment variables")
12
 
13
  def run(self, prompt: str) -> str:
14
  headers = {
15
  "Authorization": f"Bearer {self.api_token}",
16
  "Content-Type": "application/json"
17
  }
18
+
19
  payload = {
20
  "inputs": prompt,
21
+ "parameters": {
22
+ "max_new_tokens": 512,
23
+ "temperature": 0.7,
24
+ "return_full_text": False
25
+ }
26
  }
27
+
28
  response = requests.post(self.api_url, headers=headers, json=payload)
29
  response.raise_for_status()
30
+
31
+ output = response.json()
32
+ if isinstance(output, list) and "generated_text" in output[0]:
33
+ return output[0]["generated_text"]
34
+ elif isinstance(output, dict) and "generated_text" in output:
35
+ return output["generated_text"]
36
+ else:
37
+ raise ValueError("Unexpected response format from Mistral endpoint")