Nav772 commited on
Commit
ee43ff0
·
verified ·
1 Parent(s): ee8cd34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -10,31 +10,35 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
- from transformers import pipeline
14
 
15
  class BasicAgent:
16
  def __init__(self):
17
- print("FLAN-T5-SMALL Local Agent initialized.")
18
 
 
 
 
 
19
  self.pipeline = pipeline(
20
- "text2text-generation",
21
- model="google/flan-t5-small",
22
- tokenizer="google/flan-t5-small",
23
- device=-1
24
  )
25
 
26
  def __call__(self, question: str) -> str:
27
  print(f"Agent received question (first 50 chars): {question[:50]}...")
28
 
29
  try:
30
- prompt = f"Answer the following question:\n{question.strip()}"
31
- result = self.pipeline(prompt, max_new_tokens=128, temperature=0.5)
32
- answer = result[0]["generated_text"]
33
- return answer.strip()
34
  except Exception as e:
35
  print(f"❌ Error during model inference: {e}")
36
  return f"❌ Model Error: {str(e)}"
37
 
 
38
  def run_and_submit_all( profile: gr.OAuthProfile | None):
39
  """
40
  Fetches all questions, runs the BasicAgent on them, submits all answers,
 
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
14
 
15
  class BasicAgent:
16
  def __init__(self):
17
+ print("Mistral Local Agent initialized.")
18
 
19
+ model_id = "mistralai/Mistral-7B-Instruct-v0.1"
20
+
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ self.model = AutoModelForCausalLM.from_pretrained(model_id)
23
  self.pipeline = pipeline(
24
+ "text-generation",
25
+ model=self.model,
26
+ tokenizer=self.tokenizer,
27
+ device=-1 # CPU
28
  )
29
 
30
  def __call__(self, question: str) -> str:
31
  print(f"Agent received question (first 50 chars): {question[:50]}...")
32
 
33
  try:
34
+ prompt = f"<s>[INST] {question.strip()} [/INST]"
35
+ result = self.pipeline(prompt, max_new_tokens=256, temperature=0.7)
36
+ return result[0]["generated_text"].split("[/INST]")[-1].strip()
 
37
  except Exception as e:
38
  print(f"❌ Error during model inference: {e}")
39
  return f"❌ Model Error: {str(e)}"
40
 
41
+
42
  def run_and_submit_all( profile: gr.OAuthProfile | None):
43
  """
44
  Fetches all questions, runs the BasicAgent on them, submits all answers,