victor-johnson commited on
Commit
1518270
·
verified ·
1 Parent(s): 81917a3

add agent logic

Browse files
Files changed (1) hide show
  1. app.py +83 -5
app.py CHANGED
@@ -10,14 +10,92 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
 
 
 
 
 
13
  class BasicAgent:
 
 
 
 
 
14
  def __init__(self):
15
- print("BasicAgent initialized.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def __call__(self, question: str) -> str:
17
- print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
+ # --- Enhanced Agent (replace the old BasicAgent with this) ---
14
+ import re
15
+ import json
16
+ import textwrap
17
+
18
  class BasicAgent:
19
+ """
20
+ Calls a small LLM on the Hugging Face Inference API and post-processes the output
21
+ to return a concise, exact-match-friendly answer.
22
+ Reads HF API token from the HF Space secret HF_TOKEN.
23
+ """
24
  def __init__(self):
25
+ self.model_url = (
26
+ "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
27
+ # You can swap to another text-generation-inference model on the Hub.
28
+ )
29
+ self.headers = {
30
+ "Authorization": f"Bearer {os.getenv('HF_TOKEN', '')}",
31
+ "Content-Type": "application/json",
32
+ }
33
+ if not os.getenv("HF_TOKEN"):
34
+ print("⚠️ HF_TOKEN not set — requests to the Inference API will fail.")
35
+ print("🚀 Enhanced BasicAgent initialized.")
36
+
37
+ def _clean(self, raw: str) -> str:
38
+ """
39
+ Post-process the model output to a short, exact value:
40
+ - Keep only the final line after common prefixes.
41
+ - Strip quotes/backticks and trailing punctuation.
42
+ """
43
+ txt = raw.strip()
44
+
45
+ # If the model echoed the prompt, take the last line
46
+ lines = [l.strip() for l in txt.splitlines() if l.strip()]
47
+ if lines:
48
+ txt = lines[-1]
49
+
50
+ # Remove common prefixes
51
+ txt = re.sub(r"^(final answer|answer|prediction)\s*[:\-]\s*", "", txt, flags=re.I)
52
+
53
+ # Strip code fences/quotes and trailing punctuation
54
+ txt = txt.strip("`'\" \t\n\r")
55
+ txt = re.sub(r"[ \t]*[.;,:-]+$", "", txt)
56
+
57
+ # Keep it short (exact match tasks usually want a short string/number)
58
+ return txt[:200]
59
+
60
  def __call__(self, question: str) -> str:
61
+ print(f"🧠 Agent received question: {question[:120]}...")
62
+ # A very simple, direct prompt to reduce rambling
63
+ prompt = textwrap.dedent(f"""
64
+ You must answer the question with a single, concise value
65
+ (number, word, date, short phrase) and nothing else.
66
+
67
+ Question: {question}
68
+ Final answer:
69
+ """).strip()
70
+
71
+ payload = {
72
+ "inputs": prompt,
73
+ "parameters": {
74
+ "max_new_tokens": 96,
75
+ "temperature": 0.2,
76
+ "return_full_text": False,
77
+ },
78
+ "options": {"wait_for_model": True},
79
+ }
80
+
81
+ try:
82
+ r = requests.post(self.model_url, headers=self.headers, data=json.dumps(payload), timeout=60)
83
+ r.raise_for_status()
84
+ data = r.json()
85
+ # Inference API (text-generation) usually returns a list with "generated_text"
86
+ if isinstance(data, list) and data and "generated_text" in data[0]:
87
+ generated = data[0]["generated_text"]
88
+ else:
89
+ # Fallback: stringify whatever came back
90
+ generated = str(data)
91
+ answer = self._clean(generated)
92
+ print(f"✅ Cleaned answer: {answer}")
93
+ # Ensure we never submit an empty string (server rejects)
94
+ return answer if answer else "N/A"
95
+ except Exception as e:
96
+ print(f"❌ Error generating answer: {e}")
97
+ return "N/A"
98
+
99
 
100
  def run_and_submit_all( profile: gr.OAuthProfile | None):
101
  """