安如衫 commited on
Commit
44ff471
·
1 Parent(s): e7f4f55

feat: code agent with web search

Browse files
app.py CHANGED
@@ -1,23 +1,143 @@
1
  import os
 
2
  import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
 
 
 
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
 
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  class BasicAgent:
14
  def __init__(self):
15
  print("BasicAgent initialized.")
16
- def __call__(self, question: str) -> str:
17
- print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
@@ -79,13 +199,15 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
79
  if not task_id or question_text is None:
80
  print(f"Skipping item with missing task_id or question: {item}")
81
  continue
 
 
82
  try:
83
- submitted_answer = agent(question_text)
84
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
85
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
86
  except Exception as e:
87
  print(f"Error running agent on task {task_id}: {e}")
88
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
89
 
90
  if not answers_payload:
91
  print("Agent did not produce any answers to submit.")
 
1
  import os
2
+ from traceback import print_tb
3
  import gradio as gr
4
  import requests
5
  import inspect
6
  import pandas as pd
7
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, InferenceClientModel, OpenAIServerModel, ToolCallingAgent, VisitWebpageTool
8
+ from dotenv import load_dotenv
9
+ from utils import detect_file_category
10
+ from tools import transcribe_audio
11
+ from PIL import Image
12
+
13
+ load_dotenv()
14
 
15
  # (Keep Constants as is)
16
  # --- Constants ---
17
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
18
+ DASHSCOPE_API_BASE = os.getenv("DASHSCOPE_API_BASE")
19
+ DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
20
 
21
  # --- Basic Agent Definition ---
22
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
23
  class BasicAgent:
24
  def __init__(self):
25
  print("BasicAgent initialized.")
26
+
27
+ def __call__(self, question: str, file_path: str) -> str:
28
+ if not file_path or "ERROR" in file_path:
29
+ file_path = None
30
+
31
+ if file_path:
32
+ category = detect_file_category(file_path)
33
+ else:
34
+ category = "none"
35
+
36
+ if category == "none":
37
+ agent = CodeAgent(tools=[DuckDuckGoSearchTool(), VisitWebpageTool()], model=OpenAIServerModel(
38
+ model_id="qwen3-coder-flash",
39
+ api_base=DASHSCOPE_API_BASE,
40
+ api_key=DASHSCOPE_API_KEY,
41
+ ))
42
+ return agent.run(question)
43
+
44
+ if category == "audio":
45
+ agent = CodeAgent(
46
+ tools=[DuckDuckGoSearchTool(), VisitWebpageTool(), transcribe_audio],
47
+ model=OpenAIServerModel(
48
+ model_id="qwen3-coder-flash",
49
+ api_base=DASHSCOPE_API_BASE,
50
+ api_key=DASHSCOPE_API_KEY,
51
+ ))
52
+ return agent.run(question + f"\n\nfile_path:{file_path}")
53
+
54
+ if category == "image":
55
+ agent = CodeAgent(
56
+ model=OpenAIServerModel(
57
+ model_id="qwen3-vl-flash",
58
+ api_base=DASHSCOPE_API_BASE,
59
+ api_key=DASHSCOPE_API_KEY,
60
+ ),
61
+ max_steps=20,
62
+ verbosity_level=2
63
+ )
64
+ # agent = CodeAgent(
65
+ # tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
66
+ # model=OpenAIServerModel(
67
+ # model_id="qwen3-coder-flash",
68
+ # api_base=DASHSCOPE_API_BASE,
69
+ # api_key=DASHSCOPE_API_KEY,
70
+ # ),
71
+ # managed_agents=[image_agent])
72
+ return agent.run(question, images=[Image.open(file_path).convert("RGB")])
73
+
74
+ agent = CodeAgent(
75
+ additional_authorized_imports=["pandas"],
76
+ tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
77
+ model=OpenAIServerModel(
78
+ model_id="qwen3-coder-flash",
79
+ api_base=DASHSCOPE_API_BASE,
80
+ api_key=DASHSCOPE_API_KEY,
81
+ ))
82
+ return agent.run(question + f"\n\nfile_path:{file_path}")
83
+
84
+ # 新增:下载与 task_id 关联的文件的辅助函数
85
+ import re
86
+
87
+ def download_task_file(api_url: str, task_id: str, output_dir: str = "downloads") -> str:
88
+ files_url = f"{api_url}/files/{task_id}"
89
+ try:
90
+ os.makedirs(output_dir, exist_ok=True)
91
+
92
+ # 快速预检:如果 downloads 里已存在以 task_id 命名的文件则直接返回
93
+ try:
94
+ for name in os.listdir(output_dir):
95
+ base, _ext = os.path.splitext(name)
96
+ candidate = os.path.join(output_dir, name)
97
+ if base == task_id and os.path.isfile(candidate):
98
+ print(f"File for task {task_id} already exists: {candidate}")
99
+ return candidate
100
+ except FileNotFoundError:
101
+ pass
102
+
103
+ with requests.get(files_url, stream=True, timeout=30) as r:
104
+ r.raise_for_status()
105
+ filename = None
106
+ cd = r.headers.get("content-disposition")
107
+ if cd:
108
+ m = re.search('filename="?([^";]+)"?', cd)
109
+ if m:
110
+ filename = m.group(1)
111
+ if not filename:
112
+ filename = r.headers.get("x-filename")
113
+ if not filename:
114
+ filename = f"{task_id}.download"
115
+ dest_path = os.path.join(output_dir, filename)
116
+
117
+ # 二次检查:若目标文件已存在则跳过重新下载
118
+ if os.path.exists(dest_path):
119
+ print(f"File for task {task_id} already exists: {dest_path}")
120
+ return dest_path
121
+
122
+ with open(dest_path, "wb") as f:
123
+ for chunk in r.iter_content(chunk_size=8192):
124
+ if chunk:
125
+ f.write(chunk)
126
+ print(f"Downloaded file for task {task_id} to: {dest_path}")
127
+ return dest_path
128
+ except requests.exceptions.HTTPError as e:
129
+ status = getattr(e.response, 'status_code', 'unknown')
130
+ print(f"File download HTTP error for task {task_id}: {e}")
131
+ return f"ERROR: HTTP {status} for task {task_id}"
132
+ except requests.exceptions.Timeout:
133
+ print(f"File download timed out for task {task_id}")
134
+ return f"ERROR: Timeout downloading task {task_id}"
135
+ except requests.exceptions.RequestException as e:
136
+ print(f"File download network error for task {task_id}: {e}")
137
+ return f"ERROR: Network error downloading task {task_id}: {e}"
138
+ except Exception as e:
139
+ print(f"Unexpected error downloading file for task {task_id}: {e}")
140
+ return f"ERROR: Unexpected error downloading task {task_id}: {e}"
141
 
142
  def run_and_submit_all( profile: gr.OAuthProfile | None):
143
  """
 
199
  if not task_id or question_text is None:
200
  print(f"Skipping item with missing task_id or question: {item}")
201
  continue
202
+ # 新增:下载与该 task_id 关联的文件
203
+ downloaded_path = download_task_file(api_url, task_id)
204
  try:
205
+ submitted_answer = agent(question_text, downloaded_path)
206
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
207
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer, "Downloaded File": downloaded_path})
208
  except Exception as e:
209
  print(f"Error running agent on task {task_id}: {e}")
210
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}", "Downloaded File": downloaded_path})
211
 
212
  if not answers_payload:
213
  print("Agent did not produce any answers to submit.")
downloads/7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx ADDED
Binary file (5.29 kB). View file
 
downloads/cca530fc-4052-43b2-b130-b30968d8aa44.png ADDED
downloads/f918266a-b3e0-4914-865d-4faa564f1aef.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import randint
2
+ import time
3
+
4
+ class UhOh(Exception):
5
+ pass
6
+
7
+ class Hmm:
8
+ def __init__(self):
9
+ self.value = randint(-100, 100)
10
+
11
+ def Yeah(self):
12
+ if self.value == 0:
13
+ return True
14
+ else:
15
+ raise UhOh()
16
+
17
+ def Okay():
18
+ while True:
19
+ yield Hmm()
20
+
21
+ def keep_trying(go, first_try=True):
22
+ maybe = next(go)
23
+ try:
24
+ if maybe.Yeah():
25
+ return maybe.value
26
+ except UhOh:
27
+ if first_try:
28
+ print("Working...")
29
+ print("Please wait patiently...")
30
+ time.sleep(0.1)
31
+ return keep_trying(go, first_try=False)
32
+
33
+ if __name__ == "__main__":
34
+ go = Okay()
35
+ print(f"{keep_trying(go)}")
requirements.txt CHANGED
@@ -1,2 +1,7 @@
1
  gradio
2
- requests
 
 
 
 
 
 
1
  gradio
2
+ requests
3
+ smolagents[all]
4
+ faster-whisper
5
+ filetype
6
+ torch
7
+ Pillow
tools.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import tool
2
+ import os
3
+
4
+ @tool
5
+ def transcribe_audio(file_path: str) -> str:
6
+ """
7
+ Transcribes an audio file using faster-whisper.
8
+
9
+ Args:
10
+ file_path: The path to the audio file.
11
+
12
+ Returns:
13
+ The transcribed text, or an error message if transcription fails.
14
+ """
15
+ # Use faster-whisper if available
16
+ try:
17
+ from faster_whisper import WhisperModel
18
+ import torch
19
+
20
+ if torch.cuda.is_available():
21
+ device = "cuda"
22
+ else:
23
+ device = "cpu"
24
+
25
+ model = WhisperModel("base", device=device)
26
+ segments, info = model.transcribe(file_path)
27
+ text_parts = []
28
+ for seg in segments:
29
+ try:
30
+ text_parts.append(seg.text)
31
+ except Exception:
32
+ pass
33
+ text = " ".join(text_parts).strip()
34
+ return text or "[ASR result is empty]"
35
+ except Exception as e:
36
+ return f"[ASR is not available] Please install `faster-whisper` to enable audio recognition. Error: {e}"
utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import mimetypes
3
+
4
+ def detect_file_category(file_path: str) -> str:
5
+ if not file_path or not os.path.exists(file_path):
6
+ return "none"
7
+ mime = None
8
+ try:
9
+ import filetype # optional; if not installed, fallback to mimetypes
10
+ kind = filetype.guess(file_path)
11
+ mime = kind.mime if kind else None
12
+ except Exception:
13
+ mime = None
14
+ if not mime:
15
+ mime, _ = mimetypes.guess_type(file_path)
16
+ ext = os.path.splitext(file_path)[1].lower()
17
+ if mime:
18
+ if mime.startswith("image/"):
19
+ return "image"
20
+ if mime.startswith("audio/"):
21
+ return "audio"
22
+ if mime.startswith("video/"):
23
+ return "video"
24
+ if mime in (
25
+ "application/vnd.ms-excel",
26
+ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
27
+ "text/csv",
28
+ ):
29
+ return "spreadsheet"
30
+ if mime == "application/pdf":
31
+ return "document"
32
+ if mime.startswith("text/"):
33
+ if ext in (".py", ".md", ".txt", ".csv"):
34
+ return "text"
35
+ # extension fallback
36
+ if ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tiff"):
37
+ return "image"
38
+ if ext in (".mp3", ".wav", ".m4a", ".flac", ".ogg"):
39
+ return "audio"
40
+ if ext in (".mp4", ".mov", ".mkv", ".webm"):
41
+ return "video"
42
+ if ext in (".xls", ".xlsx", ".csv"):
43
+ return "spreadsheet"
44
+ return "unknown"