Nav772 commited on
Commit
6105dbe
·
verified ·
1 Parent(s): 1cc35a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -27
app.py CHANGED
@@ -10,43 +10,49 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
- import requests
14
- import os
15
 
16
  class BasicAgent:
17
  def __init__(self):
18
- print("Mistral Agent using Inference API initialized.")
19
- self.token = os.getenv("HF_NEW_API_TOKEN")
20
- self.api_url = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
21
- self.headers = {
22
- "Authorization": f"Bearer {self.token}",
23
- "Content-Type": "application/json"
24
- }
 
 
 
 
25
 
26
  def __call__(self, question: str) -> str:
27
- print(f"Sending question to API: {question[:50]}...")
28
  prompt = f"<s>[INST] {question.strip()} [/INST]"
29
 
30
  try:
31
- response = requests.post(
32
- self.api_url,
33
- headers=self.headers,
34
- json={"inputs": prompt, "parameters": {"max_new_tokens": 256, "temperature": 0.7}},
35
- timeout=60
36
- )
37
- response.raise_for_status()
38
- output = response.json()
39
-
40
- # Handle potential format differences
41
- if isinstance(output, list) and "generated_text" in output[0]:
42
- return output[0]["generated_text"].split("[/INST]")[-1].strip()
43
- else:
44
- print(f"⚠️ Unexpected response: {output}")
45
- return "⚠️ Mistral returned an unexpected format."
 
 
 
46
 
47
  except Exception as e:
48
- print(f"❌ Error during API call: {e}")
49
- return f"❌ API Error: {str(e)}"
50
 
51
  def run_and_submit_all( profile: gr.OAuthProfile | None):
52
  """
 
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+ import torch
15
 
16
  class BasicAgent:
17
  def __init__(self):
18
+ print("Loading Mistral with manual generate()...")
19
+
20
+ model_id = "mistralai/Mistral-7B-Instruct-v0.1"
21
+
22
+ # Load tokenizer and model (gated model → needs HF token access if private)
23
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_NEW_API_TOKEN"))
24
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, token=os.getenv("HF_NEW_API_TOKEN"))
25
+
26
+ # CPU-only
27
+ self.model.to("cpu")
28
+ self.model.eval()
29
 
30
  def __call__(self, question: str) -> str:
 
31
  prompt = f"<s>[INST] {question.strip()} [/INST]"
32
 
33
  try:
34
+ # Tokenize the prompt
35
+ inputs = self.tokenizer(prompt, return_tensors="pt")
36
+ input_ids = inputs["input_ids"].to("cpu")
37
+
38
+ # Generate text
39
+ with torch.no_grad():
40
+ generated_ids = self.model.generate(
41
+ input_ids,
42
+ max_new_tokens=256,
43
+ do_sample=True,
44
+ temperature=0.7,
45
+ top_p=0.95
46
+ )
47
+
48
+ # Decode output
49
+ output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
50
+ answer = output.split("[/INST]")[-1].strip()
51
+ return answer
52
 
53
  except Exception as e:
54
+ print(f"❌ Error during generation: {e}")
55
+ return f"❌ Model Error: {str(e)}"
56
 
57
  def run_and_submit_all( profile: gr.OAuthProfile | None):
58
  """