meghkaa commited on
Commit
7d7f1b9
·
verified ·
1 Parent(s): 81917a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -5
app.py CHANGED
@@ -3,6 +3,9 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -10,14 +13,81 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
 
13
  class BasicAgent:
14
  def __init__(self):
15
- print("BasicAgent initialized.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def __call__(self, question: str) -> str:
17
- print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from transformers import pipeline
7
+ import re
8
+ import math
9
 
10
  # (Keep Constants as is)
11
  # --- Constants ---
 
13
 
14
  # --- Basic Agent Definition ---
15
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
16
+
17
  class BasicAgent:
18
  def __init__(self):
19
+ print("Loading stronger instruction model...")
20
+ self.generator = pipeline(
21
+ "text-generation",
22
+ model="mistralai/Mistral-7B-Instruct-v0.2",
23
+ max_new_tokens=256,
24
+ do_sample=False,
25
+ )
26
+ print("Model loaded.")
27
+
28
+ # ---- Simple math tool ----
29
+ def try_math(self, question: str):
30
+ try:
31
+ # Detect simple arithmetic expressions
32
+ expression = re.findall(r"[\d\.\+\-\*\/\(\) ]+", question)
33
+ if expression:
34
+ candidate = expression[0]
35
+ result = eval(candidate)
36
+ return str(result)
37
+ except:
38
+ pass
39
+ return None
40
+
41
+ # ---- Clean output for EXACT MATCH ----
42
+ def clean_answer(self, text: str) -> str:
43
+ text = text.strip()
44
+
45
+ if "Answer:" in text:
46
+ text = text.split("Answer:")[-1]
47
+
48
+ text = text.split("\n")[0]
49
+ text = text.strip()
50
+
51
+ # Remove trailing punctuation
52
+ text = re.sub(r"[\.]$", "", text)
53
+
54
+ return text.strip()
55
+
56
+ def ask_model(self, question: str):
57
+ prompt = f"""
58
+ You are solving a benchmark evaluation problem.
59
+
60
+ Think step by step internally.
61
+ But output ONLY the final answer.
62
+ Do NOT explain.
63
+ Do NOT add extra words.
64
+
65
+ Question:
66
+ {question}
67
+
68
+ Final Answer:
69
+ """
70
+ output = self.generator(prompt)[0]["generated_text"]
71
+ answer = output.replace(prompt, "")
72
+ return self.clean_answer(answer)
73
+
74
  def __call__(self, question: str) -> str:
75
+ print(f"Processing question: {question[:60]}...")
76
+
77
+ math_result = self.try_math(question)
78
+ if math_result:
79
+ print("Used math tool.")
80
+ return math_result
81
+
82
+ answer = self.ask_model(question)
83
+
84
+ if len(answer.split()) > 6:
85
+ print("Retrying due to long answer...")
86
+ answer = self.ask_model(question)
87
+
88
+ print(f"Final Answer: {answer}")
89
+ return answer
90
+
91
 
92
  def run_and_submit_all( profile: gr.OAuthProfile | None):
93
  """