Update app.py
Browse files
app.py
CHANGED
|
@@ -12,12 +12,11 @@ import yaml
|
|
| 12 |
import pandas as pd
|
| 13 |
import gradio as gr
|
| 14 |
|
| 15 |
-
from smolagents import CodeAgent, tool
|
| 16 |
|
| 17 |
# -------------------------
|
| 18 |
# Minimal tools
|
| 19 |
# -------------------------
|
| 20 |
-
|
| 21 |
_allowed_ops = {
|
| 22 |
ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul,
|
| 23 |
ast.Div: operator.truediv, ast.Pow: operator.pow, ast.USub: operator.neg,
|
|
@@ -87,12 +86,16 @@ except Exception:
|
|
| 87 |
prompt_templates = None
|
| 88 |
|
| 89 |
# -------------------------
|
| 90 |
-
# CodeAgent minimal
|
| 91 |
# -------------------------
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
code_agent = CodeAgent(
|
| 95 |
-
model=
|
| 96 |
tools=[calculator, get_current_time_in_timezone],
|
| 97 |
max_steps=6,
|
| 98 |
verbosity_level=0,
|
|
@@ -117,13 +120,13 @@ class GaiaAgentMinimal:
|
|
| 117 |
try:
|
| 118 |
q = question.strip()
|
| 119 |
|
| 120 |
-
#
|
| 121 |
if self._is_calc(q):
|
| 122 |
m = re.search(r'([0-9\.\s\+\-\*\/\^\%\(\)]+)', q)
|
| 123 |
expr = m.group(1) if m else q
|
| 124 |
return calculator(expr)
|
| 125 |
|
| 126 |
-
#
|
| 127 |
if self._is_time(q):
|
| 128 |
if "paris" in q.lower() or "france" in q.lower():
|
| 129 |
tz = "Europe/Paris"
|
|
@@ -131,7 +134,7 @@ class GaiaAgentMinimal:
|
|
| 131 |
tz = "UTC"
|
| 132 |
return get_current_time_in_timezone(tz)
|
| 133 |
|
| 134 |
-
#
|
| 135 |
resp = self.code_agent.run(q)
|
| 136 |
if isinstance(resp, dict):
|
| 137 |
for key in ("final_answer", "answer", "result", "output"):
|
|
@@ -142,7 +145,7 @@ class GaiaAgentMinimal:
|
|
| 142 |
except Exception as e:
|
| 143 |
return json.dumps({"error": f"Agent internal error: {e}"})
|
| 144 |
|
| 145 |
-
# instantiate
|
| 146 |
gaia_agent = GaiaAgentMinimal(code_agent)
|
| 147 |
|
| 148 |
# -------------------------
|
|
|
|
| 12 |
import pandas as pd
|
| 13 |
import gradio as gr
|
| 14 |
|
| 15 |
+
from smolagents import CodeAgent, HfApiModel, tool
|
| 16 |
|
| 17 |
# -------------------------
|
| 18 |
# Minimal tools
|
| 19 |
# -------------------------
|
|
|
|
| 20 |
_allowed_ops = {
|
| 21 |
ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul,
|
| 22 |
ast.Div: operator.truediv, ast.Pow: operator.pow, ast.USub: operator.neg,
|
|
|
|
| 86 |
prompt_templates = None
|
| 87 |
|
| 88 |
# -------------------------
|
| 89 |
+
# HfApiModel + CodeAgent minimal
|
| 90 |
# -------------------------
|
| 91 |
+
model = HfApiModel(
|
| 92 |
+
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
|
| 93 |
+
max_tokens=1024,
|
| 94 |
+
temperature=0.0
|
| 95 |
+
)
|
| 96 |
|
| 97 |
code_agent = CodeAgent(
|
| 98 |
+
model=model,
|
| 99 |
tools=[calculator, get_current_time_in_timezone],
|
| 100 |
max_steps=6,
|
| 101 |
verbosity_level=0,
|
|
|
|
| 120 |
try:
|
| 121 |
q = question.strip()
|
| 122 |
|
| 123 |
+
# Calculator queries
|
| 124 |
if self._is_calc(q):
|
| 125 |
m = re.search(r'([0-9\.\s\+\-\*\/\^\%\(\)]+)', q)
|
| 126 |
expr = m.group(1) if m else q
|
| 127 |
return calculator(expr)
|
| 128 |
|
| 129 |
+
# Time queries
|
| 130 |
if self._is_time(q):
|
| 131 |
if "paris" in q.lower() or "france" in q.lower():
|
| 132 |
tz = "Europe/Paris"
|
|
|
|
| 134 |
tz = "UTC"
|
| 135 |
return get_current_time_in_timezone(tz)
|
| 136 |
|
| 137 |
+
# fallback LLM
|
| 138 |
resp = self.code_agent.run(q)
|
| 139 |
if isinstance(resp, dict):
|
| 140 |
for key in ("final_answer", "answer", "result", "output"):
|
|
|
|
| 145 |
except Exception as e:
|
| 146 |
return json.dumps({"error": f"Agent internal error: {e}"})
|
| 147 |
|
| 148 |
+
# instantiate GAIA agent
|
| 149 |
gaia_agent = GaiaAgentMinimal(code_agent)
|
| 150 |
|
| 151 |
# -------------------------
|