FD900 commited on
Commit
0a31dab
·
verified ·
1 Parent(s): dd99f47

Update mistral_hf_wrapper.py

Browse files
Files changed (1) hide show
  1. mistral_hf_wrapper.py +16 -17
mistral_hf_wrapper.py CHANGED
@@ -1,21 +1,20 @@
1
  import os
2
  import requests
3
 
4
- API_URL = os.getenv("HF_MISTRAL_ENDPOINT")
5
- API_TOKEN = os.getenv("HF_TOKEN")
 
 
6
 
7
- headers = {
8
- "Authorization": f"Bearer {API_TOKEN}",
9
- "Content-Type": "application/json"
10
- }
11
-
12
- def query_mistral(system_prompt: str, user_prompt: str) -> str:
13
- """Query the Mistral model hosted on Hugging Face."""
14
- prompt = f"<s>[INST] {system_prompt.strip()}\n\n{user_prompt.strip()} [/INST]"
15
- response = requests.post(
16
- API_URL,
17
- headers=headers,
18
- json={"inputs": prompt}
19
- )
20
- response.raise_for_status()
21
- return response.json()["generated_text"].strip()
 
1
  import os
2
  import requests
3
 
4
+ class MistralInference:
5
+ def __init__(self):
6
+ self.api_url = os.getenv("HF_MISTRAL_URL")
7
+ self.api_token = os.getenv("HF_TOKEN")
8
 
9
+ def run(self, prompt: str) -> str:
10
+ headers = {
11
+ "Authorization": f"Bearer {self.api_token}",
12
+ "Content-Type": "application/json"
13
+ }
14
+ payload = {
15
+ "inputs": prompt,
16
+ "parameters": {"max_new_tokens": 512}
17
+ }
18
+ response = requests.post(self.api_url, headers=headers, json=payload)
19
+ response.raise_for_status()
20
+ return response.json()[0]["generated_text"]