Nav772 commited on
Commit
1cc35a6
·
verified ·
1 Parent(s): 13fdc21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -19
app.py CHANGED
@@ -10,35 +10,43 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
- from transformers import pipeline
 
14
 
15
  class BasicAgent:
16
  def __init__(self):
17
- print("Loading Mistral-7B-Instruct-v0.1 using pipeline...")
18
-
19
- self.pipe = pipeline(
20
- "text-generation",
21
- model="mistralai/Mistral-7B-Instruct-v0.1",
22
- device=-1 # CPU only
23
- )
24
 
25
  def __call__(self, question: str) -> str:
26
- print(f"Received question: {question[:50]}...")
27
-
28
  prompt = f"<s>[INST] {question.strip()} [/INST]"
29
 
30
  try:
31
- output = self.pipe(
32
- prompt,
33
- max_new_tokens=256,
34
- temperature=0.7,
35
- top_p=0.95
36
  )
37
- full_response = output[0]["generated_text"]
38
- return full_response.split("[/INST]")[-1].strip()
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
- print(f"❌ Inference Error: {e}")
41
- return f"❌ Model Error: {str(e)}"
42
 
43
  def run_and_submit_all( profile: gr.OAuthProfile | None):
44
  """
 
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
+ import requests
14
+ import os
15
 
16
  class BasicAgent:
17
  def __init__(self):
18
+ print("Mistral Agent using Inference API initialized.")
19
+ self.token = os.getenv("HF_NEW_API_TOKEN")
20
+ self.api_url = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
21
+ self.headers = {
22
+ "Authorization": f"Bearer {self.token}",
23
+ "Content-Type": "application/json"
24
+ }
25
 
26
  def __call__(self, question: str) -> str:
27
+ print(f"Sending question to API: {question[:50]}...")
 
28
  prompt = f"<s>[INST] {question.strip()} [/INST]"
29
 
30
  try:
31
+ response = requests.post(
32
+ self.api_url,
33
+ headers=self.headers,
34
+ json={"inputs": prompt, "parameters": {"max_new_tokens": 256, "temperature": 0.7}},
35
+ timeout=60
36
  )
37
+ response.raise_for_status()
38
+ output = response.json()
39
+
40
+ # Handle potential format differences
41
+ if isinstance(output, list) and "generated_text" in output[0]:
42
+ return output[0]["generated_text"].split("[/INST]")[-1].strip()
43
+ else:
44
+ print(f"⚠️ Unexpected response: {output}")
45
+ return "⚠️ Mistral returned an unexpected format."
46
+
47
  except Exception as e:
48
+ print(f"❌ Error during API call: {e}")
49
+ return f"❌ API Error: {str(e)}"
50
 
51
  def run_and_submit_all( profile: gr.OAuthProfile | None):
52
  """