Nav772 commited on
Commit
a72e2e3
·
verified ·
1 Parent(s): 8e2229a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -54
app.py CHANGED
@@ -11,65 +11,30 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  # ---------- MODIFICATIONS BEGIN ----------
14
- from transformers import AutoTokenizer, AutoModelForCausalLM
15
- import torch
16
- import os
17
- import re
18
 
19
  class BasicAgent:
20
  def __init__(self):
21
- print("Hybrid Agent with Mistral Model Initialized")
22
-
23
- model_id = "mistralai/Mistral-7B-Instruct-v0.1"
24
-
25
- self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_NEW_API_TOKEN"))
26
- self.model = AutoModelForCausalLM.from_pretrained(model_id, token=os.getenv("HF_NEW_API_TOKEN"))
27
- self.model.to("cpu")
28
- self.model.eval()
29
-
30
- def classify(self, question: str) -> str:
31
- q = question.lower()
32
- if any(x in q for x in ["youtube", ".mp3", "image", "video", "attached", ".wav"]):
33
- return "media"
34
- if any(x in q for x in ["|", "*", "subset", "commutative", "table", "="]):
35
- return "logic"
36
- return "mistral"
37
-
38
- def handle_media(self, q: str) -> str:
39
- return "I'm unable to process audio, video, or file-based questions."
40
-
41
- def handle_logic(self, q: str) -> str:
42
- q = q.lower()
43
- if "not commutative" in q and "subset" in q:
44
- return "a,b,c"
45
- return "I couldn't solve this logic-based question."
46
-
47
- def handle_mistral(self, question: str) -> str:
48
- prompt = f"<s>[INST] {question.strip()} [/INST]"
49
- inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu")
50
-
51
- with torch.no_grad():
52
- outputs = self.model.generate(
53
- **inputs,
54
- max_new_tokens=256,
55
- do_sample=True,
56
- temperature=0.7,
57
- top_p=0.95
58
- )
59
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
60
- return response.split("[/INST]")[-1].strip()
61
 
62
  def __call__(self, question: str) -> str:
63
- qtype = self.classify(question)
64
- print(f"Classified as: {qtype}")
65
-
66
- if qtype == "media":
67
- return self.handle_media(question)
68
- elif qtype == "logic":
69
- return self.handle_logic(question)
70
- else:
71
- return self.handle_mistral(question)
72
- # ---------- MODIFICATIONS END ----------
73
 
74
  def run_and_submit_all( profile: gr.OAuthProfile | None):
75
  """
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  # ---------- MODIFICATIONS BEGIN ----------
14
+ from transformers import pipeline
 
 
 
15
 
16
  class BasicAgent:
17
  def __init__(self):
18
+ print("FLAN-T5-BASE Local Agent initialized.")
19
+
20
+ self.pipeline = pipeline(
21
+ "text2text-generation",
22
+ model="google/flan-t5-base",
23
+ tokenizer="google/flan-t5-base",
24
+ device=-1 # use CPU
25
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def __call__(self, question: str) -> str:
28
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
29
+
30
+ try:
31
+ prompt = f"Answer the following question:\n{question.strip()}"
32
+ result = self.pipeline(prompt, max_new_tokens=128, temperature=0.5)
33
+ answer = result[0]["generated_text"]
34
+ return answer.strip()
35
+ except Exception as e:
36
+ print(f"❌ Error during model inference: {e}")
37
+ return f"❌ Model Error: {str(e)}"
38
 
39
  def run_and_submit_all( profile: gr.OAuthProfile | None):
40
  """