jcleee's picture
Update app.py
9a2c5c8 verified
raw
history blame
5.3 kB
import os
import datetime
import requests
import pytz
import yaml
from typing import List
from langchain.text_splitter import CharacterTextSplitter
from tools.final_answer import FinalAnswerTool
from smolagents import CodeAgent, LiteLLMModel, DuckDuckGoSearchTool, HfApiModel, load_tool, tool
from Gradio_UI import GradioUI
# === TOOLS ===
@tool
def web_search(query: str) -> str:
"""Allows search through DuckDuckGo.
Args:
query: what you want to search
"""
search_tool = DuckDuckGoSearchTool()
results = search_tool(query)
return "\n".join(results)
@tool
def get_current_time_in_timezone(timezone: str) -> str:
"""Fetches the current local time in a specified timezone.
Args:
timezone: A string representing a valid timezone (e.g., 'America/New_York').
"""
try:
tz = pytz.timezone(timezone)
local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
return f"The current local time in {timezone} is: {local_time}"
except Exception as e:
return f"Error fetching time for timezone '{timezone}': {str(e)}"
@tool
def visit_webpage(url: str) -> str:
"""Fetches raw HTML content of a web page.
Args:
url: The url of the webpage.
"""
try:
response = requests.get(url, timeout=5)
return response.text[:5000] # Limit length
except Exception as e:
return f"[ERROR fetching {url}]: {str(e)}"
@tool
def text_splitter(text: str) -> List[str]:
"""Splits text into chunks using LangChain's CharacterTextSplitter.
Args:
text: A string of text to split.
"""
splitter = CharacterTextSplitter(chunk_size=450, chunk_overlap=10)
return splitter.split_text(text)
# === FINAL ANSWER TOOL ===
final_answer = FinalAnswerTool()
# === LOAD PROMPT TEMPLATES ===
with open("prompts.yaml", "r") as stream:
prompt_templates = yaml.safe_load(stream)
# === LOAD agent.json CONFIG ===
with open("agent.json", "r") as f:
agent_config = yaml.safe_load(f)
model_config = agent_config["model"]["data"]
# === BUILD MODEL ===
model = LiteLLMModel(
model_id="gemini/gemini-2.0-flash-lite",
api_key=os.getenv("GEMINI_API_KEY"),
temperature=0.5,
max_tokens=1024,
)
# If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
# model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud' (!!!)
# model = HfApiModel(
# #model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
# model_id="mistralai/Mistral-7B-Instruct-v0.2",
# token=os.getenv("HF_TOKEN"),
# max_tokens=2096,
# temperature=0.5,
# last_input_token_count=0,
# last_output_token_count=0,
# )
# === IMPORT TOOL FROM HUB ===
image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
# === BUILD AGENT ===
agent = CodeAgent(
model=model,
tools=[
final_answer,
web_search,
get_current_time_in_timezone,
visit_webpage,
text_splitter,
image_generation_tool
],
max_steps=6,
verbosity_level=1,
grammar=None,
planning_interval=None,
name=None,
description=None,
prompt_templates=prompt_templates
)
# === LAUNCH UI ===
# GradioUI(agent).launch() # new change 6:36 - for evaluation API
import gradio as gr
# OLD VERSION
def run_agent(question):
try:
result = agent(question)
return [str(result)] # Must return a list with one string (like ["answer"])
except Exception as e:
return [f"Error: {e}"]
# # NEW VERSION
# def run_agent(question):
# try:
# # Get all steps from the agent run
# steps = list(agent.run(question, stream=False))
# # Look through all tool calls to find final_answer
# for step in steps:
# if hasattr(step, "tool_calls") and step.tool_calls:
# for call in step.tool_calls:
# if call.name == "final_answer":
# answer = call.arguments.get("answer", None)
# if answer:
# return [str(answer)]
# # Fallback: try tool_output or .final_answer if FinalAnswerStep
# for step in reversed(steps):
# if hasattr(step, "tool_output") and step.tool_output:
# return [str(step.tool_output)]
# if hasattr(step, "final_answer") and step.final_answer:
# return [str(step.final_answer)]
# return ["null"]
# except Exception as e:
# return [f"Error: {e}"]
# # NEWER VERSION (should just return the result)
# import re
# def run_agent(question):
# try:
# result = agent(question)
# # If result is a string and contains "### 1. Task outcome", extract that
# if isinstance(result, str):
# match = re.search(r"### 1\. Task outcome \(short version\):\s*(.+)", result)
# if match:
# return [match.group(1).strip()] # return just the short answer
# return [result.strip()]
# return [str(result)]
# except Exception as e:
# return [f"Error: {e}"]
demo = gr.Interface(fn=run_agent, inputs="text", outputs="text")
demo.launch()