RalphThings commited on
Commit
e944048
·
verified ·
1 Parent(s): b70ff0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -171
app.py CHANGED
@@ -1,12 +1,14 @@
1
  from transformers import pipeline
2
  import os
3
  import re
4
- import json
5
  import torch
6
  import gradio as gr
7
  import requests
8
  import inspect
9
  import pandas as pd
 
 
 
10
  from youtube_transcript_api import YouTubeTranscriptApi
11
  import chess, chess.engine
12
  from bs4 import BeautifulSoup
@@ -18,189 +20,218 @@ from SPARQLWrapper import SPARQLWrapper, JSON
18
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
19
  HF_TOKEN = os.getenv("HF_TOKEN", None)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # --- Basic Agent Definition ---
22
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
23
  class BasicAgent:
24
- WIKI_API = "https://en.wikipedia.org/w/api.php"
25
- VEGETABLE_SET = {
26
- "bell pepper","broccoli","celery","green beans",
27
- "lettuce","zucchini","sweet potatoes"
28
- }
29
-
30
  def __init__(self):
31
  # initialize HF inference pipeline once
32
  if HF_TOKEN is None:
33
  raise ValueError("HF_TOKEN not set in environment")
34
- self.generator = pipeline("text-generation", model="EleutherAI/gpt-neo-125M")
 
 
35
  # The GAIA system prompt (no "FINAL ANSWER:" at the end)
36
- self.system_prompt = (
37
- "You are a concise AI assistant. "
38
- "Answer in as few words as possible—a number, a few words, or a comma-separated list. "
39
- "No commentary, prefixes, or units.\n\n"
40
- )
41
  print("BasicAgent initialized with LLM.")
42
- # Stockfish location—adjust path if needed
43
- self.stockfish_path = "/usr/bin/stockfish"
44
-
45
- # --- Tool 1: Wikipedia raw wikitext fetch ---
46
- def wiki_get_page(self, title: str) -> str:
47
- params = {
48
- "action": "query","format": "json",
49
- "prop": "revisions","rvprop": "content","rvslots": "*",
50
- "titles": title
51
- }
52
- r = requests.get(self.WIKI_API, params=params, timeout=10)
53
- pages = r.json()["query"]["pages"]
54
- page = next(iter(pages.values()))
55
- return page["revisions"][0]["slots"]["main"]["*"]
56
-
57
- # --- Tool 2: YouTube transcript ---
58
- def youtube_transcript(self, video_id: str) -> str:
59
- transcript = YouTubeTranscriptApi().fetch_transcript(video_id)
60
- return " ".join(t["text"] for t in transcript)
61
-
62
- # --- Tool 3: reverse text ---
63
- def reverse_text(self, text: str) -> str:
64
- return text[::-1]
65
-
66
- # --- Tool 4: Chess best move via Stockfish ---
67
- def chess_best_move(self, fen: str, time_limit: float = 0.1) -> str:
68
- board = chess.Board(fen)
69
- engine = chess.engine.SimpleEngine.popen_uci(self.stockfish_path)
70
- result = engine.play(board, chess.engine.Limit(time=time_limit))
71
- engine.quit()
72
- return result.move.uci()
73
-
74
- # --- Tool 5: Table non-commutativity ---
75
- def find_non_commutative(self, table: dict) -> list:
76
- elems = set(x for x,_ in table.keys())
77
- bad = set()
78
- for x in elems:
79
- for y in elems:
80
- if table[(x,y)] != table[(y,x)]:
81
- bad.update([x,y])
82
- return sorted(bad)
83
-
84
- # --- Tool 6: LibreTexts scraping (generic) ---
85
- def libretext_extract(self, url: str, selector: str) -> str:
86
- r = requests.get(url, timeout=10)
87
- soup = BeautifulSoup(r.text, "html.parser")
88
- return soup.select_one(selector).get_text(strip=True)
89
-
90
- # --- Tool 7: Grocery vegetable classifier ---
91
- def classify_vegetables(self, items: list[str]) -> list[str]:
92
- vegs = [i for i in items if i in self.VEGETABLE_SET]
93
- return sorted(vegs)
94
-
95
- # --- Tool 8: Audio transcription via AssemblyAI ---
96
- def transcribe_audio(self, audio_url: str) -> str:
97
- transcriber = aai.Transcriber()
98
- result = transcriber.transcribe(audio_url)
99
- return result.text
100
-
101
- # --- Tool 9: Actor role lookup (stub—for you to flesh out) ---
102
- def actor_role(self, title: str, role_name: str, target_series: str) -> str:
103
- # TODO: implement via OMDb/IMDbPy
104
- return "UNKNOWN"
105
-
106
- # --- Tool 10: Sandbox code execution ---
107
- def execute_code(self, code: str) -> str:
108
- local_ns = {}
109
- exec(code, {"__builtins__": {}}, local_ns)
110
- # assume user sets 'output' variable
111
- return str(local_ns.get("output", ""))
112
-
113
- # --- Tool 11: Baseball stats via statsapi ---
114
- def yankee_at_bats_most_walks(self, year: int) -> int:
115
- leaders = statsapi.team_leaders("walks", season=year, team=147) # Yankees=147
116
- pid = leaders[0]["id"]
117
- stats = statsapi.player_stats(pid, "hitting", "season", season=year)
118
- return stats["batting"][0]["atBats"]
119
-
120
- # --- Tool 12: Olympics data scraping ---
121
- def least_athletes_olympics(self, year: int) -> str:
122
- url = f"https://en.wikipedia.org/wiki/{year}_Summer_Olympics"
123
- r = requests.get(url); soup = BeautifulSoup(r.text,"html.parser")
124
- # naive: look for first table with nation counts...
125
- table = soup.find("table","wikitable")
126
- rows = table.find_all("tr")[1:]
127
- data = [(r.find_all("td")[0].get_text(strip=True),
128
- int(r.find_all("td")[1].get_text(strip=True)))
129
- for r in rows]
130
- min_val = min(c for _,c in data)
131
- candidates = sorted([code for code,count in data if count==min_val])
132
- return candidates[0]
133
-
134
- # --- Tool 13: Wikidata SPARQL for NASA awards ---
135
- def get_nasa_award_number(self, qid: str) -> str:
136
- sparql = SPARQLWrapper("https://query.wikidata.org/sparql")
137
- sparql.setQuery(f"""
138
- SELECT ?award WHERE {{
139
- wd:{qid} wdt:P496 ?award.
140
- }}
141
- """)
142
- sparql.setReturnFormat(JSON)
143
- res = sparql.query().convert()
144
- return res["results"]["bindings"][0]["award"]["value"]
145
 
146
  # --- Core dispatcher/fallback ---
147
  def __call__(self, question: str) -> str:
148
- q = question.strip()
149
-
150
- # 1) studio albums by Mercedes Sosa 2000–2009
151
- if "Mercedes Sosa" in q and "studio albums" in q:
152
- text = self.wiki_get_page("Mercedes Sosa discography")
153
- years = re.findall(r"\b(20\d\d)\b", text)
154
- # count entries between 2000 and 2009
155
- return str(sum(1 for y in years if 2000 <= int(y) <= 2009))
156
-
157
- # 2) YouTube species count
158
- m = re.search(r"youtube\.com/watch\?v=([A-Za-z0-9_\-]+)", q)
159
- if m and "bird species" in q:
160
- transcript = self.youtube_transcript(m.group(1))
161
- nums = [int(n) for n in re.findall(r"(\d+)\s+species", transcript)]
162
- return str(max(nums) if nums else 0)
163
-
164
- # 3) reversed-text puzzles
165
- if q.startswith((".",'"')) and "dnatsrednu" in q:
166
- inner = q.strip('"').strip()[::-1]
167
- # extract the core sentence
168
- return inner
169
-
170
- # 4) chess win move (FEN)
171
- if "Review the chess position" in q:
172
- # user would have attached FEN in question_data["files"], but here we default example
173
- fen = "..." # TODO: extract from files
174
- return self.chess_best_move(fen)
175
-
176
- # 5) operation table non-commutativity
177
- if "counter-examples" in q:
178
- # assume question_data carries a JSON-able table under item["table"]
179
- table = json.loads(question_data.get("table_json","{}"))
180
- bad = self.find_non_commutative(table)
181
- return ",".join(bad)
182
-
183
- # 6) grocery list vegetables
184
- if "grocery list" in q and "vegetables" in q:
185
- items = re.findall(r"\b[\w\s]+(?=,|$)", q)
186
- vegs = self.classify_vegetables([i.strip() for i in items])
187
- return ",".join(vegs)
188
-
189
- # 7) transcript-based page numbers or ingredients
190
- if q.lower().startswith("i was out sick") or "strawberry pie.mp3" in q:
191
- # use URL or path from item["files"]
192
- audio_url = question_data.get("audio_url")
193
- text = self.transcribe_audio(audio_url)
194
- # depends: page numbers or ingredients
195
- nums = sorted(set(re.findall(r"\b(\d+)\b", text)), key=int)
196
- return ",".join(nums)
197
-
198
- # ... extend further for other tools ...
199
-
200
- # fallback to LLM
201
  prompt = f"{self.system_prompt}Q: {q}\nA:"
202
- out = self.generator(prompt, max_new_tokens=16, return_full_text=False)
203
- return out[0]["generated_text"].strip()
 
 
 
204
 
205
  def run_and_submit_all( profile: gr.OAuthProfile | None):
206
  """
 
1
  from transformers import pipeline
2
  import os
3
  import re
 
4
  import torch
5
  import gradio as gr
6
  import requests
7
  import inspect
8
  import pandas as pd
9
+ from langchain_huggingface.llms import HuggingFacePipeline
10
+ from langchain_core.tools import tool
11
+ from langchain_core.agents import AgentExecutor, JsonOutputParser
12
  from youtube_transcript_api import YouTubeTranscriptApi
13
  import chess, chess.engine
14
  from bs4 import BeautifulSoup
 
20
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
21
  HF_TOKEN = os.getenv("HF_TOKEN", None)
22
 
23
+ @tool(
24
+ name="wiki_get_page",
25
+ description="Fetch raw wikitext for a given Wikipedia page title",
26
+ inputs={"title": "string"},
27
+ output_type="string",
28
+ )
29
+ def wiki_get_page(title: str) -> str:
30
+ API = "https://en.wikipedia.org/w/api.php"
31
+ params = {"action": "query", "format": "json", "prop": "revisions", "rvprop": "content", "rvslots": "*", "titles": title}
32
+ data = requests.get(API, params=params, timeout=10).json()
33
+ page = next(iter(data["query"]["pages"].values()))
34
+ return page["revisions"][0]["slots"]["main"]["*"]
35
+
36
+ @tool(
37
+ name="youtube_transcript",
38
+ description="Retrieve transcript for a YouTube video ID",
39
+ inputs={"video_id": "string"},
40
+ output_type="string",
41
+ )
42
+ def youtube_transcript(video_id: str) -> str:
43
+ transcript = YouTubeTranscriptApi().fetch_transcript(video_id)
44
+ return " ".join(t["text"] for t in transcript)
45
+
46
+ @tool(
47
+ name="reverse_text",
48
+ description="Reverse the input string",
49
+ inputs={"text": "string"},
50
+ output_type="string",
51
+ )
52
+ def reverse_text(text: str) -> str:
53
+ return text[::-1]
54
+
55
+ @tool(
56
+ name="chess_best_move",
57
+ description="Return best move in UCI notation for given FEN",
58
+ inputs={"fen": "string", "time_limit": "float"},
59
+ output_type="string",
60
+ )
61
+ def chess_best_move(fen: str, time_limit: float = 0.1) -> str:
62
+ board = chess.Board(fen)
63
+ engine = chess.engine.SimpleEngine.popen_uci("/usr/bin/stockfish")
64
+ result = engine.play(board, chess.engine.Limit(time=time_limit))
65
+ engine.quit()
66
+ return result.move.uci()
67
+
68
+ @tool(
69
+ name="find_non_commutative",
70
+ description="Find elements involved in non-commutativity from operation table",
71
+ inputs={"table": "dict"},
72
+ output_type="list[string]",
73
+ )
74
+ def find_non_commutative(table: dict) -> list:
75
+ elems = set(x for x,_ in table.keys())
76
+ bad = set()
77
+ for x in elems:
78
+ for y in elems:
79
+ if table[(x,y)] != table[(y,x)]:
80
+ bad.update([x,y])
81
+ return sorted(bad)
82
+
83
+ @tool(
84
+ name="libretext_extract",
85
+ description="Extract text from LibreTexts URL using CSS selector",
86
+ inputs={"url": "string", "selector": "string"},
87
+ output_type="string",
88
+ )
89
+ def libretext_extract(url: str, selector: str) -> str:
90
+ r = requests.get(url, timeout=10)
91
+ soup = BeautifulSoup(r.text, "html.parser")
92
+ return soup.select_one(selector).get_text(strip=True)
93
+
94
+ @tool(
95
+ name="classify_vegetables",
96
+ description="Return alphabetized list of vegetables from input list",
97
+ inputs={"items": "list[string]"},
98
+ output_type="list[string]",
99
+ )
100
+ def classify_vegetables(items: list) -> list:
101
+ VEGETABLE_SET = {"bell pepper","broccoli","celery","green beans","lettuce","zucchini","sweet potatoes"}
102
+ return sorted([i for i in items if i in VEGETABLE_SET])
103
+
104
+ @tool(
105
+ name="transcribe_audio",
106
+ description="Transcribe audio file or URL using AssemblyAI",
107
+ inputs={"audio_url": "string"},
108
+ output_type="string",
109
+ )
110
+ def transcribe_audio(audio_url: str) -> str:
111
+ transcriber = aai.Transcriber()
112
+ result = transcriber.transcribe(audio_url)
113
+ return result.text
114
+
115
+ @tool(
116
+ name="actor_role",
117
+ description="Lookup actor role via OMDb API (stub implementation)",
118
+ inputs={"title": "string", "role_name": "string", "target_series": "string"},
119
+ output_type="string",
120
+ )
121
+ def actor_role(title: str, role_name: str, target_series: str) -> str:
122
+ return "UNKNOWN"
123
+
124
+ @tool(
125
+ name="execute_code",
126
+ description="Execute Python code snippet and return 'output' variable",
127
+ inputs={"code": "string"},
128
+ output_type="string",
129
+ )
130
+ def execute_code(code: str) -> str:
131
+ local_ns = {}
132
+ exec(code, {"__builtins__": {}}, local_ns)
133
+ return str(local_ns.get("output", ""))
134
+
135
+ @tool(
136
+ name="yankee_at_bats_most_walks",
137
+ description="Return at bats for Yankee with most walks in given season",
138
+ inputs={"year": "int"},
139
+ output_type="int",
140
+ )
141
+ def yankee_at_bats_most_walks(year: int) -> int:
142
+ leaders = statsapi.team_leaders("walks", season=year, team=147)
143
+ pid = leaders[0]["id"]
144
+ stats = statsapi.player_stats(pid, "hitting", "season", season=year)
145
+ return stats["batting"][0]["atBats"]
146
+
147
+ @tool(
148
+ name="least_athletes_olympics",
149
+ description="Return IOC code of country with least athletes in given Olympics year",
150
+ inputs={"year": "int"},
151
+ output_type="string",
152
+ )
153
+ def least_athletes_olympics(year: int) -> str:
154
+ url = f"https://en.wikipedia.org/wiki/{year}_Summer_Olympics"
155
+ r = requests.get(url)
156
+ soup = BeautifulSoup(r.text,"html.parser")
157
+ table = soup.find("table","wikitable")
158
+ rows = table.find_all("tr")[1:]
159
+ data = [(r.find_all("td")[0].get_text(strip=True), int(r.find_all("td")[1].get_text(strip=True))) for r in rows]
160
+ min_val = min(c for _,c in data)
161
+ candidates = sorted([code for code,count in data if count==min_val])
162
+ return candidates[0]
163
+
164
+ @tool(
165
+ name="get_nasa_award_number",
166
+ description="Get NASA award number for a Wikidata QID",
167
+ inputs={"qid": "string"},
168
+ output_type="string",
169
+ )
170
+ def get_nasa_award_number(qid: str) -> str:
171
+ sparql = SPARQLWrapper("https://query.wikidata.org/sparql")
172
+ sparql.setQuery(f'SELECT ?award WHERE {{ wd:{qid} wdt:P496 ?award. }}')
173
+ sparql.setReturnFormat(JSON)
174
+ res = sparql.query().convert()
175
+ return res["results"]["bindings"][0]["award"]["value"]
176
+
177
+ TOOLS = [
178
+ wiki_get_page,
179
+ youtube_transcript,
180
+ reverse_text,
181
+ chess_best_move,
182
+ find_non_commutative,
183
+ libretext_extract,
184
+ classify_vegetables,
185
+ transcribe_audio,
186
+ actor_role,
187
+ execute_code,
188
+ yankee_at_bats_most_walks,
189
+ least_athletes_olympics,
190
+ get_nasa_award_number,
191
+ ]
192
+
193
+ SYSTEM_MESSAGE = """You are a concise AI assistant with access to the following tools:
194
+ - wiki_get_page(title: string) → string
195
+ - youtube_transcript(video_id: string) → string
196
+ - reverse_text(text: string) → string
197
+ - chess_best_move(fen: string, time_limit: float) → string
198
+ - find_non_commutative(table: dict) → list[string]
199
+ - libretext_extract(url: string, selector: string) → string
200
+ - classify_vegetables(items: list[string]) → list[string]
201
+ - transcribe_audio(audio_url: string) → string
202
+ - actor_role(title: string, role_name: string, target_series: string) → string
203
+ - execute_code(code: string) → string
204
+ - yankee_at_bats_most_walks(year: int) → int
205
+ - least_athletes_olympics(year: int) → string
206
+ - get_nasa_award_number(qid: string) → string
207
+ When you need to use a tool, respond exactly with:
208
+ Action: <tool_name>(<arg_name>=<value>, ...)
209
+ Then wait for the tool’s output before continuing.
210
+ Once you have all the information, provide your final answer in as few words as possible, with no extra commentary or prefixes.
211
+ """
212
+
213
  # --- Basic Agent Definition ---
214
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
215
  class BasicAgent:
 
 
 
 
 
 
216
  def __init__(self):
217
  # initialize HF inference pipeline once
218
  if HF_TOKEN is None:
219
  raise ValueError("HF_TOKEN not set in environment")
220
+ self.generator = pipeline("text-generation", model="EleutherAI/gpt-neo-125M", max_new_tokens=16)
221
+ self.llm = HuggingFacePipeline.from_pipeline(self.generator)
222
+ self.llm = self.llm.bind_tools(TOOLS)
223
  # The GAIA system prompt (no "FINAL ANSWER:" at the end)
224
+ self.system_prompt = SYSTEM_MESSAGE
 
 
 
 
225
  print("BasicAgent initialized with LLM.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  # --- Core dispatcher/fallback ---
228
  def __call__(self, question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  prompt = f"{self.system_prompt}Q: {q}\nA:"
230
+ #out = self.generator(prompt, max_new_tokens=16, return_full_text=False)
231
+ #return out[0]["generated_text"].strip()
232
+ agent = AgentExecutor(agent=self.llm, tools=TOOLS, prompt=prompt, verbose=False, return_intermediate_steps=False)
233
+ result = agent.invoke({"input": question})
234
+ return JsonOutputParser().parse(result)
235
 
236
  def run_and_submit_all( profile: gr.OAuthProfile | None):
237
  """