Spaces:
Runtime error
Runtime error
Leonardo
commited on
Sync local Space with Hub
Browse files- .gitignore +3 -0
- README.md +5 -5
- app.py +559 -116
- flux_image.py +0 -0
- requirements.txt +15 -18
- scripts/cookies.py +3 -2
- scripts/finance_tools.py +987 -0
- scripts/flux_lora_tool.py +117 -120
- scripts/frontmatter_tool.py +0 -402
- scripts/gaia_scorer.py +0 -124
- scripts/mdconvert.py +90 -26
- scripts/reformulator.py +0 -86
- scripts/run_agents.py +0 -87
- scripts/text_cleaner_tool.py +15 -23
- scripts/text_inspector_tool.py +51 -7
- scripts/text_web_browser.py +119 -35
- scripts/time_tools.py +139 -0
- scripts/visual_qa.py +33 -15
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
logs/
|
| 3 |
+
data
|
README.md
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
---
|
| 2 |
title: ODR
|
| 3 |
emoji: 🏆
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
-
short_description: OpenAI's Deep Research, but open
|
| 12 |
---
|
| 13 |
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: ODR
|
| 3 |
emoji: 🏆
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.14.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
+
short_description: OpenAI's Deep Research, but open
|
| 12 |
---
|
| 13 |
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,44 +1,131 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
# Copyright 2024 The Footscray Coding Collective. All rights reserved.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import mimetypes
|
| 5 |
import os
|
| 6 |
import re
|
| 7 |
import shutil
|
| 8 |
-
from typing import Optional
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import gradio as gr
|
|
|
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
from huggingface_hub import login
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from scripts.flux_lora_tool import FluxLoRATool
|
| 14 |
from scripts.text_cleaner_tool import TextCleanerTool
|
| 15 |
from scripts.text_inspector_tool import TextInspectorTool
|
| 16 |
-
from scripts.text_web_browser import (
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
PageDownTool,
|
| 21 |
-
PageUpTool,
|
| 22 |
-
SimpleTextBrowser,
|
| 23 |
-
VisitTool,
|
| 24 |
-
)
|
| 25 |
from scripts.visual_qa import visualizer
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
Tool,
|
| 33 |
-
TransformersModel,
|
| 34 |
)
|
| 35 |
-
from smolagents.agent_types import AgentAudio, AgentImage, AgentText
|
| 36 |
-
from smolagents.gradio_ui import handle_agent_output_types, pull_messages_from_step
|
| 37 |
|
| 38 |
# ------------------------ Configuration and Setup ------------------------
|
| 39 |
# Constants and configurations
|
| 40 |
AUTHORIZED_IMPORTS = [
|
| 41 |
"requests", # Web requests (fetching data from the internet)
|
|
|
|
| 42 |
"zipfile", # Working with ZIP archives
|
| 43 |
"pandas", # Data manipulation and analysis (DataFrames)
|
| 44 |
"numpy", # Numerical computing (arrays, linear algebra)
|
|
@@ -48,7 +135,7 @@ AUTHORIZED_IMPORTS = [
|
|
| 48 |
"pubchempy", # Accessing PubChem chemical database
|
| 49 |
"yaml",
|
| 50 |
"xml", # XML processing
|
| 51 |
-
"yahoo_finance", # Fetching stock
|
| 52 |
"Bio", # Bioinformatics tools (e.g., sequence analysis)
|
| 53 |
"sklearn", # Scikit-learn for machine learning
|
| 54 |
"scipy", # Scientific computing (stats, optimization)
|
|
@@ -74,7 +161,7 @@ AUTHORIZED_IMPORTS = [
|
|
| 74 |
"time", # Measuring time
|
| 75 |
"tempfile", # Creating temporary files and directories
|
| 76 |
# Data Visualization (if needed) - Consider security implications carefully
|
| 77 |
-
"matplotlib", # Plotting library
|
| 78 |
"seaborn", # Statistical data visualization (more advanced)
|
| 79 |
# Web Scraping (more specific/controlled) - Consider ethical implications
|
| 80 |
"lxml", # Faster XML/HTML processing (alternative to bs4)
|
|
@@ -85,6 +172,7 @@ AUTHORIZED_IMPORTS = [
|
|
| 85 |
"schedule", # Allow the agent to schedule tasks
|
| 86 |
"uuid",
|
| 87 |
"base64",
|
|
|
|
| 88 |
]
|
| 89 |
|
| 90 |
USER_AGENT = (
|
|
@@ -93,7 +181,7 @@ USER_AGENT = (
|
|
| 93 |
)
|
| 94 |
BROWSER_CONFIG = {
|
| 95 |
"viewport_size": 1024 * 5,
|
| 96 |
-
"downloads_folder": "downloads_folder",
|
| 97 |
"request_kwargs": {
|
| 98 |
"headers": {"User-Agent": USER_AGENT},
|
| 99 |
"timeout": 300,
|
|
@@ -103,7 +191,6 @@ BROWSER_CONFIG = {
|
|
| 103 |
|
| 104 |
CUSTOM_ROLE_CONVERSIONS = {"tool-call": "assistant", "tool-response": "user"}
|
| 105 |
|
| 106 |
-
|
| 107 |
ALLOWED_FILE_TYPES = [
|
| 108 |
"application/pdf",
|
| 109 |
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
@@ -121,23 +208,108 @@ ALLOWED_FILE_TYPES = [
|
|
| 121 |
]
|
| 122 |
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
load_dotenv(override=True)
|
|
|
|
|
|
|
| 127 |
if os.getenv("HF_TOKEN"): # Check if token is actually set
|
| 128 |
login(os.getenv("HF_TOKEN"))
|
| 129 |
-
print("HF_TOKEN
|
| 130 |
else:
|
| 131 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
# ------------------------ Model and Tool Management ------------------------
|
| 135 |
class ModelManager:
|
| 136 |
-
"""Manages model loading and initialization.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
@staticmethod
|
| 139 |
def load_model(chosen_inference: str, model_id: str, key_manager=None):
|
| 140 |
-
"""Load the specified model with appropriate configuration.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
try:
|
| 142 |
if chosen_inference == "hf_api":
|
| 143 |
return HfApiModel(model_id=model_id)
|
|
@@ -156,7 +328,7 @@ class ModelManager:
|
|
| 156 |
model_id=model_id, api_key=key_manager.get_key("openai_api_key")
|
| 157 |
)
|
| 158 |
|
| 159 |
-
|
| 160 |
return TransformersModel(
|
| 161 |
model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
| 162 |
device_map="auto",
|
|
@@ -167,41 +339,114 @@ class ModelManager:
|
|
| 167 |
raise ValueError(f"Invalid inference type: {chosen_inference}")
|
| 168 |
|
| 169 |
except Exception as e:
|
| 170 |
-
print(f"✗ Couldn't load model: {e}")
|
| 171 |
raise
|
| 172 |
|
| 173 |
|
|
|
|
| 174 |
class ToolRegistry:
|
| 175 |
-
"""Manages tool initialization and organization."""
|
| 176 |
|
| 177 |
@staticmethod
|
| 178 |
-
def
|
| 179 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
return [
|
| 181 |
-
GoogleSearchTool(provider="serper"),
|
| 182 |
-
VisitTool(browser),
|
| 183 |
-
PageUpTool(browser),
|
| 184 |
-
PageDownTool(browser),
|
| 185 |
-
FinderTool(browser),
|
| 186 |
-
FindNextTool(browser),
|
| 187 |
-
ArchiveSearchTool(browser),
|
| 188 |
TextInspectorTool(model, text_limit),
|
| 189 |
]
|
| 190 |
|
| 191 |
@staticmethod
|
| 192 |
-
def
|
| 193 |
"""
|
| 194 |
-
Initialize and return
|
|
|
|
| 195 |
Returns:
|
| 196 |
-
List of
|
| 197 |
"""
|
| 198 |
return [
|
| 199 |
TextCleanerTool(),
|
| 200 |
]
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
@staticmethod
|
| 203 |
def load_image_generation_tools():
|
| 204 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
try:
|
| 206 |
return Tool.from_space(
|
| 207 |
space_id="xkerser/FLUX.1-dev",
|
|
@@ -209,95 +454,219 @@ class ToolRegistry:
|
|
| 209 |
description="Generates high-quality AgentImage using the FLUX.1-dev model based on text prompts.",
|
| 210 |
)
|
| 211 |
except Exception as e:
|
| 212 |
-
print(
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
|
| 216 |
-
|
| 217 |
-
def create_agent():
|
| 218 |
"""
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
Returns:
|
| 221 |
-
|
|
|
|
| 222 |
Raises:
|
| 223 |
-
ValueError: If tool validation fails
|
| 224 |
RuntimeError: If agent creation fails
|
| 225 |
"""
|
| 226 |
try:
|
| 227 |
-
# Initialize model
|
| 228 |
-
model =
|
| 229 |
-
custom_role_conversions=CUSTOM_ROLE_CONVERSIONS,
|
| 230 |
-
model_id="openrouter/google/gemini-2.0-flash-001",
|
| 231 |
-
)
|
| 232 |
|
| 233 |
# Initialize tools
|
| 234 |
text_limit = 30000
|
| 235 |
browser = SimpleTextBrowser(**BROWSER_CONFIG)
|
| 236 |
|
| 237 |
-
# Collect all tools
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
| 240 |
image_generator = ToolRegistry.load_image_generation_tools()
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
# Validate tools before creating agent
|
| 246 |
-
|
| 247 |
-
if not isinstance(tool, Tool):
|
| 248 |
-
raise ValueError(
|
| 249 |
-
f"Invalid tool type: {type(tool)}. "
|
| 250 |
-
f"All tools must be instances of Tool class."
|
| 251 |
-
)
|
| 252 |
|
| 253 |
return CodeAgent(
|
| 254 |
model=model,
|
| 255 |
tools=all_tools,
|
| 256 |
-
max_steps=
|
| 257 |
verbosity_level=2,
|
| 258 |
additional_authorized_imports=AUTHORIZED_IMPORTS,
|
| 259 |
-
planning_interval=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
)
|
| 261 |
-
except
|
| 262 |
-
print(f"
|
| 263 |
raise RuntimeError(f"Agent creation failed: {e}")
|
| 264 |
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
def stream_to_gradio(
|
| 267 |
agent,
|
| 268 |
task: str,
|
| 269 |
reset_agent_memory: bool = False,
|
| 270 |
additional_args: Optional[dict] = None,
|
| 271 |
):
|
| 272 |
-
"""
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
yield gr.ChatMessage(
|
| 295 |
role="assistant",
|
| 296 |
-
content=
|
| 297 |
-
) # Send as Gradio-compatible file object
|
| 298 |
-
else:
|
| 299 |
-
yield gr.ChatMessage(
|
| 300 |
-
role="assistant", content=f"**Final answer:** {str(final_answer)}"
|
| 301 |
)
|
| 302 |
|
| 303 |
|
|
@@ -313,20 +682,37 @@ class GradioUI:
|
|
| 313 |
if not os.path.exists(file_upload_folder):
|
| 314 |
os.mkdir(file_upload_folder)
|
| 315 |
|
| 316 |
-
def interact_with_agent(
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
# Get or create session-specific agent
|
| 320 |
if "agent" not in session_state:
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
# Adding monitoring
|
| 324 |
try:
|
| 325 |
# Log the existence of agent memory
|
| 326 |
has_memory = hasattr(session_state["agent"], "memory")
|
| 327 |
-
print(f"Agent has memory: {has_memory}")
|
| 328 |
if has_memory:
|
| 329 |
-
print(f"Memory type: {type(session_state['agent'].memory)}")
|
| 330 |
|
| 331 |
messages.append(gr.ChatMessage(role="user", content=prompt))
|
| 332 |
yield messages
|
|
@@ -339,7 +725,7 @@ class GradioUI:
|
|
| 339 |
yield messages # Yield messages one last time
|
| 340 |
|
| 341 |
except Exception as e:
|
| 342 |
-
print(f"Error in interaction: {str(e)}")
|
| 343 |
raise
|
| 344 |
|
| 345 |
def upload_file(
|
|
@@ -448,7 +834,7 @@ class GradioUI:
|
|
| 448 |
@gr.render()
|
| 449 |
def layout(request: gr.Request):
|
| 450 |
device = self.detect_device(request)
|
| 451 |
-
print(f"device - {device}")
|
| 452 |
# Render layout with sidebar
|
| 453 |
if device == "Desktop":
|
| 454 |
return self._create_desktop_layout()
|
|
@@ -464,7 +850,7 @@ class GradioUI:
|
|
| 464 |
with gr.Sidebar():
|
| 465 |
gr.Markdown(
|
| 466 |
"""#OpenDeepResearch - 3theSmolagents!
|
| 467 |
-
Model_id:
|
| 468 |
)
|
| 469 |
with gr.Group():
|
| 470 |
gr.Markdown("**What's on your mind mate?**", container=True)
|
|
@@ -635,18 +1021,75 @@ class GradioUI:
|
|
| 635 |
)
|
| 636 |
|
| 637 |
|
| 638 |
-
# ------------------------
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
setup_environment()
|
| 643 |
|
| 644 |
-
#
|
| 645 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
-
|
| 648 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
|
|
|
|
| 651 |
if __name__ == "__main__":
|
| 652 |
-
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
# Copyright 2024 The Footscray Coding Collective. All rights reserved.
|
| 4 |
+
"""
|
| 5 |
+
Financial Research Agent: Advanced Market Analysis and Data Access
|
| 6 |
+
|
| 7 |
+
This script implements a comprehensive financial research agent capable of performing market analysis,
|
| 8 |
+
retrieving financial data, and providing interactive research capabilities through either a GUI or
|
| 9 |
+
command-line interface.
|
| 10 |
+
|
| 11 |
+
The agent leverages the Smolagents framework to create an autonomous system that can:
|
| 12 |
+
1. Access and analyze real-time market data through Alpha Vantage API integration
|
| 13 |
+
2. Process financial documents and extract relevant information
|
| 14 |
+
3. Perform web searches and analyze webpage content
|
| 15 |
+
4. Create visualizations of financial data
|
| 16 |
+
5. Generate comprehensive financial analysis reports
|
| 17 |
+
6. Handle user uploads of various document types
|
| 18 |
+
|
| 19 |
+
Key Components:
|
| 20 |
+
-------------
|
| 21 |
+
- ModelManager: Handles loading and configuration of various LLM models
|
| 22 |
+
- ToolRegistry: Manages initialization and organization of tools available to the agent
|
| 23 |
+
- GradioUI: Provides a user-friendly interface with responsive design for desktop/mobile
|
| 24 |
+
- A robust set of financial tools for retrieving stock data, financial statements, and market sentiment
|
| 25 |
+
- Web browsing capabilities with text extraction and analysis
|
| 26 |
+
- Document processing for PDFs, spreadsheets, and other common file formats
|
| 27 |
+
- Visualization tools for creating charts and graphs from financial data
|
| 28 |
+
|
| 29 |
+
Usage:
|
| 30 |
+
-----
|
| 31 |
+
Run in UI mode (default):
|
| 32 |
+
python app.py
|
| 33 |
+
|
| 34 |
+
Run in headless mode with a specific query:
|
| 35 |
+
python app.py --mode headless --query "Analyze Tesla's financial performance for 2023"
|
| 36 |
+
|
| 37 |
+
Configuration:
|
| 38 |
+
------------
|
| 39 |
+
The script uses environment variables for API keys and other configuration settings.
|
| 40 |
+
Required environment variables:
|
| 41 |
+
- ALPHA_VANTAGE_API_KEY: For accessing financial data APIs
|
| 42 |
+
- HF_TOKEN: For accessing Hugging Face models (optional)
|
| 43 |
+
|
| 44 |
+
The agent also maintains detailed logs in the logs/ directory for debugging and auditing.
|
| 45 |
+
|
| 46 |
+
Dependencies:
|
| 47 |
+
-----------
|
| 48 |
+
- smolagents: Core framework for agent capabilities
|
| 49 |
+
- gradio: For the web interface
|
| 50 |
+
- Alpha Vantage API integration: For financial data
|
| 51 |
+
- Various data processing libraries: For handling and analyzing financial information
|
| 52 |
+
|
| 53 |
+
Technical Notes:
|
| 54 |
+
--------------
|
| 55 |
+
- The agent runs with a configurable number of maximum steps (default: 20)
|
| 56 |
+
- Planning occurs at regular intervals (default: every 4 steps)
|
| 57 |
+
- The agent has access to a curated list of authorized Python imports for security
|
| 58 |
+
- All file uploads are validated for type and size before processing
|
| 59 |
+
|
| 60 |
+
Created by the Footscray Coding Collective
|
| 61 |
+
Copyright 2024, All rights reserved
|
| 62 |
+
"""
|
| 63 |
+
import contextlib
|
| 64 |
+
import datetime
|
| 65 |
+
import logging
|
| 66 |
import mimetypes
|
| 67 |
import os
|
| 68 |
import re
|
| 69 |
import shutil
|
| 70 |
+
from typing import Any, Dict, Generator, List, Optional, Tuple
|
| 71 |
|
| 72 |
+
# Typer for CLI functionality
|
| 73 |
+
import typer
|
| 74 |
+
|
| 75 |
+
# Telemetry imports (optional)
|
| 76 |
+
with contextlib.suppress(ImportError):
|
| 77 |
+
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
| 78 |
+
from phoenix.otel import register
|
| 79 |
+
|
| 80 |
+
# Initialize telemetry for observability and tracing
|
| 81 |
+
register()
|
| 82 |
+
SmolagentsInstrumentor().instrument()
|
| 83 |
+
|
| 84 |
+
# third-party
|
| 85 |
import gradio as gr
|
| 86 |
+
import pytz
|
| 87 |
from dotenv import load_dotenv
|
| 88 |
from huggingface_hub import login
|
| 89 |
+
from rich.console import Console
|
| 90 |
+
from rich.logging import RichHandler
|
| 91 |
+
from smolagents import FinalAnswerTool # smolagents
|
| 92 |
+
from smolagents import (CodeAgent, GoogleSearchTool, HfApiModel, LiteLLMModel,
|
| 93 |
+
OpenAIServerModel, Tool, TransformersModel)
|
| 94 |
+
from smolagents.agent_types import AgentText
|
| 95 |
+
from smolagents.gradio_ui import (handle_agent_output_types,
|
| 96 |
+
pull_messages_from_step)
|
| 97 |
+
|
| 98 |
+
# local
|
| 99 |
+
from scripts.finance_tools import (DataVisualizationTool,
|
| 100 |
+
FinancialCalculatorTool, TrendAnalysisTool,
|
| 101 |
+
get_balance_sheet_data, get_cash_flow_data,
|
| 102 |
+
get_company_overview_data,
|
| 103 |
+
get_earnings_data,
|
| 104 |
+
get_income_statement_data,
|
| 105 |
+
get_market_news_sentiment,
|
| 106 |
+
get_stock_quote_data, get_time_series_daily,
|
| 107 |
+
search_symbols)
|
| 108 |
from scripts.flux_lora_tool import FluxLoRATool
|
| 109 |
from scripts.text_cleaner_tool import TextCleanerTool
|
| 110 |
from scripts.text_inspector_tool import TextInspectorTool
|
| 111 |
+
from scripts.text_web_browser import (ArchiveSearchTool, DownloadTool,
|
| 112 |
+
FinderTool, FindNextTool, PageDownTool,
|
| 113 |
+
PageUpTool, SimpleTextBrowser, VisitTool)
|
| 114 |
+
from scripts.time_tools import get_temporal_context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
from scripts.visual_qa import visualizer
|
| 116 |
+
|
| 117 |
+
# Initialize console and app
|
| 118 |
+
console = Console()
|
| 119 |
+
app = typer.Typer(
|
| 120 |
+
help="Financial Research Agent - Access market data and analysis through a CLI or UI",
|
| 121 |
+
add_completion=False,
|
|
|
|
|
|
|
| 122 |
)
|
|
|
|
|
|
|
| 123 |
|
| 124 |
# ------------------------ Configuration and Setup ------------------------
|
| 125 |
# Constants and configurations
|
| 126 |
AUTHORIZED_IMPORTS = [
|
| 127 |
"requests", # Web requests (fetching data from the internet)
|
| 128 |
+
"pytz", # Timezone handling
|
| 129 |
"zipfile", # Working with ZIP archives
|
| 130 |
"pandas", # Data manipulation and analysis (DataFrames)
|
| 131 |
"numpy", # Numerical computing (arrays, linear algebra)
|
|
|
|
| 135 |
"pubchempy", # Accessing PubChem chemical database
|
| 136 |
"yaml",
|
| 137 |
"xml", # XML processing
|
| 138 |
+
"yahoo_finance", # Fetching stock datauv
|
| 139 |
"Bio", # Bioinformatics tools (e.g., sequence analysis)
|
| 140 |
"sklearn", # Scikit-learn for machine learning
|
| 141 |
"scipy", # Scientific computing (stats, optimization)
|
|
|
|
| 161 |
"time", # Measuring time
|
| 162 |
"tempfile", # Creating temporary files and directories
|
| 163 |
# Data Visualization (if needed) - Consider security implications carefully
|
| 164 |
+
"matplotlib.plt", # Plotting library
|
| 165 |
"seaborn", # Statistical data visualization (more advanced)
|
| 166 |
# Web Scraping (more specific/controlled) - Consider ethical implications
|
| 167 |
"lxml", # Faster XML/HTML processing (alternative to bs4)
|
|
|
|
| 172 |
"schedule", # Allow the agent to schedule tasks
|
| 173 |
"uuid",
|
| 174 |
"base64",
|
| 175 |
+
"smolagents", # smolagents package to be able to create smolagents tools
|
| 176 |
]
|
| 177 |
|
| 178 |
USER_AGENT = (
|
|
|
|
| 181 |
)
|
| 182 |
BROWSER_CONFIG = {
|
| 183 |
"viewport_size": 1024 * 5,
|
| 184 |
+
"downloads_folder": "data/downloads_folder",
|
| 185 |
"request_kwargs": {
|
| 186 |
"headers": {"User-Agent": USER_AGENT},
|
| 187 |
"timeout": 300,
|
|
|
|
| 191 |
|
| 192 |
CUSTOM_ROLE_CONVERSIONS = {"tool-call": "assistant", "tool-response": "user"}
|
| 193 |
|
|
|
|
| 194 |
ALLOWED_FILE_TYPES = [
|
| 195 |
"application/pdf",
|
| 196 |
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
|
|
| 208 |
]
|
| 209 |
|
| 210 |
|
| 211 |
+
# Set up logging configuration
|
| 212 |
+
def setup_logging() -> Tuple[str, logging.Logger]:
|
| 213 |
+
"""
|
| 214 |
+
Configure logging with structured output and file storage.
|
| 215 |
+
|
| 216 |
+
The function creates logs directory and timestamped log filename, sets up
|
| 217 |
+
logging with Rich integration and creates and returns logger.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
Tuple containing the log file path and configured logger
|
| 221 |
+
"""
|
| 222 |
+
# Create logs directory
|
| 223 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 224 |
+
logs_dir = os.path.join(current_dir, "logs")
|
| 225 |
+
os.makedirs(logs_dir, exist_ok=True)
|
| 226 |
+
|
| 227 |
+
# Generate timestamped log filename
|
| 228 |
+
melbourne_timezone = pytz.timezone("Australia/Melbourne")
|
| 229 |
+
log_filename = f'smolagents_{datetime.datetime.now(melbourne_timezone).strftime("%Y%m%d_%H%M%S")}.log'
|
| 230 |
+
log_file = os.path.join(logs_dir, log_filename)
|
| 231 |
+
|
| 232 |
+
# Set up logging with Rich integration
|
| 233 |
+
logging.basicConfig(
|
| 234 |
+
level=logging.INFO,
|
| 235 |
+
format="%(asctime)s [%(levelname)s] - %(message)s",
|
| 236 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 237 |
+
handlers=[
|
| 238 |
+
RichHandler(rich_tracebacks=True, show_time=True),
|
| 239 |
+
logging.FileHandler(log_file),
|
| 240 |
+
],
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Create and return logger
|
| 244 |
+
logger = logging.getLogger(__name__)
|
| 245 |
+
return log_file, logger
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
LOG_FILE, logger = setup_logging()
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def setup_environment() -> None:
|
| 252 |
+
"""Initialize environment variables and authentication.
|
| 253 |
+
|
| 254 |
+
This function ensures that required environment variables are set and
|
| 255 |
+
attempts to authenticate with Hugging Face and Alpha Vantage services.
|
| 256 |
+
"""
|
| 257 |
load_dotenv(override=True)
|
| 258 |
+
|
| 259 |
+
# Check Hugging Face token
|
| 260 |
if os.getenv("HF_TOKEN"): # Check if token is actually set
|
| 261 |
login(os.getenv("HF_TOKEN"))
|
| 262 |
+
console.print("HF_TOKEN loaded successfully")
|
| 263 |
else:
|
| 264 |
+
console.print(
|
| 265 |
+
"[yellow]HF_TOKEN not found in environment variables. "
|
| 266 |
+
"Some features may not work properly.[/yellow]"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Check Alpha Vantage API key
|
| 270 |
+
try:
|
| 271 |
+
# Ensure Alpha Vantage API key is available
|
| 272 |
+
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
| 273 |
+
if not api_key:
|
| 274 |
+
console.print(
|
| 275 |
+
"[yellow]⚠️ Warning: ALPHA_VANTAGE_API_KEY not found. "
|
| 276 |
+
"Finance tools may not work properly.[/yellow]"
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
console.print("[green]✓ ALPHA_VANTAGE_API_KEY loaded successfully[/green]")
|
| 280 |
+
except Exception as e:
|
| 281 |
+
console.print(f"[red]Error checking ALPHA_VANTAGE_API_KEY: {e}[/red]")
|
| 282 |
|
| 283 |
|
| 284 |
# ------------------------ Model and Tool Management ------------------------
|
| 285 |
class ModelManager:
|
| 286 |
+
"""Manages model loading and initialization.
|
| 287 |
+
|
| 288 |
+
This class provides a static method to load the specified model with the
|
| 289 |
+
appropriate configuration. It supports the following inference types:
|
| 290 |
+
- hf_api: Use the Hugging Face API to load the model.
|
| 291 |
+
- hf_api_provider: Use the Hugging Face API to load the model with the
|
| 292 |
+
'together' provider.
|
| 293 |
+
- litellm: Load the LiteLLM model with the specified model ID.
|
| 294 |
+
- openai: Load the OpenAI model with the specified model ID and API key.
|
| 295 |
+
- transformers: Load the Hugging Face transformers model with the
|
| 296 |
+
specified model ID and configuration.
|
| 297 |
+
"""
|
| 298 |
|
| 299 |
@staticmethod
|
| 300 |
def load_model(chosen_inference: str, model_id: str, key_manager=None):
|
| 301 |
+
"""Load the specified model with appropriate configuration.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
chosen_inference (str): The inference type to use.
|
| 305 |
+
model_id (str): The model ID to load.
|
| 306 |
+
key_manager (Optional[KeyManager]): The key manager to use for
|
| 307 |
+
loading the model. Required for OpenAI models.
|
| 308 |
+
|
| 309 |
+
Raises:
|
| 310 |
+
ValueError: If the chosen inference type is invalid.
|
| 311 |
+
Exception: If an error occurs while loading the model.
|
| 312 |
+
"""
|
| 313 |
try:
|
| 314 |
if chosen_inference == "hf_api":
|
| 315 |
return HfApiModel(model_id=model_id)
|
|
|
|
| 328 |
model_id=model_id, api_key=key_manager.get_key("openai_api_key")
|
| 329 |
)
|
| 330 |
|
| 331 |
+
if chosen_inference == "transformers":
|
| 332 |
return TransformersModel(
|
| 333 |
model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
| 334 |
device_map="auto",
|
|
|
|
| 339 |
raise ValueError(f"Invalid inference type: {chosen_inference}")
|
| 340 |
|
| 341 |
except Exception as e:
|
| 342 |
+
console.print(f"[red]✗ Couldn't load model: {e}[/red]")
|
| 343 |
raise
|
| 344 |
|
| 345 |
|
| 346 |
+
# ------------------------ Tool Registration ------------------------
|
| 347 |
class ToolRegistry:
|
| 348 |
+
"""Manages tool initialization and organization using Zhou Protocol priorities."""
|
| 349 |
|
| 350 |
@staticmethod
|
| 351 |
+
def load_information_tools(model, text_limit=30000):
|
| 352 |
+
"""
|
| 353 |
+
Initialize and return information analysis tools.
|
| 354 |
+
|
| 355 |
+
This method creates tools for analyzing text from documents, and other sources.
|
| 356 |
+
The information tools should be prioritized first in the agent's toolset.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
model: Language model to use for analysis
|
| 360 |
+
text_limit: Maximum character length for text summaries
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
List of information analysis tools
|
| 364 |
+
"""
|
| 365 |
return [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
TextInspectorTool(model, text_limit),
|
| 367 |
]
|
| 368 |
|
| 369 |
@staticmethod
|
| 370 |
+
def load_utility_tools():
|
| 371 |
"""
|
| 372 |
+
Initialize and return utility tools for text cleaning and normalization.
|
| 373 |
+
|
| 374 |
Returns:
|
| 375 |
+
List of utility tools
|
| 376 |
"""
|
| 377 |
return [
|
| 378 |
TextCleanerTool(),
|
| 379 |
]
|
| 380 |
|
| 381 |
+
@staticmethod
|
| 382 |
+
def load_time_tools():
|
| 383 |
+
"""
|
| 384 |
+
Initialize and return time-related tools.
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
List of time-related tools
|
| 388 |
+
"""
|
| 389 |
+
return [get_temporal_context]
|
| 390 |
+
|
| 391 |
+
@staticmethod
|
| 392 |
+
def load_finance_tools():
|
| 393 |
+
"""
|
| 394 |
+
Initialize and return financial analysis tools.
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
List of financial tools in priority order
|
| 398 |
+
"""
|
| 399 |
+
return [
|
| 400 |
+
# Analysis tools first (higher priority)
|
| 401 |
+
DataVisualizationTool(),
|
| 402 |
+
FinancialCalculatorTool(),
|
| 403 |
+
TrendAnalysisTool(),
|
| 404 |
+
# Data retrieval tools next
|
| 405 |
+
search_symbols,
|
| 406 |
+
get_stock_quote_data,
|
| 407 |
+
get_company_overview_data,
|
| 408 |
+
get_earnings_data,
|
| 409 |
+
get_income_statement_data,
|
| 410 |
+
get_balance_sheet_data,
|
| 411 |
+
get_cash_flow_data,
|
| 412 |
+
get_time_series_daily,
|
| 413 |
+
get_market_news_sentiment,
|
| 414 |
+
]
|
| 415 |
+
|
| 416 |
+
@staticmethod
|
| 417 |
+
def load_web_tools(browser, text_limit=20000):
|
| 418 |
+
"""
|
| 419 |
+
Initialize and return web interaction tools.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
browser: Browser instance for web navigation
|
| 423 |
+
text_limit: Maximum character length for text processing
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
List of web tools in priority order
|
| 427 |
+
"""
|
| 428 |
+
return [
|
| 429 |
+
# Search tools first
|
| 430 |
+
GoogleSearchTool(provider="serper"),
|
| 431 |
+
# Navigation tools next
|
| 432 |
+
VisitTool(browser),
|
| 433 |
+
DownloadTool(browser),
|
| 434 |
+
# Page interaction tools last
|
| 435 |
+
PageUpTool(browser),
|
| 436 |
+
PageDownTool(browser),
|
| 437 |
+
FinderTool(browser),
|
| 438 |
+
FindNextTool(browser),
|
| 439 |
+
ArchiveSearchTool(browser),
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
@staticmethod
|
| 443 |
def load_image_generation_tools():
|
| 444 |
+
"""
|
| 445 |
+
Initialize and return image generation tools.
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
Image generation tool or fallback
|
| 449 |
+
"""
|
| 450 |
try:
|
| 451 |
return Tool.from_space(
|
| 452 |
space_id="xkerser/FLUX.1-dev",
|
|
|
|
| 454 |
description="Generates high-quality AgentImage using the FLUX.1-dev model based on text prompts.",
|
| 455 |
)
|
| 456 |
except Exception as e:
|
| 457 |
+
console.print(
|
| 458 |
+
f"[yellow]✗ Couldn't initialize image generation tool: {e}[/yellow]"
|
| 459 |
+
)
|
| 460 |
+
return FluxLoRATool()
|
| 461 |
+
|
| 462 |
+
@staticmethod
|
| 463 |
+
def load_final_answer_tool():
|
| 464 |
+
"""
|
| 465 |
+
Return the final answer tool for providing conclusive responses.
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
List containing the final answer tool
|
| 469 |
+
"""
|
| 470 |
+
return [FinalAnswerTool()]
|
| 471 |
|
| 472 |
|
| 473 |
+
def create_agent(model_id: str = "openrouter/google/gemini-2.0-flash-001"):
|
|
|
|
| 474 |
"""
|
| 475 |
+
Create a fresh agent instance with properly configured tools.
|
| 476 |
+
|
| 477 |
+
This function creates a CodeAgent with tools organized by the Zhou Protocol
|
| 478 |
+
priority system, ensuring the most relevant tools are considered first.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
model_id: The ID of the model to use for the agent
|
| 482 |
+
|
| 483 |
Returns:
|
| 484 |
+
A configured CodeAgent instance
|
| 485 |
+
|
| 486 |
Raises:
|
|
|
|
| 487 |
RuntimeError: If agent creation fails
|
| 488 |
"""
|
| 489 |
try:
|
| 490 |
+
# Initialize model with fallback system
|
| 491 |
+
model = _load_model_with_fallback(model_id)
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
# Initialize tools
|
| 494 |
text_limit = 30000
|
| 495 |
browser = SimpleTextBrowser(**BROWSER_CONFIG)
|
| 496 |
|
| 497 |
+
# Collect all tools with proper Zhou Protocol prioritization
|
| 498 |
+
information_tools = ToolRegistry.load_information_tools(model, text_limit)
|
| 499 |
+
utility_tools = ToolRegistry.load_utility_tools()
|
| 500 |
+
finance_tools = ToolRegistry.load_finance_tools()
|
| 501 |
+
web_tools = ToolRegistry.load_web_tools(browser)
|
| 502 |
+
time_tools = ToolRegistry.load_time_tools()
|
| 503 |
image_generator = ToolRegistry.load_image_generation_tools()
|
| 504 |
+
final_answer = ToolRegistry.load_final_answer_tool()
|
| 505 |
+
|
| 506 |
+
# Combine all tools with information tools prioritized first
|
| 507 |
+
all_tools = (
|
| 508 |
+
information_tools # Critical information extraction (highest priority)
|
| 509 |
+
+ utility_tools # General utility functions
|
| 510 |
+
+ finance_tools # Financial analysis capabilities
|
| 511 |
+
+ web_tools # Web search and navigation
|
| 512 |
+
+ time_tools # Time context tools
|
| 513 |
+
+ [visualizer] # Image analysis
|
| 514 |
+
+ [image_generator] # Image generation
|
| 515 |
+
+ final_answer # Task completion (always last)
|
| 516 |
+
)
|
| 517 |
|
| 518 |
# Validate tools before creating agent
|
| 519 |
+
_validate_tools(all_tools)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
return CodeAgent(
|
| 522 |
model=model,
|
| 523 |
tools=all_tools,
|
| 524 |
+
max_steps=20,
|
| 525 |
verbosity_level=2,
|
| 526 |
additional_authorized_imports=AUTHORIZED_IMPORTS,
|
| 527 |
+
planning_interval=4,
|
| 528 |
+
description="""
|
| 529 |
+
This agent assists with comprehensive research and financial analysis. It first analyzes
|
| 530 |
+
any provided documents or text, then leverages specialized financial tools and web search
|
| 531 |
+
capabilities to provide thorough insights.
|
| 532 |
+
|
| 533 |
+
QUERY COMPREHENSION FRAMEWORK
|
| 534 |
+
Before answering any complex question, apply the Zhou Comprehension Pattern:
|
| 535 |
+
1. **Initial Parse**: What is literally being asked?
|
| 536 |
+
2. **Intent Detection**: What is the user actually trying to accomplish?
|
| 537 |
+
3. **Knowledge Assessment**: What information is needed to address this properly?
|
| 538 |
+
4. **Tool Selection**: Which tools provide the most direct path to a solution?
|
| 539 |
+
5. **Execution Planning**: What sequence of operations will yield the best result?
|
| 540 |
+
|
| 541 |
+
CLARIFICATION CHECKLIST
|
| 542 |
+
When faced with ambiguous queries, the agent should systematically clarify:
|
| 543 |
+
* **Scope**: "How comprehensive should this analysis be?"
|
| 544 |
+
* **Format**: "What form would you like the results in?"
|
| 545 |
+
* **Technical Level**: "Should I explain technical details or focus on practical applications?"
|
| 546 |
+
* **Time Horizon**: "Are you interested in historical data, current status, or future projections?"
|
| 547 |
+
* **Priority**: "Which aspect of this question is most important to you?"
|
| 548 |
+
""".strip(),
|
| 549 |
)
|
| 550 |
+
except Exception as e:
|
| 551 |
+
console.print(f"[red]✗ Agent creation failed: {e}[/red]")
|
| 552 |
raise RuntimeError(f"Agent creation failed: {e}")
|
| 553 |
|
| 554 |
|
| 555 |
+
def _load_model_with_fallback(model_id: str) -> Any:
|
| 556 |
+
"""
|
| 557 |
+
Attempt to load the specified model with fallbacks if it fails.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
model_id: Primary model ID to try loading
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
Loaded model instance
|
| 564 |
+
|
| 565 |
+
Raises:
|
| 566 |
+
RuntimeError: If all model loading attempts fail
|
| 567 |
+
"""
|
| 568 |
+
# Fallback model chain from most capable to most reliable
|
| 569 |
+
fallback_models = [
|
| 570 |
+
model_id, # Try the requested model first
|
| 571 |
+
"openrouter/anthropic/claude-3.7-sonnet",
|
| 572 |
+
"openai/gpt-4o-mini",
|
| 573 |
+
"anthropic/claude-3.7-sonnet",
|
| 574 |
+
"HuggingFaceTB/SmolLM2-1.7B-Instruct", # Last resort local option
|
| 575 |
+
]
|
| 576 |
+
|
| 577 |
+
last_error = None
|
| 578 |
+
for model in fallback_models:
|
| 579 |
+
try:
|
| 580 |
+
return LiteLLMModel(
|
| 581 |
+
custom_role_conversions=CUSTOM_ROLE_CONVERSIONS,
|
| 582 |
+
model_id=model,
|
| 583 |
+
)
|
| 584 |
+
except Exception as e:
|
| 585 |
+
last_error = e
|
| 586 |
+
console.print(f"[yellow]Failed to load model {model}: {e}[/yellow]")
|
| 587 |
+
|
| 588 |
+
# If we get here, all models failed
|
| 589 |
+
raise RuntimeError(f"All model loading attempts failed. Last error: {last_error}")
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def _validate_tools(tools):
|
| 593 |
+
"""
|
| 594 |
+
Validate that all tools are proper Tool instances.
|
| 595 |
+
|
| 596 |
+
Args:
|
| 597 |
+
tools: List of tools to validate
|
| 598 |
+
|
| 599 |
+
Raises:
|
| 600 |
+
ValueError: If any tool is not a Tool instance
|
| 601 |
+
"""
|
| 602 |
+
for tool in tools:
|
| 603 |
+
if not isinstance(tool, Tool):
|
| 604 |
+
raise ValueError(
|
| 605 |
+
f"Invalid tool type: {type(tool)}. "
|
| 606 |
+
f"All tools must be instances of Tool class."
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
# ------------------------ Gradio UI Components ------------------------
|
| 611 |
+
|
| 612 |
+
|
| 613 |
def stream_to_gradio(
|
| 614 |
agent,
|
| 615 |
task: str,
|
| 616 |
reset_agent_memory: bool = False,
|
| 617 |
additional_args: Optional[dict] = None,
|
| 618 |
):
|
| 619 |
+
"""Streams agent responses with improved status indicators."""
|
| 620 |
+
try:
|
| 621 |
+
# Initial processing indicator
|
| 622 |
+
yield gr.ChatMessage(role="assistant", content="⏳ Processing your request...")
|
| 623 |
+
|
| 624 |
+
# Track what we've yielded to replace the processing indicator
|
| 625 |
+
first_message_yielded = False
|
| 626 |
+
|
| 627 |
+
for step_log in agent.run(
|
| 628 |
+
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
|
| 629 |
+
):
|
| 630 |
+
# The key fix: pull_messages_from_step is a generator function that yields messages
|
| 631 |
+
# We need to iterate through each yielded message
|
| 632 |
+
for message in pull_messages_from_step(step_log):
|
| 633 |
+
if not first_message_yielded:
|
| 634 |
+
# Replace the initial "Processing" message
|
| 635 |
+
first_message_yielded = True
|
| 636 |
+
message.content = message.content.replace(
|
| 637 |
+
"⏳ Processing your request...", ""
|
| 638 |
+
)
|
| 639 |
|
| 640 |
+
# Check what type of operation is being performed based on the metadata or content
|
| 641 |
+
# Instead of trying to access a 'status' attribute that doesn't exist
|
| 642 |
+
content_lower = (
|
| 643 |
+
message.content.lower() if hasattr(message, "content") else ""
|
| 644 |
+
)
|
| 645 |
|
| 646 |
+
if "document analysis" in content_lower:
|
| 647 |
+
message.content = f"📄 **Document Analysis:** {message.content}"
|
| 648 |
+
elif "search" in content_lower:
|
| 649 |
+
message.content = f"🔍 **Search:** {message.content}"
|
| 650 |
+
|
| 651 |
+
yield message
|
| 652 |
+
|
| 653 |
+
# Final answer with enhanced formatting
|
| 654 |
+
final_answer = handle_agent_output_types(step_log)
|
| 655 |
+
|
| 656 |
+
if isinstance(final_answer, AgentText):
|
| 657 |
+
yield gr.ChatMessage(
|
| 658 |
+
role="assistant",
|
| 659 |
+
content=f"✅ **Final Answer:**\n\n{final_answer.to_string()}",
|
| 660 |
+
)
|
| 661 |
+
else:
|
| 662 |
+
yield gr.ChatMessage(
|
| 663 |
+
role="assistant", content=f"✅ **Final Answer:** {str(final_answer)}"
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
except Exception as e:
|
| 667 |
yield gr.ChatMessage(
|
| 668 |
role="assistant",
|
| 669 |
+
content=f"❌ **Error:** {str(e)}\n\nPlease try again with a different query.",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
)
|
| 671 |
|
| 672 |
|
|
|
|
| 682 |
if not os.path.exists(file_upload_folder):
|
| 683 |
os.mkdir(file_upload_folder)
|
| 684 |
|
| 685 |
+
def interact_with_agent(
|
| 686 |
+
self,
|
| 687 |
+
prompt: str,
|
| 688 |
+
messages: List[gr.ChatMessage],
|
| 689 |
+
session_state: Dict[str, Any],
|
| 690 |
+
) -> Generator[List[gr.ChatMessage], None, None]:
|
| 691 |
+
"""Main interaction handler with the agent.
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
prompt: The user's input prompt
|
| 695 |
+
messages: The list of messages so far (including the user's prompt)
|
| 696 |
+
session_state: The current state of the user's session
|
| 697 |
+
|
| 698 |
+
Yields:
|
| 699 |
+
A list of messages after each step (including the user's prompt)
|
| 700 |
+
"""
|
| 701 |
|
| 702 |
# Get or create session-specific agent
|
| 703 |
if "agent" not in session_state:
|
| 704 |
+
model_id = session_state.get(
|
| 705 |
+
"model_id", "openrouter/google/gemini-2.0-flash-001"
|
| 706 |
+
)
|
| 707 |
+
session_state["agent"] = create_agent(model_id)
|
| 708 |
|
| 709 |
# Adding monitoring
|
| 710 |
try:
|
| 711 |
# Log the existence of agent memory
|
| 712 |
has_memory = hasattr(session_state["agent"], "memory")
|
| 713 |
+
console.print(f"Agent has memory: {has_memory}")
|
| 714 |
if has_memory:
|
| 715 |
+
console.print(f"Memory type: {type(session_state['agent'].memory)}")
|
| 716 |
|
| 717 |
messages.append(gr.ChatMessage(role="user", content=prompt))
|
| 718 |
yield messages
|
|
|
|
| 725 |
yield messages # Yield messages one last time
|
| 726 |
|
| 727 |
except Exception as e:
|
| 728 |
+
console.print(f"[red]Error in interaction: {str(e)}[/red]")
|
| 729 |
raise
|
| 730 |
|
| 731 |
def upload_file(
|
|
|
|
| 834 |
@gr.render()
|
| 835 |
def layout(request: gr.Request):
|
| 836 |
device = self.detect_device(request)
|
| 837 |
+
console.print(f"device - {device}")
|
| 838 |
# Render layout with sidebar
|
| 839 |
if device == "Desktop":
|
| 840 |
return self._create_desktop_layout()
|
|
|
|
| 850 |
with gr.Sidebar():
|
| 851 |
gr.Markdown(
|
| 852 |
"""#OpenDeepResearch - 3theSmolagents!
|
| 853 |
+
Model_id: deepseek/deepseek-r1"""
|
| 854 |
)
|
| 855 |
with gr.Group():
|
| 856 |
gr.Markdown("**What's on your mind mate?**", container=True)
|
|
|
|
| 1021 |
)
|
| 1022 |
|
| 1023 |
|
| 1024 |
+
# ------------------------ CLI Command ------------------------
|
| 1025 |
+
@app.command()
|
| 1026 |
+
def run(
|
| 1027 |
+
mode: str = typer.Option(
|
| 1028 |
+
"ui",
|
| 1029 |
+
"--mode",
|
| 1030 |
+
"-m",
|
| 1031 |
+
help="Operating mode: 'ui' for Gradio interface or 'headless' for CLI mode",
|
| 1032 |
+
),
|
| 1033 |
+
model_id: str = typer.Option(
|
| 1034 |
+
"openrouter/google/gemini-2.0-flash-001",
|
| 1035 |
+
"--model",
|
| 1036 |
+
help="Model ID to use for the agent",
|
| 1037 |
+
),
|
| 1038 |
+
query: Optional[str] = typer.Option(
|
| 1039 |
+
None, "--query", "-q", help="Query to execute (required in headless mode)"
|
| 1040 |
+
),
|
| 1041 |
+
):
|
| 1042 |
+
"""
|
| 1043 |
+
Run the financial research agent in either UI or headless mode.
|
| 1044 |
+
|
| 1045 |
+
In UI mode, launches a Gradio interface for interactive use.
|
| 1046 |
+
In headless mode, processes a single query and outputs the result to the console.
|
| 1047 |
+
"""
|
| 1048 |
+
# Setup environment variables
|
| 1049 |
setup_environment()
|
| 1050 |
|
| 1051 |
+
# Validate inputs for headless mode
|
| 1052 |
+
if mode == "headless" and not query:
|
| 1053 |
+
console.print("[red]Error: query parameter is required in headless mode[/red]")
|
| 1054 |
+
raise typer.Exit(code=1)
|
| 1055 |
+
|
| 1056 |
+
# Create agent with specified model ID
|
| 1057 |
+
console.print(f"[bold]Initializing agent with model:[/bold] {model_id}")
|
| 1058 |
+
|
| 1059 |
+
# Execute in appropriate mode
|
| 1060 |
+
if mode == "ui":
|
| 1061 |
+
console.print(
|
| 1062 |
+
"[bold green]Starting UI mode with Gradio interface...[/bold green]"
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
# Ensure downloads folder exists
|
| 1066 |
+
os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
|
| 1067 |
+
|
| 1068 |
+
# Launch UI
|
| 1069 |
+
GradioUI(file_upload_folder="data/uploaded_files").launch()
|
| 1070 |
+
|
| 1071 |
+
elif mode == "headless":
|
| 1072 |
+
console.print(f"[bold]Processing query in headless mode:[/bold] {query}")
|
| 1073 |
+
|
| 1074 |
+
# Create agent for headless mode
|
| 1075 |
+
agent = create_agent(model_id)
|
| 1076 |
|
| 1077 |
+
# Show a simple spinner during processing
|
| 1078 |
+
with console.status("[bold green]Processing query...[/bold green]"):
|
| 1079 |
+
result = agent.run(query)
|
| 1080 |
+
|
| 1081 |
+
# Display the results
|
| 1082 |
+
console.print("\n[bold green]Results:[/bold green]")
|
| 1083 |
+
console.print(result)
|
| 1084 |
+
|
| 1085 |
+
else:
|
| 1086 |
+
console.print(
|
| 1087 |
+
f"[red]Error: Invalid mode '{mode}'. Use 'ui' or 'headless'[/red]"
|
| 1088 |
+
)
|
| 1089 |
+
raise typer.Exit(code=1)
|
| 1090 |
|
| 1091 |
|
| 1092 |
+
# ------------------------ Main Entry Point ------------------------
|
| 1093 |
if __name__ == "__main__":
|
| 1094 |
+
# Use the typer app as the entry point
|
| 1095 |
+
app()
|
flux_image.py
DELETED
|
File without changes
|
requirements.txt
CHANGED
|
@@ -1,13 +1,9 @@
|
|
|
|
|
| 1 |
anthropic>=0.37.1
|
| 2 |
beautifulsoup4>=4.12.3
|
| 3 |
-
Bio
|
| 4 |
-
chess
|
| 5 |
-
clean-text[gpl]
|
| 6 |
datasets>=2.21.0
|
| 7 |
google_search_results>=2.4.2
|
| 8 |
huggingface_hub>=0.23.4
|
| 9 |
-
llama-index
|
| 10 |
-
llama-index-embeddings-huggingface
|
| 11 |
mammoth>=1.8.0
|
| 12 |
markdownify>=0.13.1
|
| 13 |
numexpr>=2.10.1
|
|
@@ -19,25 +15,26 @@ pathvalidate>=3.2.1
|
|
| 19 |
pdfminer>=20191125
|
| 20 |
pdfminer.six>=20240706
|
| 21 |
Pillow>=11.0.0
|
| 22 |
-
pubchempy
|
| 23 |
puremagic>=1.28
|
| 24 |
-
|
| 25 |
-
PyPDF2
|
| 26 |
python-dotenv>=1.0.1
|
| 27 |
python_pptx>=1.0.2
|
| 28 |
-
python-pptx
|
| 29 |
Requests>=2.32.3
|
| 30 |
-
scikit-learn
|
| 31 |
-
scikit-learn
|
| 32 |
-
scipy
|
| 33 |
serpapi>=0.1.5
|
| 34 |
-
|
| 35 |
-
SpeechRecognition
|
| 36 |
-
sympy
|
| 37 |
torch>=2.2.2
|
| 38 |
torchvision>=0.17.2
|
| 39 |
-
tqdm>=4.66.4
|
| 40 |
-
tqdm
|
| 41 |
transformers>=4.46.0
|
| 42 |
-
xlrd
|
| 43 |
youtube_transcript_api>=0.6.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
smolagents[litellm, telemetry]
|
| 2 |
anthropic>=0.37.1
|
| 3 |
beautifulsoup4>=4.12.3
|
|
|
|
|
|
|
|
|
|
| 4 |
datasets>=2.21.0
|
| 5 |
google_search_results>=2.4.2
|
| 6 |
huggingface_hub>=0.23.4
|
|
|
|
|
|
|
| 7 |
mammoth>=1.8.0
|
| 8 |
markdownify>=0.13.1
|
| 9 |
numexpr>=2.10.1
|
|
|
|
| 15 |
pdfminer>=20191125
|
| 16 |
pdfminer.six>=20240706
|
| 17 |
Pillow>=11.0.0
|
|
|
|
| 18 |
puremagic>=1.28
|
| 19 |
+
pypdf>=5.1.0
|
|
|
|
| 20 |
python-dotenv>=1.0.1
|
| 21 |
python_pptx>=1.0.2
|
|
|
|
| 22 |
Requests>=2.32.3
|
|
|
|
|
|
|
|
|
|
| 23 |
serpapi>=0.1.5
|
| 24 |
+
tqdm>=4.66.4
|
|
|
|
|
|
|
| 25 |
torch>=2.2.2
|
| 26 |
torchvision>=0.17.2
|
|
|
|
|
|
|
| 27 |
transformers>=4.46.0
|
|
|
|
| 28 |
youtube_transcript_api>=0.6.2
|
| 29 |
+
chess
|
| 30 |
+
sympy
|
| 31 |
+
pubchempy
|
| 32 |
+
Bio
|
| 33 |
+
scikit-learn
|
| 34 |
+
scipy
|
| 35 |
+
pydub
|
| 36 |
+
PyPDF2
|
| 37 |
+
python-pptx
|
| 38 |
+
torch
|
| 39 |
+
xlrd
|
| 40 |
+
SpeechRecognition
|
scripts/cookies.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from requests.cookies import RequestsCookieJar
|
| 2 |
|
| 3 |
-
|
| 4 |
COOKIES_LIST = [
|
| 5 |
{
|
| 6 |
"domain": ".youtube.com",
|
|
@@ -712,4 +711,6 @@ COOKIES = RequestsCookieJar()
|
|
| 712 |
|
| 713 |
# Add cookies to the jar
|
| 714 |
for cookie in COOKIES_LIST:
|
| 715 |
-
COOKIES.set(
|
|
|
|
|
|
|
|
|
| 1 |
from requests.cookies import RequestsCookieJar
|
| 2 |
|
|
|
|
| 3 |
COOKIES_LIST = [
|
| 4 |
{
|
| 5 |
"domain": ".youtube.com",
|
|
|
|
| 711 |
|
| 712 |
# Add cookies to the jar
|
| 713 |
for cookie in COOKIES_LIST:
|
| 714 |
+
COOKIES.set(
|
| 715 |
+
cookie["name"], cookie["value"], domain=cookie["domain"], path=cookie["path"]
|
| 716 |
+
)
|
scripts/finance_tools.py
ADDED
|
@@ -0,0 +1,987 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2024 The Footscray Coding Collective. All rights reserved.
|
| 4 |
+
"""
|
| 5 |
+
Financial Data and Analysis Tools
|
| 6 |
+
--------------------------------------
|
| 7 |
+
A comprehensive suite of tools for retrieving financial market data through the Alpha Vantage API.
|
| 8 |
+
These tools enable accessing real-time stock quotes, company fundamentals, financial statements,
|
| 9 |
+
price history, market news, and sentiment analysis with proper error handling and caching.
|
| 10 |
+
|
| 11 |
+
The Alpha Vantage tools follow the Zhou Protocol for financial data retrieval:
|
| 12 |
+
- Singleton pattern for API client management
|
| 13 |
+
- Comprehensive error handling with failed request tracking
|
| 14 |
+
- In-memory request caching to minimize API usage
|
| 15 |
+
- Detailed docstrings with usage examples
|
| 16 |
+
|
| 17 |
+
Key Financial Tools:
|
| 18 |
+
- search_symbols: Find ticker symbols for companies by keywords
|
| 19 |
+
- get_stock_quote_data: Real-time stock quote information
|
| 20 |
+
- get_company_overview_data: Company profiles and fundamentals
|
| 21 |
+
- get_earnings_data: Quarterly and annual earnings information
|
| 22 |
+
- get_income_statement_data: Income statement analysis
|
| 23 |
+
- get_balance_sheet_data: Balance sheet information
|
| 24 |
+
- get_cash_flow_data: Cash flow statement analysis
|
| 25 |
+
- get_time_series_daily: Historical price and volume data
|
| 26 |
+
- get_market_news_sentiment: News and sentiment analysis
|
| 27 |
+
|
| 28 |
+
Financial Analysis Tools:
|
| 29 |
+
- FinancialCalculatorTool: Calculate financial metrics (growth rates, margins, CAGR)
|
| 30 |
+
- DataVisualizationTool: Generate visual representations of financial data
|
| 31 |
+
- TrendAnalysisTool: Perform year-over-year trend analysis on financial metrics
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import io
|
| 35 |
+
import logging
|
| 36 |
+
import os
|
| 37 |
+
import traceback
|
| 38 |
+
from typing import Any, Dict, Optional, Set
|
| 39 |
+
|
| 40 |
+
# Third-party imports in alphabetical order with dotenv first
|
| 41 |
+
try:
|
| 42 |
+
from dotenv import load_dotenv
|
| 43 |
+
|
| 44 |
+
load_dotenv()
|
| 45 |
+
except ImportError:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
import matplotlib.pyplot as plt # Plot the chart
|
| 49 |
+
import pandas as pd # Store dataframe
|
| 50 |
+
import requests
|
| 51 |
+
from smolagents import Tool, tool
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class AlphaVantageClient:
|
| 55 |
+
"""Centralized client for Alpha Vantage API requests with caching and error handling."""
|
| 56 |
+
|
| 57 |
+
def __init__(self):
|
| 58 |
+
"""Initialize the client with empty caches."""
|
| 59 |
+
self._api_key: Optional[str] = None
|
| 60 |
+
self._failed_requests: Set[str] = set()
|
| 61 |
+
self._data_cache: Dict[str, Dict[str, Any]] = {}
|
| 62 |
+
|
| 63 |
+
def get_api_key(self) -> str:
|
| 64 |
+
"""
|
| 65 |
+
Get Alpha Vantage API key from environment or cache.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
API key string or error message
|
| 69 |
+
"""
|
| 70 |
+
if self._api_key:
|
| 71 |
+
return self._api_key
|
| 72 |
+
|
| 73 |
+
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
| 74 |
+
if not api_key:
|
| 75 |
+
return "Error: No API key found. Set ALPHA_VANTAGE_API_KEY in your environment."
|
| 76 |
+
|
| 77 |
+
self._api_key = api_key
|
| 78 |
+
return api_key
|
| 79 |
+
|
| 80 |
+
def make_request(self, function: str, symbol: str, **params: Any) -> Dict[str, Any]:
|
| 81 |
+
"""
|
| 82 |
+
Make a request to Alpha Vantage API with error handling and caching.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
function (str): API function name
|
| 86 |
+
symbol (str): Stock symbol
|
| 87 |
+
**params (Any): Additional parameters for the request, excluding 'function' and 'symbol'
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Dict[str, Any]: Raw JSON response data
|
| 91 |
+
"""
|
| 92 |
+
# Validate params
|
| 93 |
+
if "function" in params or "symbol" in params:
|
| 94 |
+
raise ValueError("function and symbol should not be included in params")
|
| 95 |
+
|
| 96 |
+
# Generate cache key
|
| 97 |
+
cache_key = f"{function}:{symbol}:{hash(frozenset(params.items()))}"
|
| 98 |
+
|
| 99 |
+
# Return cached data if available
|
| 100 |
+
if cache_key in self._data_cache:
|
| 101 |
+
return self._data_cache[cache_key]
|
| 102 |
+
|
| 103 |
+
# Check if this request has failed before
|
| 104 |
+
if cache_key in self._failed_requests:
|
| 105 |
+
return {
|
| 106 |
+
"Error": f"Previously failed request for {symbol} with function {function}"
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Get API key
|
| 110 |
+
api_key = self.get_api_key()
|
| 111 |
+
if api_key.startswith("Error:"):
|
| 112 |
+
return {"Error Message": api_key}
|
| 113 |
+
|
| 114 |
+
# Build request URL and parameters
|
| 115 |
+
url = "https://www.alphavantage.co/query"
|
| 116 |
+
request_params = {
|
| 117 |
+
"function": function,
|
| 118 |
+
"symbol": symbol,
|
| 119 |
+
"apikey": api_key,
|
| 120 |
+
**params,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
# Make request with timeout for responsiveness
|
| 125 |
+
response = requests.get(url, params=request_params, timeout=10)
|
| 126 |
+
response.raise_for_status()
|
| 127 |
+
data = response.json()
|
| 128 |
+
|
| 129 |
+
# Check for API errors
|
| 130 |
+
if "Error Message" in data or "Information" in data or not data:
|
| 131 |
+
self._failed_requests.add(cache_key)
|
| 132 |
+
return data
|
| 133 |
+
|
| 134 |
+
# Cache successful response
|
| 135 |
+
self._data_cache[cache_key] = data
|
| 136 |
+
return data
|
| 137 |
+
|
| 138 |
+
except requests.RequestException as e:
|
| 139 |
+
error_data = {"Error Message": f"API request failed: {str(e)}"}
|
| 140 |
+
self._failed_requests.add(cache_key)
|
| 141 |
+
return error_data
|
| 142 |
+
except ValueError as e:
|
| 143 |
+
error_data = {"Error Message": f"Failed to parse response: {str(e)}"}
|
| 144 |
+
self._failed_requests.add(cache_key)
|
| 145 |
+
return error_data
|
| 146 |
+
|
| 147 |
+
def clear_cache(
|
| 148 |
+
self, function: Optional[str] = None, symbol: Optional[str] = None
|
| 149 |
+
) -> None:
|
| 150 |
+
"""
|
| 151 |
+
Clear the data cache, optionally filtering by function and/or symbol.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
function: Optional function name to filter cache entries
|
| 155 |
+
symbol: Optional symbol to filter cache entries
|
| 156 |
+
"""
|
| 157 |
+
if not function and not symbol:
|
| 158 |
+
self._data_cache.clear()
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
keys_to_remove = []
|
| 162 |
+
for key in self._data_cache:
|
| 163 |
+
parts = key.split(":")
|
| 164 |
+
if function and parts[0] != function:
|
| 165 |
+
continue
|
| 166 |
+
if symbol and parts[1] != symbol:
|
| 167 |
+
continue
|
| 168 |
+
keys_to_remove.append(key)
|
| 169 |
+
|
| 170 |
+
for key in keys_to_remove:
|
| 171 |
+
del self._data_cache[key]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# Create a singleton instance of the client
|
| 175 |
+
_client = AlphaVantageClient()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@tool
|
| 179 |
+
def get_stock_quote_data(symbol: str) -> Dict[str, Any]:
|
| 180 |
+
"""
|
| 181 |
+
Retrieve raw real-time stock quote information from Alpha Vantage.
|
| 182 |
+
|
| 183 |
+
This tool fetches current market data for a specified stock ticker,
|
| 184 |
+
returning the raw data for custom processing and analysis.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Raw JSON data containing:
|
| 191 |
+
- Global Quote object with price, volume, and trading information
|
| 192 |
+
- Error information if the request failed
|
| 193 |
+
|
| 194 |
+
Example:
|
| 195 |
+
```python
|
| 196 |
+
# Get raw quote data
|
| 197 |
+
data = get_stock_quote_data("MSFT")
|
| 198 |
+
|
| 199 |
+
# Extract price
|
| 200 |
+
if "Global Quote" in data:
|
| 201 |
+
quote = data["Global Quote"]
|
| 202 |
+
price = float(quote.get("05. price", 0))
|
| 203 |
+
change = float(quote.get("09. change", 0))
|
| 204 |
+
print(f"MSFT: ${price:.2f} ({change:+.2f})")
|
| 205 |
+
```
|
| 206 |
+
"""
|
| 207 |
+
return _client.make_request("GLOBAL_QUOTE", symbol)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@tool
|
| 211 |
+
def get_company_overview_data(symbol: str) -> Dict[str, Any]:
|
| 212 |
+
"""
|
| 213 |
+
Retrieve raw company information and metrics from Alpha Vantage.
|
| 214 |
+
|
| 215 |
+
This tool provides comprehensive information about a company, returning
|
| 216 |
+
raw data for custom analysis and presentation.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Raw JSON data containing:
|
| 223 |
+
- Company profile (name, sector, industry)
|
| 224 |
+
- Financial metrics (market cap, P/E ratio, etc.)
|
| 225 |
+
- Performance indicators (ROE, ROA, etc.)
|
| 226 |
+
- Company description
|
| 227 |
+
- Error information if the request failed
|
| 228 |
+
|
| 229 |
+
Example:
|
| 230 |
+
```python
|
| 231 |
+
# Get company data
|
| 232 |
+
data = get_company_overview_data("AAPL")
|
| 233 |
+
|
| 234 |
+
# Create custom analysis
|
| 235 |
+
if "Sector" in data:
|
| 236 |
+
sector = data.get("Sector")
|
| 237 |
+
market_cap = float(data.get("MarketCapitalization", 0))
|
| 238 |
+
pe_ratio = float(data.get("PERatio", 0))
|
| 239 |
+
|
| 240 |
+
print(f"AAPL is in the {sector} sector")
|
| 241 |
+
print(f"Market Cap: ${market_cap/1e9:.2f}B")
|
| 242 |
+
print(f"P/E Ratio: {pe_ratio:.2f}")
|
| 243 |
+
```
|
| 244 |
+
"""
|
| 245 |
+
return _client.make_request("OVERVIEW", symbol)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@tool
|
| 249 |
+
def get_earnings_data(symbol: str) -> Dict[str, Any]:
|
| 250 |
+
"""
|
| 251 |
+
Retrieve raw earnings data for a company from Alpha Vantage.
|
| 252 |
+
|
| 253 |
+
This tool fetches quarterly and annual earnings data, returning
|
| 254 |
+
raw information for custom analysis and trend evaluation.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
Raw JSON data containing:
|
| 261 |
+
- quarterlyEarnings array with fiscal dates, reported EPS, and surprises
|
| 262 |
+
- annualEarnings array with yearly EPS figures
|
| 263 |
+
- Error information if the request failed
|
| 264 |
+
|
| 265 |
+
Example:
|
| 266 |
+
```python
|
| 267 |
+
# Get earnings data
|
| 268 |
+
data = get_earnings_data("MSFT")
|
| 269 |
+
|
| 270 |
+
# Analyze earnings surprises
|
| 271 |
+
if "quarterlyEarnings" in data:
|
| 272 |
+
quarterly = data["quarterlyEarnings"]
|
| 273 |
+
|
| 274 |
+
# Calculate average earnings surprise percentage
|
| 275 |
+
surprises = [float(q.get("surprisePercentage", 0)) for q in quarterly[:4]]
|
| 276 |
+
avg_surprise = sum(surprises) / len(surprises)
|
| 277 |
+
|
| 278 |
+
print(f"Average earnings surprise (last 4Q): {avg_surprise:.2f}%")
|
| 279 |
+
|
| 280 |
+
# Find biggest positive surprise
|
| 281 |
+
max_surprise = max(surprises)
|
| 282 |
+
print(f"Largest positive surprise: {max_surprise:.2f}%")
|
| 283 |
+
```
|
| 284 |
+
"""
|
| 285 |
+
return _client.make_request("EARNINGS", symbol)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@tool
|
| 289 |
+
def get_income_statement_data(symbol: str) -> Dict[str, Any]:
|
| 290 |
+
"""
|
| 291 |
+
Retrieve raw income statement data for a company from Alpha Vantage.
|
| 292 |
+
|
| 293 |
+
This tool fetches annual and quarterly income statements, returning
|
| 294 |
+
raw financial data for custom analysis and profit trend evaluation.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Raw JSON data containing:
|
| 301 |
+
- annualReports array with yearly income statements
|
| 302 |
+
- quarterlyReports array with quarterly income statements
|
| 303 |
+
- Error information if the request failed
|
| 304 |
+
|
| 305 |
+
Example:
|
| 306 |
+
```python
|
| 307 |
+
# Get income statement data
|
| 308 |
+
data = get_income_statement_data("AAPL")
|
| 309 |
+
|
| 310 |
+
# Analyze profitability trends
|
| 311 |
+
if "annualReports" in data and len(data["annualReports"]) >= 3:
|
| 312 |
+
reports = data["annualReports"][:3] # Last 3 years
|
| 313 |
+
|
| 314 |
+
# Extract revenue and profit
|
| 315 |
+
revenues = [float(r.get("totalRevenue", 0)) for r in reports]
|
| 316 |
+
net_incomes = [float(r.get("netIncome", 0)) for r in reports]
|
| 317 |
+
|
| 318 |
+
# Calculate profit margins
|
| 319 |
+
margins = [ni/rev*100 if rev else 0 for ni, rev in zip(net_incomes, revenues)]
|
| 320 |
+
|
| 321 |
+
for i, margin in enumerate(margins):
|
| 322 |
+
year = reports[i].get("fiscalDateEnding", "Unknown")
|
| 323 |
+
print(f"{year}: Profit margin = {margin:.2f}%")
|
| 324 |
+
```
|
| 325 |
+
"""
|
| 326 |
+
return _client.make_request("INCOME_STATEMENT", symbol)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@tool
|
| 330 |
+
def get_balance_sheet_data(symbol: str) -> Dict[str, Any]:
|
| 331 |
+
"""
|
| 332 |
+
Retrieve raw balance sheet data for a company from Alpha Vantage.
|
| 333 |
+
|
| 334 |
+
This tool fetches annual and quarterly balance sheets, returning
|
| 335 |
+
raw financial data for custom analysis of a company's financial position.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
Raw JSON data containing:
|
| 342 |
+
- annualReports array with yearly balance sheets
|
| 343 |
+
- quarterlyReports array with quarterly balance sheets
|
| 344 |
+
- Error information if the request failed
|
| 345 |
+
|
| 346 |
+
Example:
|
| 347 |
+
```python
|
| 348 |
+
# Get balance sheet data
|
| 349 |
+
data = get_balance_sheet_data("MSFT")
|
| 350 |
+
|
| 351 |
+
# Calculate debt-to-equity ratio
|
| 352 |
+
if "annualReports" in data and data["annualReports"]:
|
| 353 |
+
latest = data["annualReports"][0]
|
| 354 |
+
|
| 355 |
+
total_debt = float(latest.get("shortTermDebt", 0)) + float(latest.get("longTermDebt", 0))
|
| 356 |
+
equity = float(latest.get("totalShareholderEquity", 0))
|
| 357 |
+
|
| 358 |
+
if equity:
|
| 359 |
+
debt_to_equity = total_debt / equity
|
| 360 |
+
print(f"Debt-to-Equity Ratio: {debt_to_equity:.2f}")
|
| 361 |
+
|
| 362 |
+
# Calculate current ratio
|
| 363 |
+
current_assets = float(latest.get("totalCurrentAssets", 0))
|
| 364 |
+
current_liabilities = float(latest.get("totalCurrentLiabilities", 0))
|
| 365 |
+
|
| 366 |
+
if current_liabilities:
|
| 367 |
+
current_ratio = current_assets / current_liabilities
|
| 368 |
+
print(f"Current Ratio: {current_ratio:.2f}")
|
| 369 |
+
```
|
| 370 |
+
"""
|
| 371 |
+
return _client.make_request("BALANCE_SHEET", symbol)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
@tool
|
| 375 |
+
def get_cash_flow_data(symbol: str) -> Dict[str, Any]:
|
| 376 |
+
"""
|
| 377 |
+
Retrieve raw cash flow statement data for a company from Alpha Vantage.
|
| 378 |
+
|
| 379 |
+
This tool fetches annual and quarterly cash flow statements, returning
|
| 380 |
+
raw financial data for analyzing a company's cash generation and usage.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
Raw JSON data containing:
|
| 387 |
+
- annualReports array with yearly cash flow statements
|
| 388 |
+
- quarterlyReports array with quarterly cash flow statements
|
| 389 |
+
- Error information if the request failed
|
| 390 |
+
|
| 391 |
+
Example:
|
| 392 |
+
```python
|
| 393 |
+
# Get cash flow data
|
| 394 |
+
data = get_cash_flow_data("AMZN")
|
| 395 |
+
|
| 396 |
+
# Analyze free cash flow
|
| 397 |
+
if "annualReports" in data and data["annualReports"]:
|
| 398 |
+
reports = data["annualReports"][:3] # Last 3 years
|
| 399 |
+
|
| 400 |
+
for report in reports:
|
| 401 |
+
year = report.get("fiscalDateEnding", "Unknown")
|
| 402 |
+
operating_cf = float(report.get("operatingCashflow", 0))
|
| 403 |
+
capex = float(report.get("capitalExpenditures", 0))
|
| 404 |
+
|
| 405 |
+
# Free cash flow = Operating cash flow - Capital expenditures
|
| 406 |
+
free_cf = operating_cf - abs(capex)
|
| 407 |
+
|
| 408 |
+
print(f"{year}: Free Cash Flow = ${free_cf/1e9:.2f}B")
|
| 409 |
+
```
|
| 410 |
+
"""
|
| 411 |
+
return _client.make_request("CASH_FLOW", symbol)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
@tool
|
| 415 |
+
def get_time_series_daily(symbol: str, outputsize: str = "compact") -> Dict[str, Any]:
|
| 416 |
+
"""
|
| 417 |
+
Retrieve daily time series stock price data from Alpha Vantage.
|
| 418 |
+
|
| 419 |
+
This tool fetches historical daily OHLCV (Open, High, Low, Close, Volume) data
|
| 420 |
+
for specified ticker symbols, supporting both compact (100 data points) and
|
| 421 |
+
full (20+ years) history.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
|
| 425 |
+
outputsize: Data size, either 'compact' (last 100 points) or 'full' (20+ years)
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
Raw JSON data containing:
|
| 429 |
+
- "Meta Data" object with information about the data series
|
| 430 |
+
- "Time Series (Daily)" object with date-keyed OHLCV data points
|
| 431 |
+
- Error information if the request failed
|
| 432 |
+
|
| 433 |
+
Example:
|
| 434 |
+
```python
|
| 435 |
+
# Get daily prices (compact = last 100 days)
|
| 436 |
+
data = get_time_series_daily("TSLA")
|
| 437 |
+
|
| 438 |
+
# Calculate moving averages
|
| 439 |
+
if "Time Series (Daily)" in data:
|
| 440 |
+
time_series = data["Time Series (Daily)"]
|
| 441 |
+
dates = sorted(time_series.keys())
|
| 442 |
+
|
| 443 |
+
# Extract closing prices
|
| 444 |
+
prices = [float(time_series[date]["4. close"]) for date in dates]
|
| 445 |
+
|
| 446 |
+
# Calculate 20-day moving average
|
| 447 |
+
if len(prices) >= 20:
|
| 448 |
+
ma_20 = sum(prices[-20:]) / 20
|
| 449 |
+
print(f"20-day Moving Average: ${ma_20:.2f}")
|
| 450 |
+
|
| 451 |
+
# Get latest price
|
| 452 |
+
latest_price = prices[-1]
|
| 453 |
+
print(f"Latest price: ${latest_price:.2f}")
|
| 454 |
+
|
| 455 |
+
# Compare to moving average
|
| 456 |
+
diff_pct = (latest_price / ma_20 - 1) * 100
|
| 457 |
+
print(f"Price is {diff_pct:+.2f}% from 20-day MA")
|
| 458 |
+
```
|
| 459 |
+
"""
|
| 460 |
+
return _client.make_request("TIME_SERIES_DAILY", symbol, outputsize=outputsize)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
# Ensure that the default value IS specified
|
| 464 |
+
@tool
|
| 465 |
+
def search_symbols(keywords: str) -> Dict[str, Any]:
|
| 466 |
+
"""
|
| 467 |
+
[FINANCIAL DISCOVERY] Search for stock symbols matching the provided keywords.
|
| 468 |
+
|
| 469 |
+
WHEN TO USE: ALWAYS use this tool FIRST when you don't know the exact stock symbol for a company.
|
| 470 |
+
|
| 471 |
+
This tool helps find relevant ticker symbols when you don't know the exact symbol,
|
| 472 |
+
matching companies by name, description, or partial symbols.
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
keywords: Search term (e.g., 'microsoft', 'tech', 'MSFT')
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
Raw JSON data containing:
|
| 479 |
+
- bestMatches array with matching companies (symbol, name, type, region)
|
| 480 |
+
- Error information if the request failed
|
| 481 |
+
|
| 482 |
+
Example:
|
| 483 |
+
```python
|
| 484 |
+
# Search for companies related to "electric vehicles"
|
| 485 |
+
results = search_symbols("electric vehicles")
|
| 486 |
+
|
| 487 |
+
# Print matched symbols and names
|
| 488 |
+
if "bestMatches" in results:
|
| 489 |
+
matches = results["bestMatches"]
|
| 490 |
+
|
| 491 |
+
print(f"Found {len(matches)} matches:")
|
| 492 |
+
for match in matches:
|
| 493 |
+
symbol = match.get("1. symbol", "")
|
| 494 |
+
name = match.get("2. name", "")
|
| 495 |
+
market = match.get("4. region", "")
|
| 496 |
+
|
| 497 |
+
print(f"{symbol} - {name} ({market})")
|
| 498 |
+
```
|
| 499 |
+
"""
|
| 500 |
+
return _client.make_request("SYMBOL_SEARCH", "", keywords=keywords)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
@tool
|
| 504 |
+
def clear_api_cache() -> str:
|
| 505 |
+
"""
|
| 506 |
+
Clear all cached API data to force fresh requests.
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
Confirmation message
|
| 510 |
+
"""
|
| 511 |
+
_client._data_cache.clear()
|
| 512 |
+
return "API cache cleared successfully."
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
@tool
|
| 516 |
+
def get_market_news_sentiment(
|
| 517 |
+
tickers: Optional[str] = None,
|
| 518 |
+
topics: Optional[str] = None,
|
| 519 |
+
time_from: Optional[str] = None,
|
| 520 |
+
time_to: Optional[str] = None,
|
| 521 |
+
sort: str = "LATEST",
|
| 522 |
+
limit: int = 50,
|
| 523 |
+
) -> Dict[str, Any]:
|
| 524 |
+
"""
|
| 525 |
+
Retrieve market news and sentiment data from Alpha Vantage.
|
| 526 |
+
|
| 527 |
+
This tool fetches live and historical market news with sentiment analysis from premier
|
| 528 |
+
news outlets worldwide, covering stocks, cryptocurrencies, forex, and various market topics.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
tickers: Optional comma-separated list of symbols (e.g., 'AAPL,MSFT' or 'COIN,CRYPTO:BTC,FOREX:USD')
|
| 532 |
+
topics: Optional comma-separated list of news topics (e.g., 'technology,ipo')
|
| 533 |
+
Available topics: blockchain, earnings, ipo, mergers_and_acquisitions, financial_markets,
|
| 534 |
+
economy_fiscal, economy_monetary, economy_macro, energy_transportation, finance,
|
| 535 |
+
life_sciences, manufacturing, real_estate, retail_wholesale, technology
|
| 536 |
+
time_from: Optional start time in YYYYMMDDTHHMM format (e.g., '20220410T0130')
|
| 537 |
+
time_to: Optional end time in YYYYMMDDTHHMM format
|
| 538 |
+
sort: Sorting order - 'LATEST' (default), 'EARLIEST', or 'RELEVANCE'
|
| 539 |
+
limit: Maximum number of results to return (default: 50, max: 1000)
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
Raw JSON data containing:
|
| 543 |
+
- feed: Array of news articles with title, summary, url, time_published, authors, and more
|
| 544 |
+
- sentiment scores for each article (if available)
|
| 545 |
+
- Error information if the request failed
|
| 546 |
+
|
| 547 |
+
Example:
|
| 548 |
+
```python
|
| 549 |
+
# Get latest news about Apple
|
| 550 |
+
apple_news = get_market_news_sentiment(tickers="AAPL")
|
| 551 |
+
|
| 552 |
+
# Get news articles at the intersection of technology and IPOs
|
| 553 |
+
tech_ipo_news = get_market_news_sentiment(topics="technology,ipo")
|
| 554 |
+
|
| 555 |
+
# Get Bitcoin news from a specific time period
|
| 556 |
+
btc_news = get_market_news_sentiment(
|
| 557 |
+
tickers="CRYPTO:BTC",
|
| 558 |
+
time_from="20230101T0000",
|
| 559 |
+
time_to="20230201T0000"
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Process the sentiment data
|
| 563 |
+
if "feed" in apple_news:
|
| 564 |
+
for article in apple_news["feed"]:
|
| 565 |
+
title = article.get("title", "No title")
|
| 566 |
+
sentiment = article.get("overall_sentiment_score", "N/A")
|
| 567 |
+
print(f"Article: {title} | Sentiment: {sentiment}")
|
| 568 |
+
```
|
| 569 |
+
"""
|
| 570 |
+
params = {
|
| 571 |
+
"function": "NEWS_SENTIMENT",
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
# Add optional parameters
|
| 575 |
+
if tickers:
|
| 576 |
+
params["tickers"] = tickers
|
| 577 |
+
if topics:
|
| 578 |
+
params["topics"] = topics
|
| 579 |
+
if time_from:
|
| 580 |
+
params["time_from"] = time_from
|
| 581 |
+
if time_to:
|
| 582 |
+
params["time_to"] = time_to
|
| 583 |
+
if sort:
|
| 584 |
+
params["sort"] = sort
|
| 585 |
+
if limit:
|
| 586 |
+
params["limit"] = limit
|
| 587 |
+
|
| 588 |
+
return _client.make_request("NEWS_SENTIMENT", "", **params)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
"""Example functions to be used in the tools and called by the agent"""
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
class FinancialCalculatorTool(Tool):
|
| 595 |
+
"""
|
| 596 |
+
Performs various financial calculations, given structured data from a table.
|
| 597 |
+
Useful for calculating growth rates, financial ratios, and other key metrics.
|
| 598 |
+
The tool can directly perform calculations on the data for numerical answers.
|
| 599 |
+
"""
|
| 600 |
+
|
| 601 |
+
name = "financial_calculator"
|
| 602 |
+
description = """
|
| 603 |
+
Performs various financial calculations, given structured data from a table.
|
| 604 |
+
Useful for calculating growth rates, financial ratios, and other key metrics.
|
| 605 |
+
The tool can directly perform calculations on the data for numerical answers.
|
| 606 |
+
|
| 607 |
+
Input:
|
| 608 |
+
- `data` (str): A string representing table data (e.g., CSV, markdown table).
|
| 609 |
+
- `calculation_type` (str): The type of calculation to perform, such as 'growth_rate', 'profit_margin', 'debt_to_equity'.
|
| 610 |
+
- `year1`, `year2`, `metric` (str): Parameters for "growth", e.g., "2020", "2021", "Revenue".
|
| 611 |
+
- `year`, `revenue`, `netIncome`(str): Parameters for 'Profit_Margin', e.g. "2023", "10000", "1000".
|
| 612 |
+
- `year`, `totalDebt`, `totalEquity` (str): Parameters for 'Debt_To_Equity', e.g. "2023", "5000", "10000".
|
| 613 |
+
- `startYear`, `endYear`, `metric"(str): Parametes for "CAGR", e.g. "2020", "2025", "Revenue"
|
| 614 |
+
|
| 615 |
+
Output:
|
| 616 |
+
- `calculation_result` (str): The result of the financial calculation as a string, to two decimals points.
|
| 617 |
+
This ensures the agent can understand and utilize the output effectively.
|
| 618 |
+
"""
|
| 619 |
+
|
| 620 |
+
inputs = {
|
| 621 |
+
"data": {
|
| 622 |
+
"type": "string",
|
| 623 |
+
"description": "A string representing table data. Must be in CSV format with a header row.",
|
| 624 |
+
},
|
| 625 |
+
"calculation_type": {
|
| 626 |
+
"type": "string",
|
| 627 |
+
"description": "The type of calculation to perform. Must be one of the following exactly: 'growth_rate', 'profit_margin', 'debt_to_equity', 'CAGR'.",
|
| 628 |
+
},
|
| 629 |
+
"year1": {
|
| 630 |
+
"type": "string",
|
| 631 |
+
"description": "Year 1 for growth rate calculation, as a string.",
|
| 632 |
+
"nullable": True,
|
| 633 |
+
},
|
| 634 |
+
"metric": {
|
| 635 |
+
"type": "string",
|
| 636 |
+
"description": "Valid CSV Header to compare, for growth. MUST correspond to the appropriate header in dataset.",
|
| 637 |
+
"nullable": True,
|
| 638 |
+
},
|
| 639 |
+
"year2": {
|
| 640 |
+
"type": "string",
|
| 641 |
+
"description": "Year 2 for growth rate calculation, as a string. Make sure that is a valid CSV Header.",
|
| 642 |
+
"nullable": True,
|
| 643 |
+
},
|
| 644 |
+
"revenue": {
|
| 645 |
+
"type": "string",
|
| 646 |
+
"description": "Revenue for the fiscal year profit calculation (as a string).",
|
| 647 |
+
"nullable": True,
|
| 648 |
+
},
|
| 649 |
+
"netIncome": {
|
| 650 |
+
"type": "string",
|
| 651 |
+
"description": "Must be Valid Valid Net income for the fiscal year profit margin calculation, in string format",
|
| 652 |
+
"nullable": True,
|
| 653 |
+
},
|
| 654 |
+
"endYear": {
|
| 655 |
+
"type": "string",
|
| 656 |
+
"description": "Year 2 string for the CAGR function",
|
| 657 |
+
"nullable": True,
|
| 658 |
+
},
|
| 659 |
+
"year": {
|
| 660 |
+
"type": "string",
|
| 661 |
+
"description": "Valid Year",
|
| 662 |
+
"nullable": True,
|
| 663 |
+
},
|
| 664 |
+
"startYear": {
|
| 665 |
+
"type": "string",
|
| 666 |
+
"description": "Year 1, string for the CAGR function",
|
| 667 |
+
"nullable": True,
|
| 668 |
+
},
|
| 669 |
+
"totalAssets": {
|
| 670 |
+
"type": "string",
|
| 671 |
+
"description": "The Total assets data in string format",
|
| 672 |
+
"nullable": True,
|
| 673 |
+
},
|
| 674 |
+
"totalDebt": {
|
| 675 |
+
"type": "string",
|
| 676 |
+
"description": "The total debt data in string.",
|
| 677 |
+
"nullable": True,
|
| 678 |
+
},
|
| 679 |
+
"totalEquity": {
|
| 680 |
+
"type": "string",
|
| 681 |
+
"description": "The Total Shareholders Equity in string format",
|
| 682 |
+
"nullable": True,
|
| 683 |
+
},
|
| 684 |
+
}
|
| 685 |
+
output_type = "string"
|
| 686 |
+
|
| 687 |
+
def forward(
|
| 688 |
+
self,
|
| 689 |
+
data: str, # A string representing the data. Must be a valid CSV
|
| 690 |
+
calculation_type: str, # type of calculation you'd like to do with the data
|
| 691 |
+
year1: Optional[str] = None, # Year1, all string types
|
| 692 |
+
metric: Optional[str] = None, # metric, all string types
|
| 693 |
+
year2: Optional[str] = None, # Year2, all string types
|
| 694 |
+
revenue: Optional[str] = None, # Revenue, all string types
|
| 695 |
+
netIncome: Optional[str] = None, # Net income, all string types
|
| 696 |
+
endYear: Optional[str] = None, # Year 2 string for the CAGR function
|
| 697 |
+
year: Optional[str] = None, # Valid Year
|
| 698 |
+
startYear: Optional[str] = None, # Year 1, string for the CAGR function
|
| 699 |
+
totalAssets: Optional[str] = None, # The Total assets data in string format
|
| 700 |
+
totalDebt: Optional[str] = None, # The total debt data in string.
|
| 701 |
+
totalEquity: Optional[
|
| 702 |
+
str
|
| 703 |
+
] = None, # The Total Shareholders Equity in string format
|
| 704 |
+
) -> str:
|
| 705 |
+
"""
|
| 706 |
+
Performs the specified financial calculation.
|
| 707 |
+
Args:
|
| 708 |
+
data: A string representing the dat. Must be a valid CSV
|
| 709 |
+
calculation_type: type of calculation you'd like to do with the data
|
| 710 |
+
year1: Year1, all string types
|
| 711 |
+
year2: Year2, all string types
|
| 712 |
+
metric: metric, all string types
|
| 713 |
+
|
| 714 |
+
Returns:
|
| 715 |
+
A string representing the result of the calculation. If an error occurs, the string will start with "Error: "
|
| 716 |
+
"""
|
| 717 |
+
try:
|
| 718 |
+
df = pd.read_csv(io.StringIO(data))
|
| 719 |
+
except Exception as e:
|
| 720 |
+
return f"Error reading data: {e}. Ensure that the input provided is a valid csv, AND has headers (no comments or empty rows)."
|
| 721 |
+
|
| 722 |
+
try:
|
| 723 |
+
if calculation_type == "growth_rate":
|
| 724 |
+
if not (year1 and year2 and metric):
|
| 725 |
+
return "Error: Missing year1, year2, or metric for growth_rate calculation."
|
| 726 |
+
|
| 727 |
+
value1 = df.loc[df["Year"] == year1][metric].values[0]
|
| 728 |
+
value2 = df.loc[df["Year"] == year2][metric].values[0]
|
| 729 |
+
|
| 730 |
+
growth_rate = ((value2 - value1) / value1) * 100
|
| 731 |
+
return f"{growth_rate:.2f}%"
|
| 732 |
+
|
| 733 |
+
elif calculation_type == "profit_margin":
|
| 734 |
+
if not year or not revenue or not netIncome:
|
| 735 |
+
return "Error: Missing year for profit_margin calculation"
|
| 736 |
+
|
| 737 |
+
# revenue = df.loc[df['Year'] == year]['Revenue'].values[0] # Replace with your actual data columns
|
| 738 |
+
# net_income = df.loc[df['Year'] == year]['Net Income'].values[0] # This can also be EBIT or operating profit or whatever
|
| 739 |
+
|
| 740 |
+
profit_margin = (float(netIncome) / float(revenue)) * 100
|
| 741 |
+
return f"{profit_margin:.2f}%"
|
| 742 |
+
|
| 743 |
+
elif calculation_type == "debt_to_equity":
|
| 744 |
+
if not year or not totalDebt or not totalEquity:
|
| 745 |
+
return "Error: Missing year for debt_to_equity calculation"
|
| 746 |
+
|
| 747 |
+
# total_debt = df.loc[df['Year'] == year]['Total Debt'].values[0] # Could be short term or long term
|
| 748 |
+
# total_equity = df.loc[df['Year'] == year]['Total Equity'].values[0] # Could be share holders equity?
|
| 749 |
+
|
| 750 |
+
debt_to_equity = float(totalDebt) / float(totalEquity)
|
| 751 |
+
return f"{debt_to_equity:.2f}"
|
| 752 |
+
elif calculation_type == "CAGR":
|
| 753 |
+
|
| 754 |
+
if not (startYear and endYear and metric):
|
| 755 |
+
return "Error: Missing startYear, endYear, or metric for CAGR calculation."
|
| 756 |
+
|
| 757 |
+
try: # Make the CSV valid
|
| 758 |
+
start_value = float(
|
| 759 |
+
df[df["Year"] == startYear][metric].values[0]
|
| 760 |
+
) # float(start_value) #df[df.columns[1]] #["Start Value"].values[0]
|
| 761 |
+
end_value = float(
|
| 762 |
+
df[df["Year"] == endYear][metric].values[0]
|
| 763 |
+
) # float(end_value) # float(raw[0]) #df[df.columns[1]] #["End Value"].values[0]# CSV
|
| 764 |
+
except Exception as exception:
|
| 765 |
+
return f"start value {df[df['Year'] == startYear][metric].values[0]} endvalue {df[df['Year'] == endYear][metric].values[0]}. start and end values are not valid headers! Ensure CSV Headers are there, and they're valid. OriginalException{exception}"
|
| 766 |
+
try: # check to confirm the calculations work by converting them to float
|
| 767 |
+
n = int(endYear) - int(startYear)
|
| 768 |
+
cagr = (end_value / start_value) ** (1 / n) - 1
|
| 769 |
+
return f"{cagr:.2f}" # f"EndValue {endYear2:.2f} Startvalue {startYear2:.2f}"
|
| 770 |
+
except Exception:
|
| 771 |
+
return f"start year {startYear} end year {endYear} Startvalue {start_value} end value {end_value}. Year calcs invalid! Invalid CSV"
|
| 772 |
+
|
| 773 |
+
else:
|
| 774 |
+
return f"Error: Unsupported Calculation Type: {calculation_type}. Consider growth_rate, profit_margin, debt_to_equity, CAGR."
|
| 775 |
+
except Exception as e:
|
| 776 |
+
return f"Error performing calculation: {e}"
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
class DataVisualizationTool(Tool):
|
| 780 |
+
"""
|
| 781 |
+
Generates visualizations (charts, graphs) from structured data to help identify trends.
|
| 782 |
+
Be thoughtful about the data AND type of graph: they must match.
|
| 783 |
+
You CANNOT import things other than csv, so make sure to follow the instructions.
|
| 784 |
+
"""
|
| 785 |
+
|
| 786 |
+
name = "data_visualization"
|
| 787 |
+
description = """
|
| 788 |
+
Generates visualizations (charts, graphs) from structured data to help identify trends. Be thoughtful about the data AND type of graph: they must match. You CANNOT import things other than csv, so make sure to follow the instructions.
|
| 789 |
+
|
| 790 |
+
Input:
|
| 791 |
+
- `data` (str): A valid CSV string, that represents values to graph: MUST start with a HEADER row, then be followed by valid csv syntax
|
| 792 |
+
- `chart_type` (str): The type of chart/graph to generate, MUST be one of: 'line', 'bar', 'scatter'.
|
| 793 |
+
- `x_axis_label` (str): Label for the x axis. If unsure, set as "years"
|
| 794 |
+
- `y_axis_label` (str): Label for the y axis. If unsure, set as "net income"
|
| 795 |
+
|
| 796 |
+
Output:
|
| 797 |
+
- `plot_string` (str): A verbal description of the plot, especially its overall trend. A short trend is sufficient.
|
| 798 |
+
|
| 799 |
+
"""
|
| 800 |
+
inputs = {
|
| 801 |
+
"data": {
|
| 802 |
+
"type": "string",
|
| 803 |
+
"description": "CSV data representing a time series: Start this with headers followed by values!!",
|
| 804 |
+
},
|
| 805 |
+
"chart_type": {
|
| 806 |
+
"type": "string",
|
| 807 |
+
"description": "Type of chart to generate (e.g., MUST be one of 'line', 'bar', 'scatter').",
|
| 808 |
+
},
|
| 809 |
+
"x_axis_label": {
|
| 810 |
+
"type": "string",
|
| 811 |
+
"description": "Label of x-axis, such as 'years' or 'quarters'",
|
| 812 |
+
},
|
| 813 |
+
"y_axis_label": {
|
| 814 |
+
"type": "string",
|
| 815 |
+
"description": "Label of y-axis, such as 'net income' or 'revenue'",
|
| 816 |
+
},
|
| 817 |
+
}
|
| 818 |
+
output_type = "string"
|
| 819 |
+
|
| 820 |
+
def forward(
|
| 821 |
+
self, data: str, chart_type: str, x_axis_label: str, y_axis_label: str
|
| 822 |
+
) -> str:
|
| 823 |
+
"""
|
| 824 |
+
Perform chart visuals
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
data (str): string CSV in the correct format
|
| 828 |
+
chart_type (str): one of scatter, line, bar
|
| 829 |
+
x_axis_label (str): label
|
| 830 |
+
y_axis_label (str): label
|
| 831 |
+
|
| 832 |
+
Returns:
|
| 833 |
+
str: A verbal description of the plot, especially its overall trend.
|
| 834 |
+
"""
|
| 835 |
+
if not data:
|
| 836 |
+
return "Error: No data provided."
|
| 837 |
+
if not chart_type:
|
| 838 |
+
return "Error: No chart."
|
| 839 |
+
if not x_axis_label:
|
| 840 |
+
return "Error: No x-axis label provided."
|
| 841 |
+
if not y_axis_label:
|
| 842 |
+
return "Error: No y-axis label provided."
|
| 843 |
+
try:
|
| 844 |
+
df = pd.read_csv(io.StringIO(data))
|
| 845 |
+
except Exception as e:
|
| 846 |
+
return f"Problem building data {data}: {e}"
|
| 847 |
+
if len(df.columns) < 2:
|
| 848 |
+
return "Error: Data must have at least two columns."
|
| 849 |
+
try:
|
| 850 |
+
plt.figure(figsize=(10, 6)) # Adjust the figure size for better readability
|
| 851 |
+
if chart_type == "line":
|
| 852 |
+
plt.xlabel(x_axis_label)
|
| 853 |
+
plt.ylabel(y_axis_label)
|
| 854 |
+
plt.plot(
|
| 855 |
+
df[df.columns[0]], df[df.columns[1]]
|
| 856 |
+
) # [df.columns[0]], df[df.columns[1]]
|
| 857 |
+
elif chart_type == "bar":
|
| 858 |
+
plt.ylabel(y_axis_label)
|
| 859 |
+
plt.xlabel(x_axis_label)
|
| 860 |
+
plt.bar(df[df.columns[0]], df[df.columns[1]]) # .values[0]
|
| 861 |
+
elif chart_type == "scatter":
|
| 862 |
+
plt.ylabel(y_axis_label)
|
| 863 |
+
plt.xlabel(x_axis_label)
|
| 864 |
+
plt.scatter(df[df.columns[0]], df[df.columns[1]]) # .values[0]
|
| 865 |
+
else:
|
| 866 |
+
raise ValueError(f"Unsupported chart type: {chart_type}")
|
| 867 |
+
chart_summary = f"Chart generated, which shows the {chart_type} of {df.columns[1]} with respect to {df.columns[0]}. "
|
| 868 |
+
plt.title(y_axis_label + " vs. " + x_axis_label) # What we're graphing
|
| 869 |
+
# plt.text(80000000000, 80000000000, chart_summary) # Show the chart summary
|
| 870 |
+
plt.show() # actually show the chart to the user, as above shows matplotlib backend
|
| 871 |
+
return chart_summary
|
| 872 |
+
except Exception as e:
|
| 873 |
+
return f"Problem with chart plotting: {e}" # chart_type = None
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
class TrendAnalysisTool(Tool):
|
| 877 |
+
"""
|
| 878 |
+
You can retrieve year over year increase percentages for a specific category by setting the category.
|
| 879 |
+
Please provide a valid CSV. MAKE SURE headers = columns, and that is in the correct format.
|
| 880 |
+
"""
|
| 881 |
+
|
| 882 |
+
name = "trend_analysis"
|
| 883 |
+
description = """
|
| 884 |
+
You can retrieve year over year increase percentages for a specific category by setting the category. Please provide a valid CSV. MAKE SURE headers = columns, and that is in the correct format.
|
| 885 |
+
"""
|
| 886 |
+
inputs = {
|
| 887 |
+
"data": {
|
| 888 |
+
"type": "string",
|
| 889 |
+
"description": "A string representing the data (e.g., CSV format) - MUST HAVE HEADERS. MUST specify all colums",
|
| 890 |
+
},
|
| 891 |
+
"category": {
|
| 892 |
+
"type": "string",
|
| 893 |
+
"description": "The category we want to compare, such as revenue. Check to know WHAT the name is!!",
|
| 894 |
+
},
|
| 895 |
+
}
|
| 896 |
+
output_type = "string"
|
| 897 |
+
|
| 898 |
+
def forward(self, data: str, category: str) -> str:
|
| 899 |
+
"""Make year over year increases for a given csv
|
| 900 |
+
Args:
|
| 901 |
+
data: all the data
|
| 902 |
+
category: the category we want to compare, such as revenue
|
| 903 |
+
"""
|
| 904 |
+
try:
|
| 905 |
+
df = pd.read_csv(io.StringIO(data))
|
| 906 |
+
except Exception as e:
|
| 907 |
+
return f"Error reading data: {e}. Ensure valid CSV, and headers are present: {e}!!"
|
| 908 |
+
try:
|
| 909 |
+
df["YoY Change"] = df[category].pct_change() * 100
|
| 910 |
+
df["YoY Change"] = df["YoY Change"].map("{:.2f}%".format)
|
| 911 |
+
change_description = df.to_string() #
|
| 912 |
+
return change_description
|
| 913 |
+
except Exception as e:
|
| 914 |
+
return f"Error with trend analysis: {e}. Check the name or data!!"
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
# ###########################
|
| 918 |
+
# # Example loading the tools:
|
| 919 |
+
# ###########################
|
| 920 |
+
|
| 921 |
+
# # def load_finance_tools():
|
| 922 |
+
# # finance_tools = [
|
| 923 |
+
# # get_stock_quote_data,
|
| 924 |
+
# # get_company_overview_data,
|
| 925 |
+
# # get_earnings_data,
|
| 926 |
+
# # get_income_statement_data,
|
| 927 |
+
# # get_balance_sheet_data,
|
| 928 |
+
# # get_cash_flow_data,
|
| 929 |
+
# # get_time_series_daily,
|
| 930 |
+
# # search_symbols,
|
| 931 |
+
# # DataVisualizationTool(),
|
| 932 |
+
# # FinancialCalculatorTool(),
|
| 933 |
+
# # TrendAnalysisTool()
|
| 934 |
+
# # ]
|
| 935 |
+
# # return finance_tools
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
def load_finance_tools():
|
| 939 |
+
"""Initialize and return finance tools for data retrieval and analysis.
|
| 940 |
+
You MUST put all the correct tools in here, or it will not run.
|
| 941 |
+
"""
|
| 942 |
+
|
| 943 |
+
finance_tools = []
|
| 944 |
+
# finance_tools_names = [] # was getting errors on loading
|
| 945 |
+
|
| 946 |
+
def safe_tool_load(tool_func, tool_name):
|
| 947 |
+
"""Helper to safely load and append a finance tool."""
|
| 948 |
+
try:
|
| 949 |
+
finance_tools.append(tool_func)
|
| 950 |
+
# finance_tools_names.append(tool_func.__name__) # was getting errors on loading
|
| 951 |
+
logging.info(f"Loaded {tool_name} tool successfully")
|
| 952 |
+
except Exception as e:
|
| 953 |
+
logging.error(f"Failed to load tool {tool_name}: {e}")
|
| 954 |
+
logging.error(traceback.format_exc()) # Print the stack trace
|
| 955 |
+
|
| 956 |
+
# Financial calculation tools first
|
| 957 |
+
safe_tool_load(DataVisualizationTool(), "DataVisualizationTool")
|
| 958 |
+
safe_tool_load(FinancialCalculatorTool(), "FinancialCalculatorTool")
|
| 959 |
+
safe_tool_load(TrendAnalysisTool(), "TrendAnalysisTool")
|
| 960 |
+
# Raw data retrieval tools last
|
| 961 |
+
safe_tool_load(get_stock_quote_data, "get_stock_quote_data")
|
| 962 |
+
safe_tool_load(get_company_overview_data, "get_company_overview_data")
|
| 963 |
+
safe_tool_load(get_earnings_data, "get_earnings_data")
|
| 964 |
+
safe_tool_load(get_income_statement_data, "get_income_statement_data")
|
| 965 |
+
safe_tool_load(get_balance_sheet_data, "get_balance_sheet_data")
|
| 966 |
+
safe_tool_load(get_cash_flow_data, "get_cash_flow_data")
|
| 967 |
+
safe_tool_load(get_time_series_daily, "get_time_series_daily")
|
| 968 |
+
safe_tool_load(search_symbols, "search_symbols")
|
| 969 |
+
safe_tool_load(get_market_news_sentiment, "get_market_news_sentiment")
|
| 970 |
+
|
| 971 |
+
return finance_tools
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
__all__ = [
|
| 975 |
+
"get_stock_quote_data",
|
| 976 |
+
"get_company_overview_data",
|
| 977 |
+
"get_earnings_data",
|
| 978 |
+
"get_income_statement_data",
|
| 979 |
+
"get_balance_sheet_data",
|
| 980 |
+
"get_cash_flow_data",
|
| 981 |
+
"get_time_series_daily",
|
| 982 |
+
"search_symbols",
|
| 983 |
+
"get_market_news_sentiment",
|
| 984 |
+
"DataVisualizationTool",
|
| 985 |
+
"FinancialCalculatorTool",
|
| 986 |
+
"TrendAnalysisTool",
|
| 987 |
+
]
|
scripts/flux_lora_tool.py
CHANGED
|
@@ -12,30 +12,28 @@ Usage:
|
|
| 12 |
agent = CodeAgent(tools=[flux_tool], ...)
|
| 13 |
"""
|
| 14 |
|
|
|
|
| 15 |
import os
|
| 16 |
-
import uuid
|
| 17 |
import tempfile
|
| 18 |
-
import
|
| 19 |
-
from typing import Dict, Any, Optional, List, Union, Tuple
|
| 20 |
from dataclasses import dataclass
|
| 21 |
-
import
|
| 22 |
-
from pathlib import Path
|
| 23 |
|
| 24 |
# Third-party
|
| 25 |
import requests
|
| 26 |
-
from PIL import Image
|
| 27 |
from gradio_client import Client
|
| 28 |
-
|
| 29 |
-
# Smolagents
|
| 30 |
from smolagents import Tool
|
| 31 |
|
| 32 |
# -----------------------------------------------------------------------------
|
| 33 |
# CONSTANTS AND TYPE DEFINITIONS
|
| 34 |
# -----------------------------------------------------------------------------
|
| 35 |
|
|
|
|
| 36 |
@dataclass
|
| 37 |
class LoRAModelInfo:
|
| 38 |
"""Value object representing LoRA model information."""
|
|
|
|
| 39 |
name: str
|
| 40 |
description: Optional[str] = None
|
| 41 |
example_image_url: Optional[str] = None
|
|
@@ -44,6 +42,7 @@ class LoRAModelInfo:
|
|
| 44 |
@dataclass
|
| 45 |
class ImageGenerationResult:
|
| 46 |
"""Value object representing a generated image result."""
|
|
|
|
| 47 |
image_path: str
|
| 48 |
seed: int
|
| 49 |
metadata: Optional[Dict[str, Any]] = None
|
|
@@ -53,14 +52,15 @@ class ImageGenerationResult:
|
|
| 53 |
# CORE TOOL IMPLEMENTATION
|
| 54 |
# -----------------------------------------------------------------------------
|
| 55 |
|
|
|
|
| 56 |
class FluxLoRATool(Tool):
|
| 57 |
"""
|
| 58 |
Tool for generating images using FLUX-LoRA-DLC API.
|
| 59 |
-
|
| 60 |
This tool implements the Zhou Protocol integration patterns to provide
|
| 61 |
a clean, efficient interface for image generation using LoRA models.
|
| 62 |
"""
|
| 63 |
-
|
| 64 |
name = "flux_lora_generator"
|
| 65 |
description = """
|
| 66 |
Generates high-quality images using FLUX-LoRA models.
|
|
@@ -68,74 +68,74 @@ class FluxLoRATool(Tool):
|
|
| 68 |
"""
|
| 69 |
inputs = {
|
| 70 |
"prompt": {
|
| 71 |
-
"type": "string",
|
| 72 |
-
"description": "Detailed description of the desired image."
|
| 73 |
},
|
| 74 |
"image_input": {
|
| 75 |
-
"type": "string",
|
| 76 |
"description": "Optional URL or file path to input image for img2img generation.",
|
| 77 |
-
"optional": True
|
| 78 |
},
|
| 79 |
"image_strength": {
|
| 80 |
"type": "float",
|
| 81 |
"description": "Strength of input image influence (0.0-1.0), where 1.0 maintains more of original image.",
|
| 82 |
"optional": True,
|
| 83 |
-
"default": 0.75
|
| 84 |
},
|
| 85 |
"cfg_scale": {
|
| 86 |
"type": "float",
|
| 87 |
"description": "Guidance scale for prompt adherence (1.0-30.0).",
|
| 88 |
"optional": True,
|
| 89 |
-
"default": 3.5
|
| 90 |
},
|
| 91 |
"steps": {
|
| 92 |
"type": "integer",
|
| 93 |
"description": "Number of sampling steps (10-100).",
|
| 94 |
"optional": True,
|
| 95 |
-
"default": 28
|
| 96 |
},
|
| 97 |
"seed": {
|
| 98 |
"type": "integer",
|
| 99 |
"description": "Random seed for reproducibility. Use -1 for random seed.",
|
| 100 |
"optional": True,
|
| 101 |
-
"default": -1
|
| 102 |
},
|
| 103 |
"width": {
|
| 104 |
"type": "integer",
|
| 105 |
"description": "Image width in pixels.",
|
| 106 |
"optional": True,
|
| 107 |
-
"default": 1024
|
| 108 |
},
|
| 109 |
"height": {
|
| 110 |
"type": "integer",
|
| 111 |
"description": "Image height in pixels.",
|
| 112 |
"optional": True,
|
| 113 |
-
"default": 1024
|
| 114 |
},
|
| 115 |
"lora_scale": {
|
| 116 |
"type": "float",
|
| 117 |
"description": "LoRA influence scale (0.0-1.0).",
|
| 118 |
"optional": True,
|
| 119 |
-
"default": 0.95
|
| 120 |
},
|
| 121 |
"custom_lora": {
|
| 122 |
"type": "string",
|
| 123 |
"description": "Custom LoRA model to use. Leave empty for default.",
|
| 124 |
-
"optional": True
|
| 125 |
-
}
|
| 126 |
}
|
| 127 |
output_type = "string"
|
| 128 |
-
|
| 129 |
def __init__(
|
| 130 |
-
self,
|
| 131 |
api_url: str = "xkerser/FLUX-LoRA-DLC",
|
| 132 |
image_save_dir: Optional[str] = None,
|
| 133 |
connection_timeout: int = 60,
|
| 134 |
-
verbose: bool = False
|
| 135 |
):
|
| 136 |
"""
|
| 137 |
Initialize the FLUX-LoRA Tool with Zhou Protocol connection patterns.
|
| 138 |
-
|
| 139 |
Args:
|
| 140 |
api_url: URL or endpoint ID for the FLUX-LoRA-DLC API
|
| 141 |
image_save_dir: Directory to save generated images (created if doesn't exist)
|
|
@@ -143,66 +143,67 @@ class FluxLoRATool(Tool):
|
|
| 143 |
verbose: Enable detailed logging
|
| 144 |
"""
|
| 145 |
super().__init__()
|
| 146 |
-
|
| 147 |
# Initialize logging
|
| 148 |
self.logger = logging.getLogger("flux_lora_tool")
|
| 149 |
self.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
| 150 |
-
|
| 151 |
# Set up client and storage directories
|
| 152 |
self.api_url = api_url
|
| 153 |
self.connection_timeout = connection_timeout
|
| 154 |
self._client = None # Lazy initialization
|
| 155 |
-
|
| 156 |
# Set up image storage directory
|
| 157 |
-
self.image_save_dir = image_save_dir or os.path.join(
|
|
|
|
|
|
|
| 158 |
os.makedirs(self.image_save_dir, exist_ok=True)
|
| 159 |
-
self.logger.info(
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
@property
|
| 162 |
def client(self) -> Client:
|
| 163 |
"""
|
| 164 |
Get or initialize the Gradio client with proper connection handling.
|
| 165 |
-
|
| 166 |
Returns:
|
| 167 |
Initialized Gradio client
|
| 168 |
-
|
| 169 |
Raises:
|
| 170 |
ConnectionError: If client initialization fails
|
| 171 |
"""
|
| 172 |
if self._client is None:
|
| 173 |
try:
|
| 174 |
-
self._client = Client(
|
| 175 |
-
self.api_url,
|
| 176 |
-
timeout=self.connection_timeout
|
| 177 |
-
)
|
| 178 |
self.logger.debug(f"Gradio client initialized for: {self.api_url}")
|
| 179 |
except Exception as e:
|
| 180 |
error_msg = f"Failed to initialize FLUX-LoRA client: {str(e)}"
|
| 181 |
self.logger.error(error_msg)
|
| 182 |
raise ConnectionError(error_msg) from e
|
| 183 |
-
|
| 184 |
return self._client
|
| 185 |
-
|
| 186 |
def _validate_inputs(self, **kwargs) -> Dict[str, Any]:
|
| 187 |
"""
|
| 188 |
Validate and normalize input parameters with Zhou Protocol validation patterns.
|
| 189 |
-
|
| 190 |
Args:
|
| 191 |
**kwargs: Input parameters
|
| 192 |
-
|
| 193 |
Returns:
|
| 194 |
Validated and normalized parameters
|
| 195 |
-
|
| 196 |
Raises:
|
| 197 |
ValueError: If input validation fails
|
| 198 |
"""
|
| 199 |
validated = {}
|
| 200 |
-
|
| 201 |
# Required parameter: prompt
|
| 202 |
if not kwargs.get("prompt"):
|
| 203 |
raise ValueError("Prompt is required for image generation")
|
| 204 |
validated["prompt"] = kwargs["prompt"]
|
| 205 |
-
|
| 206 |
# Image input handling
|
| 207 |
if "image_input" in kwargs and kwargs["image_input"]:
|
| 208 |
input_image = kwargs["image_input"]
|
|
@@ -215,7 +216,7 @@ class FluxLoRATool(Tool):
|
|
| 215 |
if not os.path.exists(input_image):
|
| 216 |
raise ValueError(f"Image file not found: {input_image}")
|
| 217 |
validated["image_input"] = input_image
|
| 218 |
-
|
| 219 |
# Numeric parameter validation with constraints
|
| 220 |
numeric_params = {
|
| 221 |
"image_strength": {"min": 0.0, "max": 1.0, "default": 0.75},
|
|
@@ -223,13 +224,13 @@ class FluxLoRATool(Tool):
|
|
| 223 |
"steps": {"min": 10, "max": 100, "default": 28},
|
| 224 |
"width": {"min": 128, "max": 2048, "default": 1024},
|
| 225 |
"height": {"min": 128, "max": 2048, "default": 1024},
|
| 226 |
-
"lora_scale": {"min": 0.0, "max": 1.0, "default": 0.95}
|
| 227 |
}
|
| 228 |
-
|
| 229 |
for param, constraints in numeric_params.items():
|
| 230 |
if param in kwargs and kwargs[param] is not None:
|
| 231 |
value = kwargs[param]
|
| 232 |
-
|
| 233 |
# Type conversion if needed
|
| 234 |
if param in ["steps", "width", "height"]:
|
| 235 |
try:
|
|
@@ -241,17 +242,17 @@ class FluxLoRATool(Tool):
|
|
| 241 |
value = float(value)
|
| 242 |
except (ValueError, TypeError):
|
| 243 |
raise ValueError(f"Parameter '{param}' must be a number")
|
| 244 |
-
|
| 245 |
# Range validation
|
| 246 |
if value < constraints["min"] or value > constraints["max"]:
|
| 247 |
raise ValueError(
|
| 248 |
f"Parameter '{param}' must be between {constraints['min']} and {constraints['max']}"
|
| 249 |
)
|
| 250 |
-
|
| 251 |
validated[param] = value
|
| 252 |
else:
|
| 253 |
validated[param] = constraints["default"]
|
| 254 |
-
|
| 255 |
# Special handling for seed
|
| 256 |
if "seed" in kwargs and kwargs["seed"] is not None:
|
| 257 |
try:
|
|
@@ -264,6 +265,7 @@ class FluxLoRATool(Tool):
|
|
| 264 |
self.logger.warning(f"Failed to get random seed from API: {e}")
|
| 265 |
# Fallback to Python's random
|
| 266 |
import random
|
|
|
|
| 267 |
seed = random.randint(0, 2**32 - 1)
|
| 268 |
validated["seed"] = seed
|
| 269 |
except (ValueError, TypeError):
|
|
@@ -271,57 +273,56 @@ class FluxLoRATool(Tool):
|
|
| 271 |
else:
|
| 272 |
# Default to random seed
|
| 273 |
validated["seed"] = self._get_random_seed()
|
| 274 |
-
|
| 275 |
# Custom LoRA handling
|
| 276 |
if "custom_lora" in kwargs and kwargs["custom_lora"]:
|
| 277 |
validated["custom_lora"] = kwargs["custom_lora"]
|
| 278 |
-
|
| 279 |
return validated
|
| 280 |
-
|
| 281 |
def _download_image(self, url: str) -> str:
|
| 282 |
"""
|
| 283 |
Download image from URL and save to local file.
|
| 284 |
-
|
| 285 |
Args:
|
| 286 |
url: Image URL
|
| 287 |
-
|
| 288 |
Returns:
|
| 289 |
Local file path
|
| 290 |
-
|
| 291 |
Raises:
|
| 292 |
ConnectionError: If download fails
|
| 293 |
"""
|
| 294 |
try:
|
| 295 |
response = requests.get(url, stream=True, timeout=30)
|
| 296 |
response.raise_for_status()
|
| 297 |
-
|
| 298 |
# Generate temporary file path
|
| 299 |
file_ext = self._guess_extension(response.headers.get("Content-Type", ""))
|
| 300 |
temp_path = os.path.join(
|
| 301 |
-
self.image_save_dir,
|
| 302 |
-
f"input_{uuid.uuid4().hex}{file_ext}"
|
| 303 |
)
|
| 304 |
-
|
| 305 |
# Save image
|
| 306 |
with open(temp_path, "wb") as f:
|
| 307 |
for chunk in response.iter_content(chunk_size=8192):
|
| 308 |
f.write(chunk)
|
| 309 |
-
|
| 310 |
self.logger.debug(f"Downloaded image from {url} to {temp_path}")
|
| 311 |
return temp_path
|
| 312 |
-
|
| 313 |
except Exception as e:
|
| 314 |
error_msg = f"Failed to download image from {url}: {str(e)}"
|
| 315 |
self.logger.error(error_msg)
|
| 316 |
raise ConnectionError(error_msg) from e
|
| 317 |
-
|
| 318 |
def _guess_extension(self, content_type: str) -> str:
|
| 319 |
"""
|
| 320 |
Guess file extension from content type.
|
| 321 |
-
|
| 322 |
Args:
|
| 323 |
content_type: HTTP Content-Type header
|
| 324 |
-
|
| 325 |
Returns:
|
| 326 |
File extension (with dot)
|
| 327 |
"""
|
|
@@ -336,14 +337,14 @@ class FluxLoRATool(Tool):
|
|
| 336 |
return ".gif"
|
| 337 |
else:
|
| 338 |
return ".png" # Default to PNG
|
| 339 |
-
|
| 340 |
def _get_random_seed(self) -> int:
|
| 341 |
"""
|
| 342 |
Get a random seed from the API.
|
| 343 |
-
|
| 344 |
Returns:
|
| 345 |
Random seed value
|
| 346 |
-
|
| 347 |
Raises:
|
| 348 |
RuntimeError: If random seed retrieval fails
|
| 349 |
"""
|
|
@@ -357,14 +358,14 @@ class FluxLoRATool(Tool):
|
|
| 357 |
# Just log and re-raise as we have fallback in the validation method
|
| 358 |
self.logger.warning(f"Failed to get random seed: {e}")
|
| 359 |
raise
|
| 360 |
-
|
| 361 |
def _handle_custom_lora(self, custom_lora: Optional[str]) -> None:
|
| 362 |
"""
|
| 363 |
Add or remove custom LoRA model.
|
| 364 |
-
|
| 365 |
Args:
|
| 366 |
custom_lora: Custom LoRA model string
|
| 367 |
-
|
| 368 |
Raises:
|
| 369 |
RuntimeError: If LoRA handling fails
|
| 370 |
"""
|
|
@@ -381,15 +382,14 @@ class FluxLoRATool(Tool):
|
|
| 381 |
# Add custom LoRA
|
| 382 |
try:
|
| 383 |
self.client.predict(
|
| 384 |
-
custom_lora=custom_lora,
|
| 385 |
-
api_name="/add_custom_lora"
|
| 386 |
)
|
| 387 |
self.logger.debug(f"Added custom LoRA: {custom_lora}")
|
| 388 |
except Exception as e:
|
| 389 |
error_msg = f"Failed to add custom LoRA '{custom_lora}': {str(e)}"
|
| 390 |
self.logger.error(error_msg)
|
| 391 |
raise RuntimeError(error_msg) from e
|
| 392 |
-
|
| 393 |
def forward(
|
| 394 |
self,
|
| 395 |
prompt: str,
|
|
@@ -401,11 +401,11 @@ class FluxLoRATool(Tool):
|
|
| 401 |
width: Optional[int] = None,
|
| 402 |
height: Optional[int] = None,
|
| 403 |
lora_scale: Optional[float] = None,
|
| 404 |
-
custom_lora: Optional[str] = None
|
| 405 |
) -> str:
|
| 406 |
"""
|
| 407 |
Generate an image with FLUX-LoRA.
|
| 408 |
-
|
| 409 |
Args:
|
| 410 |
prompt: Text description of the desired image
|
| 411 |
image_input: Optional path or URL to input image for img2img
|
|
@@ -417,10 +417,10 @@ class FluxLoRATool(Tool):
|
|
| 417 |
height: Image height in pixels (128-2048)
|
| 418 |
lora_scale: LoRA influence scale (0.0-1.0)
|
| 419 |
custom_lora: Custom LoRA model to use
|
| 420 |
-
|
| 421 |
Returns:
|
| 422 |
Formatted string with image generation results
|
| 423 |
-
|
| 424 |
Raises:
|
| 425 |
ValueError: If input validation fails
|
| 426 |
ConnectionError: If API communication fails
|
|
@@ -438,12 +438,12 @@ class FluxLoRATool(Tool):
|
|
| 438 |
width=width,
|
| 439 |
height=height,
|
| 440 |
lora_scale=lora_scale,
|
| 441 |
-
custom_lora=custom_lora
|
| 442 |
)
|
| 443 |
self.logger.debug(f"Validated parameters: {params}")
|
| 444 |
except ValueError as e:
|
| 445 |
return f"Parameter validation failed: {str(e)}"
|
| 446 |
-
|
| 447 |
# Step 2: Handle custom LoRA if specified
|
| 448 |
if "custom_lora" in params:
|
| 449 |
try:
|
|
@@ -451,15 +451,16 @@ class FluxLoRATool(Tool):
|
|
| 451 |
self._handle_custom_lora(custom_lora_value)
|
| 452 |
except RuntimeError as e:
|
| 453 |
return f"Custom LoRA setup failed: {str(e)}"
|
| 454 |
-
|
| 455 |
# Step 3: Generate image
|
| 456 |
try:
|
| 457 |
# Prepare image input if provided
|
| 458 |
img_param = None
|
| 459 |
if "image_input" in params and params["image_input"]:
|
| 460 |
from gradio_client import handle_file
|
|
|
|
| 461 |
img_param = handle_file(params.pop("image_input"))
|
| 462 |
-
|
| 463 |
# Call the API
|
| 464 |
generation_args = {
|
| 465 |
"prompt": params["prompt"],
|
|
@@ -472,27 +473,23 @@ class FluxLoRATool(Tool):
|
|
| 472 |
"height": params["height"],
|
| 473 |
"lora_scale": params["lora_scale"],
|
| 474 |
}
|
| 475 |
-
|
| 476 |
# Add image input if available
|
| 477 |
if img_param:
|
| 478 |
generation_args["image_input"] = img_param
|
| 479 |
-
|
| 480 |
self.logger.info(f"Generating image with params: {generation_args}")
|
| 481 |
-
result = self.client.predict(
|
| 482 |
-
|
| 483 |
-
**generation_args
|
| 484 |
-
)
|
| 485 |
-
|
| 486 |
# Process result
|
| 487 |
if isinstance(result, tuple) and len(result) >= 2:
|
| 488 |
image_path, actual_seed = result[0], result[1]
|
| 489 |
-
|
| 490 |
# Save image to our directory
|
| 491 |
try:
|
| 492 |
output_path = self._save_image(image_path)
|
| 493 |
image_result = ImageGenerationResult(
|
| 494 |
-
image_path=output_path,
|
| 495 |
-
seed=int(actual_seed)
|
| 496 |
)
|
| 497 |
return self._format_result(image_result, params["prompt"])
|
| 498 |
except Exception as e:
|
|
@@ -500,69 +497,69 @@ class FluxLoRATool(Tool):
|
|
| 500 |
return f"Image generated but failed to save: {str(e)}"
|
| 501 |
else:
|
| 502 |
raise ValueError(f"Unexpected API response format: {result}")
|
| 503 |
-
|
| 504 |
except Exception as e:
|
| 505 |
error_msg = f"Image generation failed: {str(e)}"
|
| 506 |
self.logger.error(error_msg)
|
| 507 |
return error_msg
|
| 508 |
-
|
| 509 |
def _save_image(self, image_path: str) -> str:
|
| 510 |
"""
|
| 511 |
Save generated image to specified directory.
|
| 512 |
-
|
| 513 |
Args:
|
| 514 |
image_path: Path to generated image from API
|
| 515 |
-
|
| 516 |
Returns:
|
| 517 |
Path to saved image
|
| 518 |
-
|
| 519 |
Raises:
|
| 520 |
IOError: If image saving fails
|
| 521 |
"""
|
| 522 |
try:
|
| 523 |
# Load the image
|
| 524 |
img = Image.open(image_path)
|
| 525 |
-
|
| 526 |
# Generate timestamp-based filename
|
| 527 |
timestamp = uuid.uuid4().hex[:8]
|
| 528 |
output_filename = f"flux_lora_{timestamp}.png"
|
| 529 |
output_path = os.path.join(self.image_save_dir, output_filename)
|
| 530 |
-
|
| 531 |
# Save to our directory
|
| 532 |
img.save(output_path)
|
| 533 |
self.logger.debug(f"Saved image to {output_path}")
|
| 534 |
-
|
| 535 |
return output_path
|
| 536 |
-
|
| 537 |
except Exception as e:
|
| 538 |
error_msg = f"Failed to save image: {str(e)}"
|
| 539 |
self.logger.error(error_msg)
|
| 540 |
raise IOError(error_msg) from e
|
| 541 |
-
|
| 542 |
def _format_result(self, result: ImageGenerationResult, prompt: str) -> str:
|
| 543 |
"""
|
| 544 |
Format the image generation result as a string.
|
| 545 |
-
|
| 546 |
Args:
|
| 547 |
result: Image generation result
|
| 548 |
prompt: Original prompt
|
| 549 |
-
|
| 550 |
Returns:
|
| 551 |
Formatted string with generation details
|
| 552 |
"""
|
| 553 |
lines = [
|
| 554 |
-
|
| 555 |
f"🖼️ Image saved to: {result.image_path}",
|
| 556 |
f"🌱 Seed used: {result.seed}",
|
| 557 |
f"📝 Original prompt: {prompt}",
|
| 558 |
]
|
| 559 |
-
|
| 560 |
# Add metadata if available
|
| 561 |
if result.metadata:
|
| 562 |
lines.append("📊 Additional metadata:")
|
| 563 |
for key, value in result.metadata.items():
|
| 564 |
lines.append(f" - {key}: {value}")
|
| 565 |
-
|
| 566 |
return "\n".join(lines)
|
| 567 |
|
| 568 |
|
|
@@ -570,17 +567,18 @@ class FluxLoRATool(Tool):
|
|
| 570 |
# UTILITY FUNCTIONS
|
| 571 |
# -----------------------------------------------------------------------------
|
| 572 |
|
|
|
|
| 573 |
def download_image(url: str, output_dir: Optional[str] = None) -> str:
|
| 574 |
"""
|
| 575 |
Standalone utility to download an image from a URL.
|
| 576 |
-
|
| 577 |
Args:
|
| 578 |
url: Image URL
|
| 579 |
output_dir: Directory to save image (created if doesn't exist)
|
| 580 |
-
|
| 581 |
Returns:
|
| 582 |
Path to downloaded image
|
| 583 |
-
|
| 584 |
Raises:
|
| 585 |
ValueError: If URL is invalid
|
| 586 |
ConnectionError: If download fails
|
|
@@ -588,31 +586,30 @@ def download_image(url: str, output_dir: Optional[str] = None) -> str:
|
|
| 588 |
"""
|
| 589 |
if not url.startswith(("http://", "https://")):
|
| 590 |
raise ValueError(f"Invalid URL: {url}")
|
| 591 |
-
|
| 592 |
# Setup output directory
|
| 593 |
if output_dir is None:
|
| 594 |
output_dir = os.path.join(tempfile.gettempdir(), "flux_lora_images")
|
| 595 |
os.makedirs(output_dir, exist_ok=True)
|
| 596 |
-
|
| 597 |
try:
|
| 598 |
# Download image
|
| 599 |
response = requests.get(url, stream=True, timeout=30)
|
| 600 |
response.raise_for_status()
|
| 601 |
-
|
| 602 |
# Determine file extension
|
| 603 |
content_type = response.headers.get("Content-Type", "")
|
| 604 |
ext = ".jpg" if "jpeg" in content_type.lower() else ".png"
|
| 605 |
-
|
| 606 |
# Save image
|
| 607 |
output_path = os.path.join(output_dir, f"download_{uuid.uuid4().hex}{ext}")
|
| 608 |
with open(output_path, "wb") as f:
|
| 609 |
for chunk in response.iter_content(chunk_size=8192):
|
| 610 |
f.write(chunk)
|
| 611 |
-
|
| 612 |
return output_path
|
| 613 |
-
|
| 614 |
except requests.RequestException as e:
|
| 615 |
raise ConnectionError(f"Failed to download image: {str(e)}")
|
| 616 |
except IOError as e:
|
| 617 |
raise IOError(f"Failed to save image: {str(e)}")
|
| 618 |
-
|
|
|
|
| 12 |
agent = CodeAgent(tools=[flux_tool], ...)
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
import logging
|
| 16 |
import os
|
|
|
|
| 17 |
import tempfile
|
| 18 |
+
import uuid
|
|
|
|
| 19 |
from dataclasses import dataclass
|
| 20 |
+
from typing import Any, Dict, Optional
|
|
|
|
| 21 |
|
| 22 |
# Third-party
|
| 23 |
import requests
|
|
|
|
| 24 |
from gradio_client import Client
|
| 25 |
+
from PIL import Image
|
|
|
|
| 26 |
from smolagents import Tool
|
| 27 |
|
| 28 |
# -----------------------------------------------------------------------------
|
| 29 |
# CONSTANTS AND TYPE DEFINITIONS
|
| 30 |
# -----------------------------------------------------------------------------
|
| 31 |
|
| 32 |
+
|
| 33 |
@dataclass
|
| 34 |
class LoRAModelInfo:
|
| 35 |
"""Value object representing LoRA model information."""
|
| 36 |
+
|
| 37 |
name: str
|
| 38 |
description: Optional[str] = None
|
| 39 |
example_image_url: Optional[str] = None
|
|
|
|
| 42 |
@dataclass
|
| 43 |
class ImageGenerationResult:
|
| 44 |
"""Value object representing a generated image result."""
|
| 45 |
+
|
| 46 |
image_path: str
|
| 47 |
seed: int
|
| 48 |
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
| 52 |
# CORE TOOL IMPLEMENTATION
|
| 53 |
# -----------------------------------------------------------------------------
|
| 54 |
|
| 55 |
+
|
| 56 |
class FluxLoRATool(Tool):
|
| 57 |
"""
|
| 58 |
Tool for generating images using FLUX-LoRA-DLC API.
|
| 59 |
+
|
| 60 |
This tool implements the Zhou Protocol integration patterns to provide
|
| 61 |
a clean, efficient interface for image generation using LoRA models.
|
| 62 |
"""
|
| 63 |
+
|
| 64 |
name = "flux_lora_generator"
|
| 65 |
description = """
|
| 66 |
Generates high-quality images using FLUX-LoRA models.
|
|
|
|
| 68 |
"""
|
| 69 |
inputs = {
|
| 70 |
"prompt": {
|
| 71 |
+
"type": "string",
|
| 72 |
+
"description": "Detailed description of the desired image.",
|
| 73 |
},
|
| 74 |
"image_input": {
|
| 75 |
+
"type": "string",
|
| 76 |
"description": "Optional URL or file path to input image for img2img generation.",
|
| 77 |
+
"optional": True,
|
| 78 |
},
|
| 79 |
"image_strength": {
|
| 80 |
"type": "float",
|
| 81 |
"description": "Strength of input image influence (0.0-1.0), where 1.0 maintains more of original image.",
|
| 82 |
"optional": True,
|
| 83 |
+
"default": 0.75,
|
| 84 |
},
|
| 85 |
"cfg_scale": {
|
| 86 |
"type": "float",
|
| 87 |
"description": "Guidance scale for prompt adherence (1.0-30.0).",
|
| 88 |
"optional": True,
|
| 89 |
+
"default": 3.5,
|
| 90 |
},
|
| 91 |
"steps": {
|
| 92 |
"type": "integer",
|
| 93 |
"description": "Number of sampling steps (10-100).",
|
| 94 |
"optional": True,
|
| 95 |
+
"default": 28,
|
| 96 |
},
|
| 97 |
"seed": {
|
| 98 |
"type": "integer",
|
| 99 |
"description": "Random seed for reproducibility. Use -1 for random seed.",
|
| 100 |
"optional": True,
|
| 101 |
+
"default": -1,
|
| 102 |
},
|
| 103 |
"width": {
|
| 104 |
"type": "integer",
|
| 105 |
"description": "Image width in pixels.",
|
| 106 |
"optional": True,
|
| 107 |
+
"default": 1024,
|
| 108 |
},
|
| 109 |
"height": {
|
| 110 |
"type": "integer",
|
| 111 |
"description": "Image height in pixels.",
|
| 112 |
"optional": True,
|
| 113 |
+
"default": 1024,
|
| 114 |
},
|
| 115 |
"lora_scale": {
|
| 116 |
"type": "float",
|
| 117 |
"description": "LoRA influence scale (0.0-1.0).",
|
| 118 |
"optional": True,
|
| 119 |
+
"default": 0.95,
|
| 120 |
},
|
| 121 |
"custom_lora": {
|
| 122 |
"type": "string",
|
| 123 |
"description": "Custom LoRA model to use. Leave empty for default.",
|
| 124 |
+
"optional": True,
|
| 125 |
+
},
|
| 126 |
}
|
| 127 |
output_type = "string"
|
| 128 |
+
|
| 129 |
def __init__(
|
| 130 |
+
self,
|
| 131 |
api_url: str = "xkerser/FLUX-LoRA-DLC",
|
| 132 |
image_save_dir: Optional[str] = None,
|
| 133 |
connection_timeout: int = 60,
|
| 134 |
+
verbose: bool = False,
|
| 135 |
):
|
| 136 |
"""
|
| 137 |
Initialize the FLUX-LoRA Tool with Zhou Protocol connection patterns.
|
| 138 |
+
|
| 139 |
Args:
|
| 140 |
api_url: URL or endpoint ID for the FLUX-LoRA-DLC API
|
| 141 |
image_save_dir: Directory to save generated images (created if doesn't exist)
|
|
|
|
| 143 |
verbose: Enable detailed logging
|
| 144 |
"""
|
| 145 |
super().__init__()
|
| 146 |
+
|
| 147 |
# Initialize logging
|
| 148 |
self.logger = logging.getLogger("flux_lora_tool")
|
| 149 |
self.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
| 150 |
+
|
| 151 |
# Set up client and storage directories
|
| 152 |
self.api_url = api_url
|
| 153 |
self.connection_timeout = connection_timeout
|
| 154 |
self._client = None # Lazy initialization
|
| 155 |
+
|
| 156 |
# Set up image storage directory
|
| 157 |
+
self.image_save_dir = image_save_dir or os.path.join(
|
| 158 |
+
tempfile.gettempdir(), "flux_lora_images"
|
| 159 |
+
)
|
| 160 |
os.makedirs(self.image_save_dir, exist_ok=True)
|
| 161 |
+
self.logger.info(
|
| 162 |
+
f"FluxLoRATool initialized. Images will be saved to: {self.image_save_dir}"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
@property
|
| 166 |
def client(self) -> Client:
|
| 167 |
"""
|
| 168 |
Get or initialize the Gradio client with proper connection handling.
|
| 169 |
+
|
| 170 |
Returns:
|
| 171 |
Initialized Gradio client
|
| 172 |
+
|
| 173 |
Raises:
|
| 174 |
ConnectionError: If client initialization fails
|
| 175 |
"""
|
| 176 |
if self._client is None:
|
| 177 |
try:
|
| 178 |
+
self._client = Client(self.api_url, timeout=self.connection_timeout)
|
|
|
|
|
|
|
|
|
|
| 179 |
self.logger.debug(f"Gradio client initialized for: {self.api_url}")
|
| 180 |
except Exception as e:
|
| 181 |
error_msg = f"Failed to initialize FLUX-LoRA client: {str(e)}"
|
| 182 |
self.logger.error(error_msg)
|
| 183 |
raise ConnectionError(error_msg) from e
|
| 184 |
+
|
| 185 |
return self._client
|
| 186 |
+
|
| 187 |
def _validate_inputs(self, **kwargs) -> Dict[str, Any]:
|
| 188 |
"""
|
| 189 |
Validate and normalize input parameters with Zhou Protocol validation patterns.
|
| 190 |
+
|
| 191 |
Args:
|
| 192 |
**kwargs: Input parameters
|
| 193 |
+
|
| 194 |
Returns:
|
| 195 |
Validated and normalized parameters
|
| 196 |
+
|
| 197 |
Raises:
|
| 198 |
ValueError: If input validation fails
|
| 199 |
"""
|
| 200 |
validated = {}
|
| 201 |
+
|
| 202 |
# Required parameter: prompt
|
| 203 |
if not kwargs.get("prompt"):
|
| 204 |
raise ValueError("Prompt is required for image generation")
|
| 205 |
validated["prompt"] = kwargs["prompt"]
|
| 206 |
+
|
| 207 |
# Image input handling
|
| 208 |
if "image_input" in kwargs and kwargs["image_input"]:
|
| 209 |
input_image = kwargs["image_input"]
|
|
|
|
| 216 |
if not os.path.exists(input_image):
|
| 217 |
raise ValueError(f"Image file not found: {input_image}")
|
| 218 |
validated["image_input"] = input_image
|
| 219 |
+
|
| 220 |
# Numeric parameter validation with constraints
|
| 221 |
numeric_params = {
|
| 222 |
"image_strength": {"min": 0.0, "max": 1.0, "default": 0.75},
|
|
|
|
| 224 |
"steps": {"min": 10, "max": 100, "default": 28},
|
| 225 |
"width": {"min": 128, "max": 2048, "default": 1024},
|
| 226 |
"height": {"min": 128, "max": 2048, "default": 1024},
|
| 227 |
+
"lora_scale": {"min": 0.0, "max": 1.0, "default": 0.95},
|
| 228 |
}
|
| 229 |
+
|
| 230 |
for param, constraints in numeric_params.items():
|
| 231 |
if param in kwargs and kwargs[param] is not None:
|
| 232 |
value = kwargs[param]
|
| 233 |
+
|
| 234 |
# Type conversion if needed
|
| 235 |
if param in ["steps", "width", "height"]:
|
| 236 |
try:
|
|
|
|
| 242 |
value = float(value)
|
| 243 |
except (ValueError, TypeError):
|
| 244 |
raise ValueError(f"Parameter '{param}' must be a number")
|
| 245 |
+
|
| 246 |
# Range validation
|
| 247 |
if value < constraints["min"] or value > constraints["max"]:
|
| 248 |
raise ValueError(
|
| 249 |
f"Parameter '{param}' must be between {constraints['min']} and {constraints['max']}"
|
| 250 |
)
|
| 251 |
+
|
| 252 |
validated[param] = value
|
| 253 |
else:
|
| 254 |
validated[param] = constraints["default"]
|
| 255 |
+
|
| 256 |
# Special handling for seed
|
| 257 |
if "seed" in kwargs and kwargs["seed"] is not None:
|
| 258 |
try:
|
|
|
|
| 265 |
self.logger.warning(f"Failed to get random seed from API: {e}")
|
| 266 |
# Fallback to Python's random
|
| 267 |
import random
|
| 268 |
+
|
| 269 |
seed = random.randint(0, 2**32 - 1)
|
| 270 |
validated["seed"] = seed
|
| 271 |
except (ValueError, TypeError):
|
|
|
|
| 273 |
else:
|
| 274 |
# Default to random seed
|
| 275 |
validated["seed"] = self._get_random_seed()
|
| 276 |
+
|
| 277 |
# Custom LoRA handling
|
| 278 |
if "custom_lora" in kwargs and kwargs["custom_lora"]:
|
| 279 |
validated["custom_lora"] = kwargs["custom_lora"]
|
| 280 |
+
|
| 281 |
return validated
|
| 282 |
+
|
| 283 |
def _download_image(self, url: str) -> str:
|
| 284 |
"""
|
| 285 |
Download image from URL and save to local file.
|
| 286 |
+
|
| 287 |
Args:
|
| 288 |
url: Image URL
|
| 289 |
+
|
| 290 |
Returns:
|
| 291 |
Local file path
|
| 292 |
+
|
| 293 |
Raises:
|
| 294 |
ConnectionError: If download fails
|
| 295 |
"""
|
| 296 |
try:
|
| 297 |
response = requests.get(url, stream=True, timeout=30)
|
| 298 |
response.raise_for_status()
|
| 299 |
+
|
| 300 |
# Generate temporary file path
|
| 301 |
file_ext = self._guess_extension(response.headers.get("Content-Type", ""))
|
| 302 |
temp_path = os.path.join(
|
| 303 |
+
self.image_save_dir, f"input_{uuid.uuid4().hex}{file_ext}"
|
|
|
|
| 304 |
)
|
| 305 |
+
|
| 306 |
# Save image
|
| 307 |
with open(temp_path, "wb") as f:
|
| 308 |
for chunk in response.iter_content(chunk_size=8192):
|
| 309 |
f.write(chunk)
|
| 310 |
+
|
| 311 |
self.logger.debug(f"Downloaded image from {url} to {temp_path}")
|
| 312 |
return temp_path
|
| 313 |
+
|
| 314 |
except Exception as e:
|
| 315 |
error_msg = f"Failed to download image from {url}: {str(e)}"
|
| 316 |
self.logger.error(error_msg)
|
| 317 |
raise ConnectionError(error_msg) from e
|
| 318 |
+
|
| 319 |
def _guess_extension(self, content_type: str) -> str:
|
| 320 |
"""
|
| 321 |
Guess file extension from content type.
|
| 322 |
+
|
| 323 |
Args:
|
| 324 |
content_type: HTTP Content-Type header
|
| 325 |
+
|
| 326 |
Returns:
|
| 327 |
File extension (with dot)
|
| 328 |
"""
|
|
|
|
| 337 |
return ".gif"
|
| 338 |
else:
|
| 339 |
return ".png" # Default to PNG
|
| 340 |
+
|
| 341 |
def _get_random_seed(self) -> int:
|
| 342 |
"""
|
| 343 |
Get a random seed from the API.
|
| 344 |
+
|
| 345 |
Returns:
|
| 346 |
Random seed value
|
| 347 |
+
|
| 348 |
Raises:
|
| 349 |
RuntimeError: If random seed retrieval fails
|
| 350 |
"""
|
|
|
|
| 358 |
# Just log and re-raise as we have fallback in the validation method
|
| 359 |
self.logger.warning(f"Failed to get random seed: {e}")
|
| 360 |
raise
|
| 361 |
+
|
| 362 |
def _handle_custom_lora(self, custom_lora: Optional[str]) -> None:
|
| 363 |
"""
|
| 364 |
Add or remove custom LoRA model.
|
| 365 |
+
|
| 366 |
Args:
|
| 367 |
custom_lora: Custom LoRA model string
|
| 368 |
+
|
| 369 |
Raises:
|
| 370 |
RuntimeError: If LoRA handling fails
|
| 371 |
"""
|
|
|
|
| 382 |
# Add custom LoRA
|
| 383 |
try:
|
| 384 |
self.client.predict(
|
| 385 |
+
custom_lora=custom_lora, api_name="/add_custom_lora"
|
|
|
|
| 386 |
)
|
| 387 |
self.logger.debug(f"Added custom LoRA: {custom_lora}")
|
| 388 |
except Exception as e:
|
| 389 |
error_msg = f"Failed to add custom LoRA '{custom_lora}': {str(e)}"
|
| 390 |
self.logger.error(error_msg)
|
| 391 |
raise RuntimeError(error_msg) from e
|
| 392 |
+
|
| 393 |
def forward(
|
| 394 |
self,
|
| 395 |
prompt: str,
|
|
|
|
| 401 |
width: Optional[int] = None,
|
| 402 |
height: Optional[int] = None,
|
| 403 |
lora_scale: Optional[float] = None,
|
| 404 |
+
custom_lora: Optional[str] = None,
|
| 405 |
) -> str:
|
| 406 |
"""
|
| 407 |
Generate an image with FLUX-LoRA.
|
| 408 |
+
|
| 409 |
Args:
|
| 410 |
prompt: Text description of the desired image
|
| 411 |
image_input: Optional path or URL to input image for img2img
|
|
|
|
| 417 |
height: Image height in pixels (128-2048)
|
| 418 |
lora_scale: LoRA influence scale (0.0-1.0)
|
| 419 |
custom_lora: Custom LoRA model to use
|
| 420 |
+
|
| 421 |
Returns:
|
| 422 |
Formatted string with image generation results
|
| 423 |
+
|
| 424 |
Raises:
|
| 425 |
ValueError: If input validation fails
|
| 426 |
ConnectionError: If API communication fails
|
|
|
|
| 438 |
width=width,
|
| 439 |
height=height,
|
| 440 |
lora_scale=lora_scale,
|
| 441 |
+
custom_lora=custom_lora,
|
| 442 |
)
|
| 443 |
self.logger.debug(f"Validated parameters: {params}")
|
| 444 |
except ValueError as e:
|
| 445 |
return f"Parameter validation failed: {str(e)}"
|
| 446 |
+
|
| 447 |
# Step 2: Handle custom LoRA if specified
|
| 448 |
if "custom_lora" in params:
|
| 449 |
try:
|
|
|
|
| 451 |
self._handle_custom_lora(custom_lora_value)
|
| 452 |
except RuntimeError as e:
|
| 453 |
return f"Custom LoRA setup failed: {str(e)}"
|
| 454 |
+
|
| 455 |
# Step 3: Generate image
|
| 456 |
try:
|
| 457 |
# Prepare image input if provided
|
| 458 |
img_param = None
|
| 459 |
if "image_input" in params and params["image_input"]:
|
| 460 |
from gradio_client import handle_file
|
| 461 |
+
|
| 462 |
img_param = handle_file(params.pop("image_input"))
|
| 463 |
+
|
| 464 |
# Call the API
|
| 465 |
generation_args = {
|
| 466 |
"prompt": params["prompt"],
|
|
|
|
| 473 |
"height": params["height"],
|
| 474 |
"lora_scale": params["lora_scale"],
|
| 475 |
}
|
| 476 |
+
|
| 477 |
# Add image input if available
|
| 478 |
if img_param:
|
| 479 |
generation_args["image_input"] = img_param
|
| 480 |
+
|
| 481 |
self.logger.info(f"Generating image with params: {generation_args}")
|
| 482 |
+
result = self.client.predict(api_name="/run_lora", **generation_args)
|
| 483 |
+
|
|
|
|
|
|
|
|
|
|
| 484 |
# Process result
|
| 485 |
if isinstance(result, tuple) and len(result) >= 2:
|
| 486 |
image_path, actual_seed = result[0], result[1]
|
| 487 |
+
|
| 488 |
# Save image to our directory
|
| 489 |
try:
|
| 490 |
output_path = self._save_image(image_path)
|
| 491 |
image_result = ImageGenerationResult(
|
| 492 |
+
image_path=output_path, seed=int(actual_seed)
|
|
|
|
| 493 |
)
|
| 494 |
return self._format_result(image_result, params["prompt"])
|
| 495 |
except Exception as e:
|
|
|
|
| 497 |
return f"Image generated but failed to save: {str(e)}"
|
| 498 |
else:
|
| 499 |
raise ValueError(f"Unexpected API response format: {result}")
|
| 500 |
+
|
| 501 |
except Exception as e:
|
| 502 |
error_msg = f"Image generation failed: {str(e)}"
|
| 503 |
self.logger.error(error_msg)
|
| 504 |
return error_msg
|
| 505 |
+
|
| 506 |
def _save_image(self, image_path: str) -> str:
|
| 507 |
"""
|
| 508 |
Save generated image to specified directory.
|
| 509 |
+
|
| 510 |
Args:
|
| 511 |
image_path: Path to generated image from API
|
| 512 |
+
|
| 513 |
Returns:
|
| 514 |
Path to saved image
|
| 515 |
+
|
| 516 |
Raises:
|
| 517 |
IOError: If image saving fails
|
| 518 |
"""
|
| 519 |
try:
|
| 520 |
# Load the image
|
| 521 |
img = Image.open(image_path)
|
| 522 |
+
|
| 523 |
# Generate timestamp-based filename
|
| 524 |
timestamp = uuid.uuid4().hex[:8]
|
| 525 |
output_filename = f"flux_lora_{timestamp}.png"
|
| 526 |
output_path = os.path.join(self.image_save_dir, output_filename)
|
| 527 |
+
|
| 528 |
# Save to our directory
|
| 529 |
img.save(output_path)
|
| 530 |
self.logger.debug(f"Saved image to {output_path}")
|
| 531 |
+
|
| 532 |
return output_path
|
| 533 |
+
|
| 534 |
except Exception as e:
|
| 535 |
error_msg = f"Failed to save image: {str(e)}"
|
| 536 |
self.logger.error(error_msg)
|
| 537 |
raise IOError(error_msg) from e
|
| 538 |
+
|
| 539 |
def _format_result(self, result: ImageGenerationResult, prompt: str) -> str:
|
| 540 |
"""
|
| 541 |
Format the image generation result as a string.
|
| 542 |
+
|
| 543 |
Args:
|
| 544 |
result: Image generation result
|
| 545 |
prompt: Original prompt
|
| 546 |
+
|
| 547 |
Returns:
|
| 548 |
Formatted string with generation details
|
| 549 |
"""
|
| 550 |
lines = [
|
| 551 |
+
"📷 Image generated successfully!",
|
| 552 |
f"🖼️ Image saved to: {result.image_path}",
|
| 553 |
f"🌱 Seed used: {result.seed}",
|
| 554 |
f"📝 Original prompt: {prompt}",
|
| 555 |
]
|
| 556 |
+
|
| 557 |
# Add metadata if available
|
| 558 |
if result.metadata:
|
| 559 |
lines.append("📊 Additional metadata:")
|
| 560 |
for key, value in result.metadata.items():
|
| 561 |
lines.append(f" - {key}: {value}")
|
| 562 |
+
|
| 563 |
return "\n".join(lines)
|
| 564 |
|
| 565 |
|
|
|
|
| 567 |
# UTILITY FUNCTIONS
|
| 568 |
# -----------------------------------------------------------------------------
|
| 569 |
|
| 570 |
+
|
| 571 |
def download_image(url: str, output_dir: Optional[str] = None) -> str:
|
| 572 |
"""
|
| 573 |
Standalone utility to download an image from a URL.
|
| 574 |
+
|
| 575 |
Args:
|
| 576 |
url: Image URL
|
| 577 |
output_dir: Directory to save image (created if doesn't exist)
|
| 578 |
+
|
| 579 |
Returns:
|
| 580 |
Path to downloaded image
|
| 581 |
+
|
| 582 |
Raises:
|
| 583 |
ValueError: If URL is invalid
|
| 584 |
ConnectionError: If download fails
|
|
|
|
| 586 |
"""
|
| 587 |
if not url.startswith(("http://", "https://")):
|
| 588 |
raise ValueError(f"Invalid URL: {url}")
|
| 589 |
+
|
| 590 |
# Setup output directory
|
| 591 |
if output_dir is None:
|
| 592 |
output_dir = os.path.join(tempfile.gettempdir(), "flux_lora_images")
|
| 593 |
os.makedirs(output_dir, exist_ok=True)
|
| 594 |
+
|
| 595 |
try:
|
| 596 |
# Download image
|
| 597 |
response = requests.get(url, stream=True, timeout=30)
|
| 598 |
response.raise_for_status()
|
| 599 |
+
|
| 600 |
# Determine file extension
|
| 601 |
content_type = response.headers.get("Content-Type", "")
|
| 602 |
ext = ".jpg" if "jpeg" in content_type.lower() else ".png"
|
| 603 |
+
|
| 604 |
# Save image
|
| 605 |
output_path = os.path.join(output_dir, f"download_{uuid.uuid4().hex}{ext}")
|
| 606 |
with open(output_path, "wb") as f:
|
| 607 |
for chunk in response.iter_content(chunk_size=8192):
|
| 608 |
f.write(chunk)
|
| 609 |
+
|
| 610 |
return output_path
|
| 611 |
+
|
| 612 |
except requests.RequestException as e:
|
| 613 |
raise ConnectionError(f"Failed to download image: {str(e)}")
|
| 614 |
except IOError as e:
|
| 615 |
raise IOError(f"Failed to save image: {str(e)}")
|
|
|
scripts/frontmatter_tool.py
DELETED
|
@@ -1,402 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Frontmatter Generator Tool for Smolagents
|
| 3 |
-
|
| 4 |
-
This tool helps generate consistent YAML frontmatter for documents,
|
| 5 |
-
useful for RAG systems, static site generators, and document organization.
|
| 6 |
-
Integrates with TextInspectorTool and MarkdownConverter for a complete
|
| 7 |
-
document processing pipeline.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
import re
|
| 11 |
-
import yaml
|
| 12 |
-
import json
|
| 13 |
-
from datetime import datetime
|
| 14 |
-
from typing import Dict, List, Optional, Any, Union
|
| 15 |
-
from smolagents import Tool
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class FrontmatterGeneratorTool(Tool):
|
| 19 |
-
"""Tool for generating and manipulating YAML frontmatter in documents."""
|
| 20 |
-
|
| 21 |
-
name = "frontmatter_generator"
|
| 22 |
-
description = """
|
| 23 |
-
Generates or extracts YAML frontmatter for documents. Frontmatter provides structured
|
| 24 |
-
metadata for documents including title, author, date, description, and tags.
|
| 25 |
-
Useful for document organization, RAG systems, and static site generators.
|
| 26 |
-
Works with content from the inspect_file_as_text tool to add metadata to documents.
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
inputs = {
|
| 30 |
-
"content": {
|
| 31 |
-
"type": "string",
|
| 32 |
-
"description": "Document content (with or without existing frontmatter)",
|
| 33 |
-
},
|
| 34 |
-
"title": {"type": "string", "description": "Document title", "nullable": True},
|
| 35 |
-
"author": {
|
| 36 |
-
"type": "string",
|
| 37 |
-
"description": "Document author(s)",
|
| 38 |
-
"nullable": True,
|
| 39 |
-
},
|
| 40 |
-
"date": {
|
| 41 |
-
"type": "string",
|
| 42 |
-
"description": "Document date in YYYY-MM-DD format (defaults to today if not provided)",
|
| 43 |
-
"nullable": True,
|
| 44 |
-
},
|
| 45 |
-
"date_format": {
|
| 46 |
-
"type": "string",
|
| 47 |
-
"description": "Format string for the document date (e.g., '%Y-%m-%d', '%d/%m/%Y'). Defaults to '%Y-%m-%d'",
|
| 48 |
-
"nullable": True,
|
| 49 |
-
"default": "%Y-%m-%d",
|
| 50 |
-
},
|
| 51 |
-
"description": {
|
| 52 |
-
"type": "string",
|
| 53 |
-
"description": "Brief description of the document",
|
| 54 |
-
"nullable": True,
|
| 55 |
-
},
|
| 56 |
-
"tags": {
|
| 57 |
-
"type": "string",
|
| 58 |
-
"description": "Comma-separated list of tags",
|
| 59 |
-
"nullable": True,
|
| 60 |
-
},
|
| 61 |
-
"additional_fields": {
|
| 62 |
-
"type": "string",
|
| 63 |
-
"description": "JSON string with additional frontmatter fields",
|
| 64 |
-
"nullable": True,
|
| 65 |
-
},
|
| 66 |
-
"mode": {
|
| 67 |
-
"type": "string",
|
| 68 |
-
"description": "Operation mode: 'generate' (create new), 'extract' (get existing), 'update' (modify existing), or 'strip' (remove)",
|
| 69 |
-
"default": "generate",
|
| 70 |
-
},
|
| 71 |
-
}
|
| 72 |
-
output_type = "string"
|
| 73 |
-
|
| 74 |
-
# Regular expression to detect and extract YAML frontmatter
|
| 75 |
-
FRONTMATTER_PATTERN = r"^---\s*\n(.*?)\n---\s*\n"
|
| 76 |
-
|
| 77 |
-
def forward(
|
| 78 |
-
self,
|
| 79 |
-
content: str,
|
| 80 |
-
title: Optional[str] = None,
|
| 81 |
-
author: Optional[str] = None,
|
| 82 |
-
date: Optional[str] = None,
|
| 83 |
-
date_format: Optional[str] = "%Y-%m-%d",
|
| 84 |
-
description: Optional[str] = None,
|
| 85 |
-
tags: Optional[str] = None,
|
| 86 |
-
additional_fields: Optional[str] = None,
|
| 87 |
-
mode: str = "generate",
|
| 88 |
-
) -> str:
|
| 89 |
-
"""
|
| 90 |
-
Process document content based on specified mode.
|
| 91 |
-
|
| 92 |
-
Args:
|
| 93 |
-
content: Document content with or without frontmatter
|
| 94 |
-
title: Document title
|
| 95 |
-
author: Document author(s)
|
| 96 |
-
date: Document date (YYYY-MM-DD)
|
| 97 |
-
date_format: strftime format string
|
| 98 |
-
description: Brief document description
|
| 99 |
-
tags: Comma-separated list of tags
|
| 100 |
-
additional_fields: JSON string with additional fields
|
| 101 |
-
mode: Operation mode (generate, extract, update, strip)
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
Processed document or extracted frontmatter
|
| 105 |
-
"""
|
| 106 |
-
# Validate inputs
|
| 107 |
-
if not isinstance(content, str):
|
| 108 |
-
return "Error: Content must be a string"
|
| 109 |
-
if title and not isinstance(title, str):
|
| 110 |
-
return "Error: Title must be a string"
|
| 111 |
-
if author and not isinstance(author, str):
|
| 112 |
-
return "Error: Author must be a string"
|
| 113 |
-
if date and not isinstance(date, str):
|
| 114 |
-
return "Error: Date must be a string"
|
| 115 |
-
if description and not isinstance(description, str):
|
| 116 |
-
return "Error: Description must be a string"
|
| 117 |
-
if tags and not isinstance(tags, str):
|
| 118 |
-
return "Error: Tags must be a string"
|
| 119 |
-
if additional_fields and not isinstance(additional_fields, str):
|
| 120 |
-
return "Error: Additional_fields must be a string"
|
| 121 |
-
if not isinstance(mode, str):
|
| 122 |
-
return "Error: Mode must be a string"
|
| 123 |
-
|
| 124 |
-
# Validate mode
|
| 125 |
-
valid_modes = ["generate", "extract", "update", "strip"]
|
| 126 |
-
if mode not in valid_modes:
|
| 127 |
-
return f"Error: Invalid mode '{mode}'. Valid options are: {', '.join(valid_modes)}"
|
| 128 |
-
|
| 129 |
-
# Handle empty content
|
| 130 |
-
if not content or not content.strip():
|
| 131 |
-
if mode == "generate":
|
| 132 |
-
# We can still generate frontmatter from provided fields
|
| 133 |
-
content = ""
|
| 134 |
-
else:
|
| 135 |
-
return "Error: Empty content provided"
|
| 136 |
-
|
| 137 |
-
# Special handling for TextInspectorTool output
|
| 138 |
-
if content.startswith("Document content:"):
|
| 139 |
-
content = content[len("Document content:"):].strip()
|
| 140 |
-
|
| 141 |
-
# Process based on mode
|
| 142 |
-
try:
|
| 143 |
-
if mode == "extract":
|
| 144 |
-
return self._extract_frontmatter(content)
|
| 145 |
-
elif mode == "strip":
|
| 146 |
-
return self._strip_frontmatter(content)
|
| 147 |
-
elif mode == "update":
|
| 148 |
-
return self._update_frontmatter(
|
| 149 |
-
content,
|
| 150 |
-
title,
|
| 151 |
-
author,
|
| 152 |
-
date,
|
| 153 |
-
description,
|
| 154 |
-
tags,
|
| 155 |
-
additional_fields,
|
| 156 |
-
date_format,
|
| 157 |
-
)
|
| 158 |
-
else: # generate
|
| 159 |
-
return self._generate_frontmatter(
|
| 160 |
-
content,
|
| 161 |
-
title,
|
| 162 |
-
author,
|
| 163 |
-
date,
|
| 164 |
-
description,
|
| 165 |
-
tags,
|
| 166 |
-
additional_fields,
|
| 167 |
-
date_format,
|
| 168 |
-
)
|
| 169 |
-
except Exception as e:
|
| 170 |
-
return f"Error processing frontmatter: {str(e)}"
|
| 171 |
-
|
| 172 |
-
def _extract_frontmatter(self, content: str) -> str:
|
| 173 |
-
"""Extract and return existing frontmatter as formatted YAML."""
|
| 174 |
-
match = re.search(self.FRONTMATTER_PATTERN, content, re.DOTALL)
|
| 175 |
-
if not match:
|
| 176 |
-
return "No frontmatter found in the document"
|
| 177 |
-
|
| 178 |
-
try:
|
| 179 |
-
yaml_content = match.group(1)
|
| 180 |
-
# Parse and reformat for consistency
|
| 181 |
-
frontmatter_dict = yaml.safe_load(yaml_content)
|
| 182 |
-
return f"Extracted frontmatter:\n\n```yaml\n{yaml.dump(frontmatter_dict, sort_keys=False, default_flow_style=False)}```"
|
| 183 |
-
except yaml.YAMLError:
|
| 184 |
-
return "Found frontmatter but failed to parse it as valid YAML"
|
| 185 |
-
|
| 186 |
-
def _strip_frontmatter(self, content: str) -> str:
|
| 187 |
-
"""Remove frontmatter from document and return clean content."""
|
| 188 |
-
result = re.sub(self.FRONTMATTER_PATTERN, "", content, count=1, flags=re.DOTALL)
|
| 189 |
-
|
| 190 |
-
# Check if anything was actually removed
|
| 191 |
-
if result == content:
|
| 192 |
-
return "No frontmatter found to strip. Content unchanged."
|
| 193 |
-
|
| 194 |
-
return result.strip()
|
| 195 |
-
|
| 196 |
-
def _parse_additional_fields(self, additional_fields: str) -> Dict[str, Any]:
|
| 197 |
-
"""Parse the additional_fields JSON string into a dictionary."""
|
| 198 |
-
if not additional_fields:
|
| 199 |
-
return {}
|
| 200 |
-
|
| 201 |
-
try:
|
| 202 |
-
return json.loads(additional_fields)
|
| 203 |
-
except json.JSONDecodeError:
|
| 204 |
-
raise ValueError("additional_fields must be a valid JSON string")
|
| 205 |
-
|
| 206 |
-
def _infer_title_from_content(self, content: str) -> Optional[str]:
|
| 207 |
-
"""Attempt to infer document title from content."""
|
| 208 |
-
# Try to find the first heading
|
| 209 |
-
heading_match = re.search(r"^#\s+(.+)$", content, re.MULTILINE)
|
| 210 |
-
if heading_match:
|
| 211 |
-
return heading_match.group(1).strip()
|
| 212 |
-
|
| 213 |
-
# Try to find the first non-empty line
|
| 214 |
-
lines = content.split("\n")
|
| 215 |
-
for line in lines:
|
| 216 |
-
if line.strip():
|
| 217 |
-
# Limit to a reasonable title length
|
| 218 |
-
return line.strip()[:100]
|
| 219 |
-
|
| 220 |
-
return None
|
| 221 |
-
|
| 222 |
-
def _parse_tags(self, tags_string: str) -> List[str]:
|
| 223 |
-
"""Parse comma-separated tags into a list."""
|
| 224 |
-
if not tags_string:
|
| 225 |
-
return []
|
| 226 |
-
|
| 227 |
-
# Split by comma and clean each tag
|
| 228 |
-
tag_list = [tag.strip() for tag in tags_string.split(",")]
|
| 229 |
-
# Remove any empty tags
|
| 230 |
-
return [tag for tag in tag_list if tag]
|
| 231 |
-
|
| 232 |
-
def _parse_flexible_date(
|
| 233 |
-
self, date_str: str, date_format: Optional[str] = None
|
| 234 |
-
) -> str:
|
| 235 |
-
"""
|
| 236 |
-
Try to parse dates in various formats and convert to YYYY-MM-DD.
|
| 237 |
-
|
| 238 |
-
Args:
|
| 239 |
-
date_str: The date string to parse
|
| 240 |
-
date_format: Optional preferred format to try first
|
| 241 |
-
|
| 242 |
-
Returns:
|
| 243 |
-
Formatted date as string (YYYY-MM-DD by default)
|
| 244 |
-
"""
|
| 245 |
-
if not date_str:
|
| 246 |
-
return datetime.now().strftime("%Y-%m-%d")
|
| 247 |
-
|
| 248 |
-
# If a specific format is provided, try it first
|
| 249 |
-
if date_format:
|
| 250 |
-
try:
|
| 251 |
-
parsed_date = datetime.strptime(date_str, date_format)
|
| 252 |
-
return parsed_date.strftime("%Y-%m-%d")
|
| 253 |
-
except ValueError:
|
| 254 |
-
# If it fails, continue with other formats
|
| 255 |
-
pass
|
| 256 |
-
|
| 257 |
-
# Common formats to try
|
| 258 |
-
formats = [
|
| 259 |
-
"%Y-%m-%d", # 2013-03-13
|
| 260 |
-
"%d %B %Y", # 13 March 2013
|
| 261 |
-
"%B %Y", # September 2013
|
| 262 |
-
"%Y", # 1958
|
| 263 |
-
"%d/%m/%Y", # 13/03/2013
|
| 264 |
-
"%m/%d/%Y", # 03/13/2013
|
| 265 |
-
"%d-%m-%Y", # 13-03-2013
|
| 266 |
-
"%m-%d-%Y", # 03-13-2013
|
| 267 |
-
"%Y/%m/%d", # 2013/03/13
|
| 268 |
-
]
|
| 269 |
-
|
| 270 |
-
for fmt in formats:
|
| 271 |
-
try:
|
| 272 |
-
parsed_date = datetime.strptime(date_str, fmt)
|
| 273 |
-
return parsed_date.strftime("%Y-%m-%d")
|
| 274 |
-
except ValueError:
|
| 275 |
-
continue
|
| 276 |
-
|
| 277 |
-
# If no format matched, return the original string
|
| 278 |
-
return date_str
|
| 279 |
-
|
| 280 |
-
def _update_frontmatter(
|
| 281 |
-
self,
|
| 282 |
-
content: str,
|
| 283 |
-
title: Optional[str] = None,
|
| 284 |
-
author: Optional[str] = None,
|
| 285 |
-
date: Optional[str] = None,
|
| 286 |
-
description: Optional[str] = None,
|
| 287 |
-
tags: Optional[str] = None,
|
| 288 |
-
additional_fields: Optional[str] = None,
|
| 289 |
-
date_format: Optional[str] = None,
|
| 290 |
-
) -> str:
|
| 291 |
-
"""Update existing frontmatter with new values."""
|
| 292 |
-
# Check if frontmatter exists
|
| 293 |
-
match = re.search(self.FRONTMATTER_PATTERN, content, re.DOTALL)
|
| 294 |
-
if not match:
|
| 295 |
-
# If no frontmatter exists, generate new one
|
| 296 |
-
return self._generate_frontmatter(
|
| 297 |
-
content,
|
| 298 |
-
title,
|
| 299 |
-
author,
|
| 300 |
-
date,
|
| 301 |
-
description,
|
| 302 |
-
tags,
|
| 303 |
-
additional_fields,
|
| 304 |
-
date_format,
|
| 305 |
-
)
|
| 306 |
-
|
| 307 |
-
# Parse existing frontmatter
|
| 308 |
-
yaml_content = match.group(1)
|
| 309 |
-
try:
|
| 310 |
-
frontmatter_dict = yaml.safe_load(yaml_content) or {}
|
| 311 |
-
except yaml.YAMLError:
|
| 312 |
-
frontmatter_dict = {}
|
| 313 |
-
|
| 314 |
-
# Update with new values if provided
|
| 315 |
-
if title:
|
| 316 |
-
frontmatter_dict["title"] = title
|
| 317 |
-
if author:
|
| 318 |
-
frontmatter_dict["author"] = author
|
| 319 |
-
if date:
|
| 320 |
-
# Try to parse the date with the flexible parser
|
| 321 |
-
frontmatter_dict["date"] = self._parse_flexible_date(date, date_format)
|
| 322 |
-
if description:
|
| 323 |
-
frontmatter_dict["description"] = description
|
| 324 |
-
if tags:
|
| 325 |
-
frontmatter_dict["tags"] = self._parse_tags(tags)
|
| 326 |
-
|
| 327 |
-
# Add additional fields
|
| 328 |
-
if additional_fields:
|
| 329 |
-
additional_dict = self._parse_additional_fields(additional_fields)
|
| 330 |
-
frontmatter_dict.update(additional_dict)
|
| 331 |
-
|
| 332 |
-
# Generate new frontmatter
|
| 333 |
-
new_frontmatter = yaml.dump(
|
| 334 |
-
frontmatter_dict, sort_keys=False, default_flow_style=False
|
| 335 |
-
)
|
| 336 |
-
new_frontmatter = f"---\n{new_frontmatter}---\n\n"
|
| 337 |
-
|
| 338 |
-
# Replace old frontmatter with new one
|
| 339 |
-
return re.sub(
|
| 340 |
-
self.FRONTMATTER_PATTERN, new_frontmatter, content, count=1, flags=re.DOTALL
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
def _generate_frontmatter(
|
| 344 |
-
self,
|
| 345 |
-
content: str,
|
| 346 |
-
title: Optional[str] = None,
|
| 347 |
-
author: Optional[str] = None,
|
| 348 |
-
date: Optional[str] = None,
|
| 349 |
-
description: Optional[str] = None,
|
| 350 |
-
tags: Optional[str] = None,
|
| 351 |
-
additional_fields: Optional[str] = None,
|
| 352 |
-
date_format: Optional[str] = None,
|
| 353 |
-
) -> str:
|
| 354 |
-
"""Generate new frontmatter and prepend to content."""
|
| 355 |
-
# Strip any existing frontmatter
|
| 356 |
-
clean_content = (
|
| 357 |
-
self._strip_frontmatter(content) if isinstance(content, str) else ""
|
| 358 |
-
)
|
| 359 |
-
|
| 360 |
-
# Build frontmatter dictionary
|
| 361 |
-
frontmatter_dict = {}
|
| 362 |
-
|
| 363 |
-
# Try to infer title if not provided
|
| 364 |
-
if title:
|
| 365 |
-
frontmatter_dict["title"] = title
|
| 366 |
-
else:
|
| 367 |
-
inferred_title = self._infer_title_from_content(clean_content)
|
| 368 |
-
if inferred_title:
|
| 369 |
-
frontmatter_dict["title"] = inferred_title
|
| 370 |
-
|
| 371 |
-
# Add other fields if provided
|
| 372 |
-
if author:
|
| 373 |
-
frontmatter_dict["author"] = author
|
| 374 |
-
|
| 375 |
-
# Process date with flexible parser
|
| 376 |
-
if date:
|
| 377 |
-
frontmatter_dict["date"] = self._parse_flexible_date(date, date_format)
|
| 378 |
-
else:
|
| 379 |
-
# Use current date with provided format or default
|
| 380 |
-
format_to_use = date_format or "%Y-%m-%d"
|
| 381 |
-
frontmatter_dict["date"] = datetime.now().strftime(format_to_use)
|
| 382 |
-
|
| 383 |
-
if description:
|
| 384 |
-
frontmatter_dict["description"] = description
|
| 385 |
-
|
| 386 |
-
if tags:
|
| 387 |
-
frontmatter_dict["tags"] = self._parse_tags(tags)
|
| 388 |
-
|
| 389 |
-
# Add additional fields
|
| 390 |
-
if additional_fields:
|
| 391 |
-
additional_dict = self._parse_additional_fields(additional_fields)
|
| 392 |
-
frontmatter_dict.update(additional_dict)
|
| 393 |
-
|
| 394 |
-
# Generate YAML frontmatter
|
| 395 |
-
frontmatter_yaml = yaml.dump(
|
| 396 |
-
frontmatter_dict, sort_keys=False, default_flow_style=False
|
| 397 |
-
)
|
| 398 |
-
frontmatter = f"---\n{frontmatter_yaml}---\n\n"
|
| 399 |
-
|
| 400 |
-
# Combine frontmatter with content
|
| 401 |
-
return frontmatter + clean_content
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/gaia_scorer.py
DELETED
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import string
|
| 3 |
-
import warnings
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def normalize_number_str(number_str: str) -> float:
|
| 7 |
-
# we replace these common units and commas to allow
|
| 8 |
-
# conversion to float
|
| 9 |
-
for char in ["$", "%", ","]:
|
| 10 |
-
number_str = number_str.replace(char, "")
|
| 11 |
-
try:
|
| 12 |
-
return float(number_str)
|
| 13 |
-
except ValueError:
|
| 14 |
-
print(f"String {number_str} cannot be normalized to number str.")
|
| 15 |
-
return float("inf")
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def split_string(
|
| 19 |
-
s: str,
|
| 20 |
-
char_list: list[str] = [",", ";"],
|
| 21 |
-
) -> list[str]:
|
| 22 |
-
pattern = f"[{''.join(char_list)}]"
|
| 23 |
-
return re.split(pattern, s)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def is_float(element: any) -> bool:
|
| 27 |
-
try:
|
| 28 |
-
float(element)
|
| 29 |
-
return True
|
| 30 |
-
except ValueError:
|
| 31 |
-
return False
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def question_scorer(
|
| 35 |
-
model_answer: str,
|
| 36 |
-
ground_truth: str,
|
| 37 |
-
) -> bool:
|
| 38 |
-
# if gt is a number
|
| 39 |
-
if is_float(ground_truth):
|
| 40 |
-
normalized_answer = normalize_number_str(str(model_answer))
|
| 41 |
-
return normalized_answer == float(ground_truth)
|
| 42 |
-
|
| 43 |
-
# if gt is a list
|
| 44 |
-
elif any(char in ground_truth for char in [",", ";"]):
|
| 45 |
-
# question with the fish: normalization removes punct
|
| 46 |
-
|
| 47 |
-
gt_elems = split_string(ground_truth)
|
| 48 |
-
ma_elems = split_string(model_answer)
|
| 49 |
-
|
| 50 |
-
# check length is the same
|
| 51 |
-
if len(gt_elems) != len(ma_elems):
|
| 52 |
-
warnings.warn("Answer lists have different lengths, returning False.", UserWarning)
|
| 53 |
-
return False
|
| 54 |
-
|
| 55 |
-
# compare each element as float or str
|
| 56 |
-
comparisons = []
|
| 57 |
-
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
|
| 58 |
-
if is_float(gt_elem):
|
| 59 |
-
normalized_ma_elem = normalize_number_str(ma_elem)
|
| 60 |
-
comparisons.append(normalized_ma_elem == float(gt_elem))
|
| 61 |
-
else:
|
| 62 |
-
# we do not remove punct since comparisons can include punct
|
| 63 |
-
comparisons.append(
|
| 64 |
-
normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)
|
| 65 |
-
)
|
| 66 |
-
return all(comparisons)
|
| 67 |
-
|
| 68 |
-
# if gt is a str
|
| 69 |
-
else:
|
| 70 |
-
return normalize_str(model_answer) == normalize_str(ground_truth)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def check_prediction_contains_answer_letters_in_order(prediction, true_answer):
|
| 74 |
-
prediction = prediction.lower()
|
| 75 |
-
true_answer = true_answer.lower()
|
| 76 |
-
if len(prediction) > len(true_answer) * 3:
|
| 77 |
-
return False
|
| 78 |
-
i = 0
|
| 79 |
-
for letter in true_answer:
|
| 80 |
-
if letter in prediction[i:]:
|
| 81 |
-
i += prediction[i:].index(letter)
|
| 82 |
-
else:
|
| 83 |
-
return False
|
| 84 |
-
return True
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def check_close_call(prediction, true_answer, is_correct):
|
| 88 |
-
if is_correct:
|
| 89 |
-
return True
|
| 90 |
-
else:
|
| 91 |
-
if is_float(true_answer):
|
| 92 |
-
return is_correct
|
| 93 |
-
else:
|
| 94 |
-
if (
|
| 95 |
-
check_prediction_contains_answer_letters_in_order(str(prediction), str(true_answer))
|
| 96 |
-
and len(str(true_answer)) * 0.5 <= len(str(prediction)) <= len(str(true_answer)) * 2
|
| 97 |
-
):
|
| 98 |
-
print(f"Close call: {prediction} vs {true_answer}")
|
| 99 |
-
return True
|
| 100 |
-
else:
|
| 101 |
-
return False
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def normalize_str(input_str, remove_punct=True) -> str:
|
| 105 |
-
"""
|
| 106 |
-
Normalize a string by:
|
| 107 |
-
- Removing all white spaces
|
| 108 |
-
- Optionally removing punctuation (if remove_punct is True)
|
| 109 |
-
- Converting to lowercase
|
| 110 |
-
Parameters:
|
| 111 |
-
- input_str: str, the string to normalize
|
| 112 |
-
- remove_punct: bool, whether to remove punctuation (default: True)
|
| 113 |
-
Returns:
|
| 114 |
-
- str, the normalized string
|
| 115 |
-
"""
|
| 116 |
-
# Remove all white spaces. Required e.g for seagull vs. sea gull
|
| 117 |
-
no_spaces = re.sub(r"\s", "", input_str)
|
| 118 |
-
|
| 119 |
-
# Remove punctuation, if specified.
|
| 120 |
-
if remove_punct:
|
| 121 |
-
translator = str.maketrans("", "", string.punctuation)
|
| 122 |
-
return no_spaces.lower().translate(translator)
|
| 123 |
-
else:
|
| 124 |
-
return no_spaces.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/mdconvert.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
# This is copied from Magentic-one's great repo: https://github.com/microsoft/autogen/blob/v0.4.4/python/packages/autogen-magentic-one/src/autogen_magentic_one/markdown_browser/mdconvert.py
|
| 2 |
# Thanks to Microsoft researchers for open-sourcing this!
|
| 3 |
# type: ignore
|
|
@@ -22,7 +24,6 @@ import pandas as pd
|
|
| 22 |
import pdfminer
|
| 23 |
import pdfminer.high_level
|
| 24 |
import pptx
|
| 25 |
-
|
| 26 |
# File-format detection
|
| 27 |
import puremagic
|
| 28 |
import pydub
|
|
@@ -86,7 +87,11 @@ class _CustomMarkdownify(markdownify.MarkdownConverter):
|
|
| 86 |
if self.options["default_title"] and not title:
|
| 87 |
title = href
|
| 88 |
title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
|
| 89 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
def convert_img(self, el: Any, text: str, convert_as_inline: bool) -> str:
|
| 92 |
"""Same as usual converter, but removes data URIs"""
|
|
@@ -95,7 +100,10 @@ class _CustomMarkdownify(markdownify.MarkdownConverter):
|
|
| 95 |
src = el.attrs.get("src", None) or ""
|
| 96 |
title = el.attrs.get("title", None) or ""
|
| 97 |
title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
|
| 98 |
-
if
|
|
|
|
|
|
|
|
|
|
| 99 |
return alt
|
| 100 |
|
| 101 |
# Remove dataURIs
|
|
@@ -119,16 +127,22 @@ class DocumentConverterResult:
|
|
| 119 |
class DocumentConverter:
|
| 120 |
"""Abstract superclass of all DocumentConverters."""
|
| 121 |
|
| 122 |
-
def convert(
|
|
|
|
|
|
|
| 123 |
raise NotImplementedError()
|
| 124 |
|
| 125 |
|
| 126 |
class PlainTextConverter(DocumentConverter):
|
| 127 |
"""Anything with content type text/plain"""
|
| 128 |
|
| 129 |
-
def convert(
|
|
|
|
|
|
|
| 130 |
# Guess the content type from any file extension that might be around
|
| 131 |
-
content_type, _ = mimetypes.guess_type(
|
|
|
|
|
|
|
| 132 |
|
| 133 |
# Only accept text files
|
| 134 |
if content_type is None:
|
|
@@ -148,7 +162,9 @@ class PlainTextConverter(DocumentConverter):
|
|
| 148 |
class HtmlConverter(DocumentConverter):
|
| 149 |
"""Anything with content type text/html"""
|
| 150 |
|
| 151 |
-
def convert(
|
|
|
|
|
|
|
| 152 |
# Bail if not html
|
| 153 |
extension = kwargs.get("file_extension", "")
|
| 154 |
if extension.lower() not in [".html", ".htm"]:
|
|
@@ -181,14 +197,17 @@ class HtmlConverter(DocumentConverter):
|
|
| 181 |
assert isinstance(webpage_text, str)
|
| 182 |
|
| 183 |
return DocumentConverterResult(
|
| 184 |
-
title=None if soup.title is None else soup.title.string,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
|
| 188 |
class WikipediaConverter(DocumentConverter):
|
| 189 |
"""Handle Wikipedia pages separately, focusing only on the main document content."""
|
| 190 |
|
| 191 |
-
def convert(
|
|
|
|
|
|
|
| 192 |
# Bail if not Wikipedia
|
| 193 |
extension = kwargs.get("file_extension", "")
|
| 194 |
if extension.lower() not in [".html", ".htm"]:
|
|
@@ -220,7 +239,9 @@ class WikipediaConverter(DocumentConverter):
|
|
| 220 |
assert isinstance(main_title, str)
|
| 221 |
|
| 222 |
# Convert the page
|
| 223 |
-
webpage_text = f"# {main_title}\n\n" + _CustomMarkdownify().convert_soup(
|
|
|
|
|
|
|
| 224 |
else:
|
| 225 |
webpage_text = _CustomMarkdownify().convert_soup(soup)
|
| 226 |
|
|
@@ -233,7 +254,9 @@ class WikipediaConverter(DocumentConverter):
|
|
| 233 |
class YouTubeConverter(DocumentConverter):
|
| 234 |
"""Handle YouTube specially, focusing on the video title, description, and transcript."""
|
| 235 |
|
| 236 |
-
def convert(
|
|
|
|
|
|
|
| 237 |
# Bail if not YouTube
|
| 238 |
extension = kwargs.get("file_extension", "")
|
| 239 |
if extension.lower() not in [".html", ".htm"]:
|
|
@@ -327,7 +350,12 @@ class YouTubeConverter(DocumentConverter):
|
|
| 327 |
text_content=webpage_text,
|
| 328 |
)
|
| 329 |
|
| 330 |
-
def _get(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
for k in keys:
|
| 332 |
if k in metadata:
|
| 333 |
return metadata[k]
|
|
@@ -444,7 +472,13 @@ class PptxConverter(HtmlConverter):
|
|
| 444 |
|
| 445 |
# A placeholder name
|
| 446 |
filename = re.sub(r"\W", "", shape.name) + ".jpg"
|
| 447 |
-
md_content +=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
|
| 449 |
# Tables
|
| 450 |
if self._is_table(shape):
|
|
@@ -460,7 +494,9 @@ class PptxConverter(HtmlConverter):
|
|
| 460 |
html_table += "</tr>"
|
| 461 |
first_row = False
|
| 462 |
html_table += "</table></body></html>"
|
| 463 |
-
md_content +=
|
|
|
|
|
|
|
| 464 |
|
| 465 |
# Text areas
|
| 466 |
elif shape.has_text_frame:
|
|
@@ -508,7 +544,9 @@ class MediaConverter(DocumentConverter):
|
|
| 508 |
return None
|
| 509 |
else:
|
| 510 |
try:
|
| 511 |
-
result = subprocess.run(
|
|
|
|
|
|
|
| 512 |
return json.loads(result)[0]
|
| 513 |
except Exception:
|
| 514 |
return None
|
|
@@ -548,9 +586,13 @@ class WavConverter(MediaConverter):
|
|
| 548 |
# Transcribe
|
| 549 |
try:
|
| 550 |
transcript = self._transcribe_audio(local_path)
|
| 551 |
-
md_content += "\n\n### Audio Transcript:\n" + (
|
|
|
|
|
|
|
| 552 |
except Exception:
|
| 553 |
-
md_content +=
|
|
|
|
|
|
|
| 554 |
|
| 555 |
return DocumentConverterResult(
|
| 556 |
title=None,
|
|
@@ -612,7 +654,9 @@ class Mp3Converter(WavConverter):
|
|
| 612 |
"[No speech detected]" if transcript == "" else transcript
|
| 613 |
)
|
| 614 |
except Exception:
|
| 615 |
-
md_content +=
|
|
|
|
|
|
|
| 616 |
|
| 617 |
finally:
|
| 618 |
os.unlink(temp_path)
|
|
@@ -662,7 +706,11 @@ class ImageConverter(MediaConverter):
|
|
| 662 |
md_content += (
|
| 663 |
"\n# Description:\n"
|
| 664 |
+ self._get_mlm_description(
|
| 665 |
-
local_path,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
).strip()
|
| 667 |
+ "\n"
|
| 668 |
)
|
|
@@ -759,7 +807,11 @@ class MarkdownConverter:
|
|
| 759 |
|
| 760 |
# Local path or url
|
| 761 |
if isinstance(source, str):
|
| 762 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
return self.convert_url(source, **kwargs)
|
| 764 |
else:
|
| 765 |
return self.convert_local(source, **kwargs)
|
|
@@ -767,7 +819,9 @@ class MarkdownConverter:
|
|
| 767 |
elif isinstance(source, requests.Response):
|
| 768 |
return self.convert_response(source, **kwargs)
|
| 769 |
|
| 770 |
-
def convert_local(
|
|
|
|
|
|
|
| 771 |
# Prepare a list of extensions to try (in order of priority)
|
| 772 |
ext = kwargs.get("file_extension")
|
| 773 |
extensions = [ext] if ext is not None else []
|
|
@@ -781,7 +835,9 @@ class MarkdownConverter:
|
|
| 781 |
return self._convert(path, extensions, **kwargs)
|
| 782 |
|
| 783 |
# TODO what should stream's type be?
|
| 784 |
-
def convert_stream(
|
|
|
|
|
|
|
| 785 |
# Prepare a list of extensions to try (in order of priority)
|
| 786 |
ext = kwargs.get("file_extension")
|
| 787 |
extensions = [ext] if ext is not None else []
|
|
@@ -814,10 +870,14 @@ class MarkdownConverter:
|
|
| 814 |
|
| 815 |
return result
|
| 816 |
|
| 817 |
-
def convert_url(
|
|
|
|
|
|
|
| 818 |
# Send a HTTP request to the URL
|
| 819 |
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
|
| 820 |
-
response = self._requests_session.get(
|
|
|
|
|
|
|
| 821 |
response.raise_for_status()
|
| 822 |
return self.convert_response(response, **kwargs)
|
| 823 |
|
|
@@ -871,7 +931,9 @@ class MarkdownConverter:
|
|
| 871 |
|
| 872 |
return result
|
| 873 |
|
| 874 |
-
def _convert(
|
|
|
|
|
|
|
| 875 |
error_trace = ""
|
| 876 |
for ext in extensions + [None]: # Try last with no extension
|
| 877 |
for converter in self._page_converters:
|
|
@@ -899,7 +961,9 @@ class MarkdownConverter:
|
|
| 899 |
|
| 900 |
if res is not None:
|
| 901 |
# Normalize the content
|
| 902 |
-
res.text_content = "\n".join(
|
|
|
|
|
|
|
| 903 |
res.text_content = re.sub(r"\n{3,}", "\n\n", res.text_content)
|
| 904 |
|
| 905 |
# Todo
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
# This is copied from Magentic-one's great repo: https://github.com/microsoft/autogen/blob/v0.4.4/python/packages/autogen-magentic-one/src/autogen_magentic_one/markdown_browser/mdconvert.py
|
| 4 |
# Thanks to Microsoft researchers for open-sourcing this!
|
| 5 |
# type: ignore
|
|
|
|
| 24 |
import pdfminer
|
| 25 |
import pdfminer.high_level
|
| 26 |
import pptx
|
|
|
|
| 27 |
# File-format detection
|
| 28 |
import puremagic
|
| 29 |
import pydub
|
|
|
|
| 87 |
if self.options["default_title"] and not title:
|
| 88 |
title = href
|
| 89 |
title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
|
| 90 |
+
return (
|
| 91 |
+
"%s[%s](%s%s)%s" % (prefix, text, href, title_part, suffix)
|
| 92 |
+
if href
|
| 93 |
+
else text
|
| 94 |
+
)
|
| 95 |
|
| 96 |
def convert_img(self, el: Any, text: str, convert_as_inline: bool) -> str:
|
| 97 |
"""Same as usual converter, but removes data URIs"""
|
|
|
|
| 100 |
src = el.attrs.get("src", None) or ""
|
| 101 |
title = el.attrs.get("title", None) or ""
|
| 102 |
title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
|
| 103 |
+
if (
|
| 104 |
+
convert_as_inline
|
| 105 |
+
and el.parent.name not in self.options["keep_inline_images_in"]
|
| 106 |
+
):
|
| 107 |
return alt
|
| 108 |
|
| 109 |
# Remove dataURIs
|
|
|
|
| 127 |
class DocumentConverter:
|
| 128 |
"""Abstract superclass of all DocumentConverters."""
|
| 129 |
|
| 130 |
+
def convert(
|
| 131 |
+
self, local_path: str, **kwargs: Any
|
| 132 |
+
) -> Union[None, DocumentConverterResult]:
|
| 133 |
raise NotImplementedError()
|
| 134 |
|
| 135 |
|
| 136 |
class PlainTextConverter(DocumentConverter):
|
| 137 |
"""Anything with content type text/plain"""
|
| 138 |
|
| 139 |
+
def convert(
|
| 140 |
+
self, local_path: str, **kwargs: Any
|
| 141 |
+
) -> Union[None, DocumentConverterResult]:
|
| 142 |
# Guess the content type from any file extension that might be around
|
| 143 |
+
content_type, _ = mimetypes.guess_type(
|
| 144 |
+
"__placeholder" + kwargs.get("file_extension", "")
|
| 145 |
+
)
|
| 146 |
|
| 147 |
# Only accept text files
|
| 148 |
if content_type is None:
|
|
|
|
| 162 |
class HtmlConverter(DocumentConverter):
|
| 163 |
"""Anything with content type text/html"""
|
| 164 |
|
| 165 |
+
def convert(
|
| 166 |
+
self, local_path: str, **kwargs: Any
|
| 167 |
+
) -> Union[None, DocumentConverterResult]:
|
| 168 |
# Bail if not html
|
| 169 |
extension = kwargs.get("file_extension", "")
|
| 170 |
if extension.lower() not in [".html", ".htm"]:
|
|
|
|
| 197 |
assert isinstance(webpage_text, str)
|
| 198 |
|
| 199 |
return DocumentConverterResult(
|
| 200 |
+
title=None if soup.title is None else soup.title.string,
|
| 201 |
+
text_content=webpage_text,
|
| 202 |
)
|
| 203 |
|
| 204 |
|
| 205 |
class WikipediaConverter(DocumentConverter):
|
| 206 |
"""Handle Wikipedia pages separately, focusing only on the main document content."""
|
| 207 |
|
| 208 |
+
def convert(
|
| 209 |
+
self, local_path: str, **kwargs: Any
|
| 210 |
+
) -> Union[None, DocumentConverterResult]:
|
| 211 |
# Bail if not Wikipedia
|
| 212 |
extension = kwargs.get("file_extension", "")
|
| 213 |
if extension.lower() not in [".html", ".htm"]:
|
|
|
|
| 239 |
assert isinstance(main_title, str)
|
| 240 |
|
| 241 |
# Convert the page
|
| 242 |
+
webpage_text = f"# {main_title}\n\n" + _CustomMarkdownify().convert_soup(
|
| 243 |
+
body_elm
|
| 244 |
+
)
|
| 245 |
else:
|
| 246 |
webpage_text = _CustomMarkdownify().convert_soup(soup)
|
| 247 |
|
|
|
|
| 254 |
class YouTubeConverter(DocumentConverter):
|
| 255 |
"""Handle YouTube specially, focusing on the video title, description, and transcript."""
|
| 256 |
|
| 257 |
+
def convert(
|
| 258 |
+
self, local_path: str, **kwargs: Any
|
| 259 |
+
) -> Union[None, DocumentConverterResult]:
|
| 260 |
# Bail if not YouTube
|
| 261 |
extension = kwargs.get("file_extension", "")
|
| 262 |
if extension.lower() not in [".html", ".htm"]:
|
|
|
|
| 350 |
text_content=webpage_text,
|
| 351 |
)
|
| 352 |
|
| 353 |
+
def _get(
|
| 354 |
+
self,
|
| 355 |
+
metadata: Dict[str, str],
|
| 356 |
+
keys: List[str],
|
| 357 |
+
default: Union[str, None] = None,
|
| 358 |
+
) -> Union[str, None]:
|
| 359 |
for k in keys:
|
| 360 |
if k in metadata:
|
| 361 |
return metadata[k]
|
|
|
|
| 472 |
|
| 473 |
# A placeholder name
|
| 474 |
filename = re.sub(r"\W", "", shape.name) + ".jpg"
|
| 475 |
+
md_content += (
|
| 476 |
+
"\n\n"
|
| 481 |
+
)
|
| 482 |
|
| 483 |
# Tables
|
| 484 |
if self._is_table(shape):
|
|
|
|
| 494 |
html_table += "</tr>"
|
| 495 |
first_row = False
|
| 496 |
html_table += "</table></body></html>"
|
| 497 |
+
md_content += (
|
| 498 |
+
"\n" + self._convert(html_table).text_content.strip() + "\n"
|
| 499 |
+
)
|
| 500 |
|
| 501 |
# Text areas
|
| 502 |
elif shape.has_text_frame:
|
|
|
|
| 544 |
return None
|
| 545 |
else:
|
| 546 |
try:
|
| 547 |
+
result = subprocess.run(
|
| 548 |
+
[exiftool, "-json", local_path], capture_output=True, text=True
|
| 549 |
+
).stdout
|
| 550 |
return json.loads(result)[0]
|
| 551 |
except Exception:
|
| 552 |
return None
|
|
|
|
| 586 |
# Transcribe
|
| 587 |
try:
|
| 588 |
transcript = self._transcribe_audio(local_path)
|
| 589 |
+
md_content += "\n\n### Audio Transcript:\n" + (
|
| 590 |
+
"[No speech detected]" if transcript == "" else transcript
|
| 591 |
+
)
|
| 592 |
except Exception:
|
| 593 |
+
md_content += (
|
| 594 |
+
"\n\n### Audio Transcript:\nError. Could not transcribe this audio."
|
| 595 |
+
)
|
| 596 |
|
| 597 |
return DocumentConverterResult(
|
| 598 |
title=None,
|
|
|
|
| 654 |
"[No speech detected]" if transcript == "" else transcript
|
| 655 |
)
|
| 656 |
except Exception:
|
| 657 |
+
md_content += (
|
| 658 |
+
"\n\n### Audio Transcript:\nError. Could not transcribe this audio."
|
| 659 |
+
)
|
| 660 |
|
| 661 |
finally:
|
| 662 |
os.unlink(temp_path)
|
|
|
|
| 706 |
md_content += (
|
| 707 |
"\n# Description:\n"
|
| 708 |
+ self._get_mlm_description(
|
| 709 |
+
local_path,
|
| 710 |
+
extension,
|
| 711 |
+
mlm_client,
|
| 712 |
+
mlm_model,
|
| 713 |
+
prompt=kwargs.get("mlm_prompt"),
|
| 714 |
).strip()
|
| 715 |
+ "\n"
|
| 716 |
)
|
|
|
|
| 807 |
|
| 808 |
# Local path or url
|
| 809 |
if isinstance(source, str):
|
| 810 |
+
if (
|
| 811 |
+
source.startswith("http://")
|
| 812 |
+
or source.startswith("https://")
|
| 813 |
+
or source.startswith("file://")
|
| 814 |
+
):
|
| 815 |
return self.convert_url(source, **kwargs)
|
| 816 |
else:
|
| 817 |
return self.convert_local(source, **kwargs)
|
|
|
|
| 819 |
elif isinstance(source, requests.Response):
|
| 820 |
return self.convert_response(source, **kwargs)
|
| 821 |
|
| 822 |
+
def convert_local(
|
| 823 |
+
self, path: str, **kwargs: Any
|
| 824 |
+
) -> DocumentConverterResult: # TODO: deal with kwargs
|
| 825 |
# Prepare a list of extensions to try (in order of priority)
|
| 826 |
ext = kwargs.get("file_extension")
|
| 827 |
extensions = [ext] if ext is not None else []
|
|
|
|
| 835 |
return self._convert(path, extensions, **kwargs)
|
| 836 |
|
| 837 |
# TODO what should stream's type be?
|
| 838 |
+
def convert_stream(
|
| 839 |
+
self, stream: Any, **kwargs: Any
|
| 840 |
+
) -> DocumentConverterResult: # TODO: deal with kwargs
|
| 841 |
# Prepare a list of extensions to try (in order of priority)
|
| 842 |
ext = kwargs.get("file_extension")
|
| 843 |
extensions = [ext] if ext is not None else []
|
|
|
|
| 870 |
|
| 871 |
return result
|
| 872 |
|
| 873 |
+
def convert_url(
|
| 874 |
+
self, url: str, **kwargs: Any
|
| 875 |
+
) -> DocumentConverterResult: # TODO: fix kwargs type
|
| 876 |
# Send a HTTP request to the URL
|
| 877 |
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
|
| 878 |
+
response = self._requests_session.get(
|
| 879 |
+
url, stream=True, headers={"User-Agent": user_agent}
|
| 880 |
+
)
|
| 881 |
response.raise_for_status()
|
| 882 |
return self.convert_response(response, **kwargs)
|
| 883 |
|
|
|
|
| 931 |
|
| 932 |
return result
|
| 933 |
|
| 934 |
+
def _convert(
|
| 935 |
+
self, local_path: str, extensions: List[Union[str, None]], **kwargs
|
| 936 |
+
) -> DocumentConverterResult:
|
| 937 |
error_trace = ""
|
| 938 |
for ext in extensions + [None]: # Try last with no extension
|
| 939 |
for converter in self._page_converters:
|
|
|
|
| 961 |
|
| 962 |
if res is not None:
|
| 963 |
# Normalize the content
|
| 964 |
+
res.text_content = "\n".join(
|
| 965 |
+
[line.rstrip() for line in re.split(r"\r?\n", res.text_content)]
|
| 966 |
+
)
|
| 967 |
res.text_content = re.sub(r"\n{3,}", "\n\n", res.text_content)
|
| 968 |
|
| 969 |
# Todo
|
scripts/reformulator.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
# Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
|
| 2 |
-
# https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
|
| 3 |
-
import copy
|
| 4 |
-
|
| 5 |
-
from smolagents.models import MessageRole, Model
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def prepare_response(original_task: str, inner_messages, reformulation_model: Model) -> str:
|
| 9 |
-
messages = [
|
| 10 |
-
{
|
| 11 |
-
"role": MessageRole.SYSTEM,
|
| 12 |
-
"content": [
|
| 13 |
-
{
|
| 14 |
-
"type": "text",
|
| 15 |
-
"text": f"""Earlier you were asked the following:
|
| 16 |
-
|
| 17 |
-
{original_task}
|
| 18 |
-
|
| 19 |
-
Your team then worked diligently to address that request. Read below a transcript of that conversation:""",
|
| 20 |
-
}
|
| 21 |
-
],
|
| 22 |
-
}
|
| 23 |
-
]
|
| 24 |
-
|
| 25 |
-
# The first message just repeats the question, so remove it
|
| 26 |
-
# if len(inner_messages) > 1:
|
| 27 |
-
# del inner_messages[0]
|
| 28 |
-
|
| 29 |
-
# copy them to this context
|
| 30 |
-
try:
|
| 31 |
-
for message in inner_messages:
|
| 32 |
-
if not message.get("content"):
|
| 33 |
-
continue
|
| 34 |
-
message = copy.deepcopy(message)
|
| 35 |
-
message["role"] = MessageRole.USER
|
| 36 |
-
messages.append(message)
|
| 37 |
-
except Exception:
|
| 38 |
-
messages += [{"role": MessageRole.ASSISTANT, "content": str(inner_messages)}]
|
| 39 |
-
|
| 40 |
-
# ask for the final answer
|
| 41 |
-
messages.append(
|
| 42 |
-
{
|
| 43 |
-
"role": MessageRole.USER,
|
| 44 |
-
"content": [
|
| 45 |
-
{
|
| 46 |
-
"type": "text",
|
| 47 |
-
"text": f"""
|
| 48 |
-
Read the above conversation and output a FINAL ANSWER to the question. The question is repeated here for convenience:
|
| 49 |
-
|
| 50 |
-
{original_task}
|
| 51 |
-
|
| 52 |
-
To output the final answer, use the following template: FINAL ANSWER: [YOUR FINAL ANSWER]
|
| 53 |
-
Your FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
| 54 |
-
ADDITIONALLY, your FINAL ANSWER MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)
|
| 55 |
-
If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and DO NOT INCLUDE UNITS such as $ or USD or percent signs unless specified otherwise.
|
| 56 |
-
If you are asked for a string, don't use articles or abbreviations (e.g. for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'.
|
| 57 |
-
If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings.
|
| 58 |
-
If you are unable to determine the final answer, output 'FINAL ANSWER: Unable to determine'
|
| 59 |
-
""",
|
| 60 |
-
}
|
| 61 |
-
],
|
| 62 |
-
}
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
response = reformulation_model(messages).content
|
| 66 |
-
|
| 67 |
-
final_answer = response.split("FINAL ANSWER: ")[-1].strip()
|
| 68 |
-
print("> Reformulated answer: ", final_answer)
|
| 69 |
-
|
| 70 |
-
# if "unable to determine" in final_answer.lower():
|
| 71 |
-
# messages.append({"role": MessageRole.ASSISTANT, "content": response })
|
| 72 |
-
# messages.append({"role": MessageRole.USER, "content": [{"type": "text", "text": """
|
| 73 |
-
# I understand that a definitive answer could not be determined. Please make a well-informed EDUCATED GUESS based on the conversation.
|
| 74 |
-
|
| 75 |
-
# To output the educated guess, use the following template: EDUCATED GUESS: [YOUR EDUCATED GUESS]
|
| 76 |
-
# Your EDUCATED GUESS should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. DO NOT OUTPUT 'I don't know', 'Unable to determine', etc.
|
| 77 |
-
# ADDITIONALLY, your EDUCATED GUESS MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)
|
| 78 |
-
# If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and don't include units such as $ or percent signs unless specified otherwise.
|
| 79 |
-
# If you are asked for a string, don't use articles or abbreviations (e.g. cit for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'.
|
| 80 |
-
# If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings.
|
| 81 |
-
# """.strip()}]})
|
| 82 |
-
|
| 83 |
-
# response = model(messages).content
|
| 84 |
-
# print("\n>>>Making an educated guess.\n", response)
|
| 85 |
-
# final_answer = response.split("EDUCATED GUESS: ")[-1].strip()
|
| 86 |
-
return final_answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run_agents.py
DELETED
|
@@ -1,87 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
import shutil
|
| 4 |
-
import textwrap
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
# import tqdm.asyncio
|
| 8 |
-
from smolagents.utils import AgentError
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def serialize_agent_error(obj):
|
| 12 |
-
if isinstance(obj, AgentError):
|
| 13 |
-
return {"error_type": obj.__class__.__name__, "message": obj.message}
|
| 14 |
-
else:
|
| 15 |
-
return str(obj)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def get_image_description(file_name: str, question: str, visual_inspection_tool) -> str:
|
| 19 |
-
prompt = f"""Write a caption of 5 sentences for this image. Pay special attention to any details that might be useful for someone answering the following question:
|
| 20 |
-
{question}. But do not try to answer the question directly!
|
| 21 |
-
Do not add any information that is not present in the image."""
|
| 22 |
-
return visual_inspection_tool(image_path=file_name, question=prompt)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def get_document_description(file_path: str, question: str, document_inspection_tool) -> str:
|
| 26 |
-
prompt = f"""Write a caption of 5 sentences for this document. Pay special attention to any details that might be useful for someone answering the following question:
|
| 27 |
-
{question}. But do not try to answer the question directly!
|
| 28 |
-
Do not add any information that is not present in the document."""
|
| 29 |
-
return document_inspection_tool.forward_initial_exam_mode(file_path=file_path, question=prompt)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def get_single_file_description(file_path: str, question: str, visual_inspection_tool, document_inspection_tool):
|
| 33 |
-
file_extension = file_path.split(".")[-1]
|
| 34 |
-
if file_extension in ["png", "jpg", "jpeg"]:
|
| 35 |
-
file_description = f" - Attached image: {file_path}"
|
| 36 |
-
file_description += (
|
| 37 |
-
f"\n -> Image description: {get_image_description(file_path, question, visual_inspection_tool)}"
|
| 38 |
-
)
|
| 39 |
-
return file_description
|
| 40 |
-
elif file_extension in ["pdf", "xls", "xlsx", "docx", "doc", "xml"]:
|
| 41 |
-
file_description = f" - Attached document: {file_path}"
|
| 42 |
-
image_path = file_path.split(".")[0] + ".png"
|
| 43 |
-
if os.path.exists(image_path):
|
| 44 |
-
description = get_image_description(image_path, question, visual_inspection_tool)
|
| 45 |
-
else:
|
| 46 |
-
description = get_document_description(file_path, question, document_inspection_tool)
|
| 47 |
-
file_description += f"\n -> File description: {description}"
|
| 48 |
-
return file_description
|
| 49 |
-
elif file_extension in ["mp3", "m4a", "wav"]:
|
| 50 |
-
return f" - Attached audio: {file_path}"
|
| 51 |
-
else:
|
| 52 |
-
return f" - Attached file: {file_path}"
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def get_zip_description(file_path: str, question: str, visual_inspection_tool, document_inspection_tool):
|
| 56 |
-
folder_path = file_path.replace(".zip", "")
|
| 57 |
-
os.makedirs(folder_path, exist_ok=True)
|
| 58 |
-
shutil.unpack_archive(file_path, folder_path)
|
| 59 |
-
|
| 60 |
-
prompt_use_files = ""
|
| 61 |
-
for root, dirs, files in os.walk(folder_path):
|
| 62 |
-
for file in files:
|
| 63 |
-
file_path = os.path.join(root, file)
|
| 64 |
-
prompt_use_files += "\n" + textwrap.indent(
|
| 65 |
-
get_single_file_description(file_path, question, visual_inspection_tool, document_inspection_tool),
|
| 66 |
-
prefix=" ",
|
| 67 |
-
)
|
| 68 |
-
return prompt_use_files
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def get_tasks_to_run(data, total: int, base_filename: Path, tasks_ids: list[int]):
|
| 72 |
-
f = base_filename.parent / f"{base_filename.stem}_answers.jsonl"
|
| 73 |
-
done = set()
|
| 74 |
-
if f.exists():
|
| 75 |
-
with open(f, encoding="utf-8") as fh:
|
| 76 |
-
done = {json.loads(line)["task_id"] for line in fh if line.strip()}
|
| 77 |
-
|
| 78 |
-
tasks = []
|
| 79 |
-
for i in range(total):
|
| 80 |
-
task_id = int(data[i]["task_id"])
|
| 81 |
-
if task_id not in done:
|
| 82 |
-
if tasks_ids is not None:
|
| 83 |
-
if task_id in tasks_ids:
|
| 84 |
-
tasks.append(data[i])
|
| 85 |
-
else:
|
| 86 |
-
tasks.append(data[i])
|
| 87 |
-
return tasks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/text_cleaner_tool.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Text cleaning tool for smolagents.
|
| 3 |
|
|
@@ -7,32 +10,26 @@ text content with handling for various text transformation options.
|
|
| 7 |
|
| 8 |
# Standard library imports
|
| 9 |
import logging
|
| 10 |
-
from typing import
|
| 11 |
|
| 12 |
# Third-party imports
|
|
|
|
| 13 |
from smolagents import Tool
|
| 14 |
|
| 15 |
-
# Try to import cleantext - handle gracefully if not installed
|
| 16 |
-
try:
|
| 17 |
-
from cleantext import clean
|
| 18 |
-
|
| 19 |
-
CLEANTEXT_AVAILABLE = True
|
| 20 |
-
except ImportError:
|
| 21 |
-
CLEANTEXT_AVAILABLE = False
|
| 22 |
-
|
| 23 |
# Configure module logger
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
|
| 26 |
|
| 27 |
# pylint: disable=too-few-public-methods
|
| 28 |
class TextCleanerTool(Tool):
|
| 29 |
-
"""A
|
| 30 |
|
| 31 |
name = "clean_text"
|
| 32 |
-
description =
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
| 36 |
inputs = {
|
| 37 |
"text": {"type": "string", "description": "The input text to clean"},
|
| 38 |
"options": {
|
|
@@ -76,7 +73,7 @@ class TextCleanerTool(Tool):
|
|
| 76 |
`clean-text` uses ftfy, unidecode and numerous hand-crafted rules,
|
| 77 |
i.e., RegEx.
|
| 78 |
|
| 79 |
-
|
| 80 |
clean("some input",
|
| 81 |
fix_unicode=True, # fix various unicode errors
|
| 82 |
to_ascii=True, # transliterate to closest ASCII
|
|
@@ -110,14 +107,6 @@ class TextCleanerTool(Tool):
|
|
| 110 |
logger.error("Failed to convert input to string: %s", e)
|
| 111 |
return f"Error: Could not process input of type {type(text)}"
|
| 112 |
|
| 113 |
-
# Check if cleantext is available
|
| 114 |
-
if not CLEANTEXT_AVAILABLE:
|
| 115 |
-
logger.error(
|
| 116 |
-
"cleantext package not installed. "
|
| 117 |
-
"Install with: pip install clean-text"
|
| 118 |
-
)
|
| 119 |
-
return "Error: Required dependency 'clean-text' is not installed."
|
| 120 |
-
|
| 121 |
# Default replacement tokens
|
| 122 |
replacements = {
|
| 123 |
"replace_with_url": "<URL>",
|
|
@@ -159,3 +148,6 @@ class TextCleanerTool(Tool):
|
|
| 159 |
except (ValueError, TypeError, AttributeError) as e:
|
| 160 |
logger.error("Error cleaning text: %s", e)
|
| 161 |
return f"Error during text cleaning: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The Footscray Coding Collective. All rights reserved.
|
| 4 |
"""
|
| 5 |
Text cleaning tool for smolagents.
|
| 6 |
|
|
|
|
| 10 |
|
| 11 |
# Standard library imports
|
| 12 |
import logging
|
| 13 |
+
from typing import Any, Dict, Optional
|
| 14 |
|
| 15 |
# Third-party imports
|
| 16 |
+
from cleantext import clean
|
| 17 |
from smolagents import Tool
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# Configure module logger
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
|
| 23 |
# pylint: disable=too-few-public-methods
|
| 24 |
class TextCleanerTool(Tool):
|
| 25 |
+
"""A simple text cleaner tool."""
|
| 26 |
|
| 27 |
name = "clean_text"
|
| 28 |
+
description = """This tool can be used to process messy user-generated content into
|
| 29 |
+
normalized text. It handles a variety of text transformation options,
|
| 30 |
+
such as fixing unicode errors, transliterating to closest ASCII,
|
| 31 |
+
lowercasing text, normalizing line breaks, removing punctuation,
|
| 32 |
+
replacing numbers with a token, and more."""
|
| 33 |
inputs = {
|
| 34 |
"text": {"type": "string", "description": "The input text to clean"},
|
| 35 |
"options": {
|
|
|
|
| 73 |
`clean-text` uses ftfy, unidecode and numerous hand-crafted rules,
|
| 74 |
i.e., RegEx.
|
| 75 |
|
| 76 |
+
Usage of the cleantext API:
|
| 77 |
clean("some input",
|
| 78 |
fix_unicode=True, # fix various unicode errors
|
| 79 |
to_ascii=True, # transliterate to closest ASCII
|
|
|
|
| 107 |
logger.error("Failed to convert input to string: %s", e)
|
| 108 |
return f"Error: Could not process input of type {type(text)}"
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
# Default replacement tokens
|
| 111 |
replacements = {
|
| 112 |
"replace_with_url": "<URL>",
|
|
|
|
| 148 |
except (ValueError, TypeError, AttributeError) as e:
|
| 149 |
logger.error("Error cleaning text: %s", e)
|
| 150 |
return f"Error during text cleaning: {str(e)}"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
__all__ = ["TextCleanerTool"]
|
scripts/text_inspector_tool.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from typing import Optional
|
| 2 |
|
| 3 |
from smolagents import Tool
|
|
@@ -7,10 +9,24 @@ from .mdconvert import MarkdownConverter
|
|
| 7 |
|
| 8 |
|
| 9 |
class TextInspectorTool(Tool):
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
description = """
|
| 12 |
-
You cannot load files yourself: instead call this tool to read a file as markdown text and ask questions about it.
|
| 13 |
-
This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES.
|
|
|
|
| 14 |
|
| 15 |
inputs = {
|
| 16 |
"file_path": {
|
|
@@ -27,15 +43,23 @@ This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pp
|
|
| 27 |
md_converter = MarkdownConverter()
|
| 28 |
|
| 29 |
def __init__(self, model: Model, text_limit: int):
|
|
|
|
|
|
|
|
|
|
| 30 |
super().__init__()
|
| 31 |
self.model = model
|
| 32 |
self.text_limit = text_limit
|
| 33 |
|
| 34 |
def forward_initial_exam_mode(self, file_path, question):
|
|
|
|
|
|
|
|
|
|
| 35 |
result = self.md_converter.convert(file_path)
|
| 36 |
|
| 37 |
-
if file_path[-4:] in [".png", ".jpg"]:
|
| 38 |
-
raise Exception(
|
|
|
|
|
|
|
| 39 |
|
| 40 |
if ".zip" in file_path:
|
| 41 |
return result.text_content
|
|
@@ -73,11 +97,28 @@ This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pp
|
|
| 73 |
]
|
| 74 |
return self.model(messages).content
|
| 75 |
|
| 76 |
-
def forward(self, file_path, question: Optional[str] = None) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
result = self.md_converter.convert(file_path)
|
| 78 |
|
| 79 |
if file_path[-4:] in [".png", ".jpg"]:
|
| 80 |
-
raise Exception(
|
|
|
|
|
|
|
| 81 |
|
| 82 |
if ".zip" in file_path:
|
| 83 |
return result.text_content
|
|
@@ -120,3 +161,6 @@ This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pp
|
|
| 120 |
},
|
| 121 |
]
|
| 122 |
return self.model(messages).content
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
from smolagents import Tool
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class TextInspectorTool(Tool):
|
| 12 |
+
"""
|
| 13 |
+
Tool for converting various file types to text and answering questions about their contents.
|
| 14 |
+
|
| 15 |
+
Supported file types include:
|
| 16 |
+
- Text documents (.txt, .md)
|
| 17 |
+
- Web documents (.html, .htm)
|
| 18 |
+
- Office documents (.docx, .xlsx, .pptx)
|
| 19 |
+
- Audio files (.wav, .mp3, .flac)
|
| 20 |
+
- PDF documents (.pdf)
|
| 21 |
+
|
| 22 |
+
Images are not supported and should be processed with a visualizer tool instead.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
name = "view_file"
|
| 26 |
description = """
|
| 27 |
+
You cannot load files yourself: instead call this tool to read a file as markdown text and ask questions about it.
|
| 28 |
+
This tool handles the following file extensions: [".html", ".htm", ".md", ".txt", ".xlsx", ".pptx", ".wav", ".mp3", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES.
|
| 29 |
+
"""
|
| 30 |
|
| 31 |
inputs = {
|
| 32 |
"file_path": {
|
|
|
|
| 43 |
md_converter = MarkdownConverter()
|
| 44 |
|
| 45 |
def __init__(self, model: Model, text_limit: int):
|
| 46 |
+
"""
|
| 47 |
+
Initialize the TextInspectorTool with a model to use for generating text and a limit for the amount of text to generate.
|
| 48 |
+
"""
|
| 49 |
super().__init__()
|
| 50 |
self.model = model
|
| 51 |
self.text_limit = text_limit
|
| 52 |
|
| 53 |
def forward_initial_exam_mode(self, file_path, question):
|
| 54 |
+
"""
|
| 55 |
+
This is used for generating code for the initial exam, and is not used for the final exam.
|
| 56 |
+
"""
|
| 57 |
result = self.md_converter.convert(file_path)
|
| 58 |
|
| 59 |
+
if file_path[-4:] in [".png", ".jpg", ".webp"]:
|
| 60 |
+
raise Exception(
|
| 61 |
+
"Cannot use inspect_file_as_text tool with images: use visualizer instead!"
|
| 62 |
+
)
|
| 63 |
|
| 64 |
if ".zip" in file_path:
|
| 65 |
return result.text_content
|
|
|
|
| 97 |
]
|
| 98 |
return self.model(messages).content
|
| 99 |
|
| 100 |
+
def forward(self, file_path: str, question: Optional[str] = None) -> str:
|
| 101 |
+
"""
|
| 102 |
+
Process a file and optionally answer a question about its contents.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
file_path: Path to the file to be processed. Must be a supported file type.
|
| 106 |
+
question: Optional question to answer about the file contents.
|
| 107 |
+
If None, returns the raw file content.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Either the raw file content if no question is provided, or the model's
|
| 111 |
+
response to the question based on the file contents.
|
| 112 |
+
|
| 113 |
+
Raises:
|
| 114 |
+
Exception: If the file is an image file or has an unsupported format.
|
| 115 |
+
"""
|
| 116 |
result = self.md_converter.convert(file_path)
|
| 117 |
|
| 118 |
if file_path[-4:] in [".png", ".jpg"]:
|
| 119 |
+
raise Exception(
|
| 120 |
+
"Cannot use inspect_file_as_text tool with images: use visualizer instead!"
|
| 121 |
+
)
|
| 122 |
|
| 123 |
if ".zip" in file_path:
|
| 124 |
return result.text_content
|
|
|
|
| 161 |
},
|
| 162 |
]
|
| 163 |
return self.model(messages).content
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
__all__ = ["TextInspectorTool"]
|
scripts/text_web_browser.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
|
| 2 |
# https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
|
| 3 |
import mimetypes
|
|
@@ -12,11 +15,11 @@ from urllib.parse import unquote, urljoin, urlparse
|
|
| 12 |
import pathvalidate
|
| 13 |
import requests
|
| 14 |
from serpapi import GoogleSearch
|
| 15 |
-
|
| 16 |
from smolagents import Tool
|
| 17 |
|
| 18 |
from .cookies import COOKIES
|
| 19 |
-
from .mdconvert import FileConversionException, MarkdownConverter,
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class SimpleTextBrowser:
|
|
@@ -45,7 +48,9 @@ class SimpleTextBrowser:
|
|
| 45 |
self._page_content: str = ""
|
| 46 |
|
| 47 |
self._find_on_page_query: Union[str, None] = None
|
| 48 |
-
self._find_on_page_last_result: Union[int, None] =
|
|
|
|
|
|
|
| 49 |
|
| 50 |
@property
|
| 51 |
def address(self) -> str:
|
|
@@ -60,7 +65,9 @@ class SimpleTextBrowser:
|
|
| 60 |
if uri_or_path == "about:blank":
|
| 61 |
self._set_page_content("")
|
| 62 |
elif uri_or_path.startswith("google:"):
|
| 63 |
-
self._serpapi_search(
|
|
|
|
|
|
|
| 64 |
else:
|
| 65 |
if (
|
| 66 |
not uri_or_path.startswith("http:")
|
|
@@ -97,7 +104,9 @@ class SimpleTextBrowser:
|
|
| 97 |
self.viewport_current_page = len(self.viewport_pages) - 1
|
| 98 |
|
| 99 |
def page_down(self) -> None:
|
| 100 |
-
self.viewport_current_page = min(
|
|
|
|
|
|
|
| 101 |
|
| 102 |
def page_up(self) -> None:
|
| 103 |
self.viewport_current_page = max(self.viewport_current_page - 1, 0)
|
|
@@ -107,7 +116,10 @@ class SimpleTextBrowser:
|
|
| 107 |
|
| 108 |
# Did we get here via a previous find_on_page search with the same query?
|
| 109 |
# If so, map to find_next
|
| 110 |
-
if
|
|
|
|
|
|
|
|
|
|
| 111 |
return self.find_next()
|
| 112 |
|
| 113 |
# Ok it's a new search start from the current viewport
|
|
@@ -135,7 +147,9 @@ class SimpleTextBrowser:
|
|
| 135 |
if starting_viewport >= len(self.viewport_pages):
|
| 136 |
starting_viewport = 0
|
| 137 |
|
| 138 |
-
viewport_match = self._find_next_viewport(
|
|
|
|
|
|
|
| 139 |
if viewport_match is None:
|
| 140 |
self._find_on_page_last_result = None
|
| 141 |
return None
|
|
@@ -144,7 +158,9 @@ class SimpleTextBrowser:
|
|
| 144 |
self._find_on_page_last_result = viewport_match
|
| 145 |
return self.viewport
|
| 146 |
|
| 147 |
-
def _find_next_viewport(
|
|
|
|
|
|
|
| 148 |
"""Search for matches between the starting viewport looping when reaching the end."""
|
| 149 |
|
| 150 |
if query is None:
|
|
@@ -153,7 +169,9 @@ class SimpleTextBrowser:
|
|
| 153 |
# Normalize the query, and convert to a regular expression
|
| 154 |
nquery = re.sub(r"\*", "__STAR__", query)
|
| 155 |
nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
|
| 156 |
-
nquery = nquery.replace(
|
|
|
|
|
|
|
| 157 |
nquery = nquery.replace("__STAR__", ".*").lower()
|
| 158 |
|
| 159 |
if nquery.strip() == "":
|
|
@@ -196,7 +214,9 @@ class SimpleTextBrowser:
|
|
| 196 |
while start_idx < len(self._page_content):
|
| 197 |
end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
|
| 198 |
# Adjust to end on a space
|
| 199 |
-
while end_idx < len(self._page_content) and self._page_content[
|
|
|
|
|
|
|
| 200 |
end_idx += 1
|
| 201 |
self.viewport_pages.append((start_idx, end_idx))
|
| 202 |
start_idx = end_idx
|
|
@@ -211,15 +231,21 @@ class SimpleTextBrowser:
|
|
| 211 |
"api_key": self.serpapi_key,
|
| 212 |
}
|
| 213 |
if filter_year is not None:
|
| 214 |
-
params["tbs"] =
|
|
|
|
|
|
|
| 215 |
|
| 216 |
search = GoogleSearch(params)
|
| 217 |
results = search.get_dict()
|
| 218 |
self.page_title = f"{query} - Search"
|
| 219 |
if "organic_results" not in results.keys():
|
| 220 |
-
raise Exception(
|
|
|
|
|
|
|
| 221 |
if len(results["organic_results"]) == 0:
|
| 222 |
-
year_filter_message =
|
|
|
|
|
|
|
| 223 |
self._set_page_content(
|
| 224 |
f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
|
| 225 |
)
|
|
@@ -250,7 +276,9 @@ class SimpleTextBrowser:
|
|
| 250 |
|
| 251 |
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{_prev_visit(page['link'])}{snippet}"
|
| 252 |
|
| 253 |
-
redacted_version = redacted_version.replace(
|
|
|
|
|
|
|
| 254 |
web_snippets.append(redacted_version)
|
| 255 |
|
| 256 |
content = (
|
|
@@ -270,7 +298,11 @@ class SimpleTextBrowser:
|
|
| 270 |
self._set_page_content(res.text_content)
|
| 271 |
else:
|
| 272 |
# Prepare the request parameters
|
| 273 |
-
request_kwargs =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
request_kwargs["stream"] = True
|
| 275 |
|
| 276 |
# Send a HTTP request to the URL
|
|
@@ -291,15 +323,21 @@ class SimpleTextBrowser:
|
|
| 291 |
fname = None
|
| 292 |
download_path = None
|
| 293 |
try:
|
| 294 |
-
fname = pathvalidate.sanitize_filename(
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
suffix = 0
|
| 298 |
while os.path.exists(download_path) and suffix < 1000:
|
| 299 |
suffix += 1
|
| 300 |
base, ext = os.path.splitext(fname)
|
| 301 |
new_fname = f"{base}__{suffix}{ext}"
|
| 302 |
-
download_path = os.path.abspath(
|
|
|
|
|
|
|
| 303 |
|
| 304 |
except NameError:
|
| 305 |
pass
|
|
@@ -310,7 +348,9 @@ class SimpleTextBrowser:
|
|
| 310 |
if extension is None:
|
| 311 |
extension = ".download"
|
| 312 |
fname = str(uuid.uuid4()) + extension
|
| 313 |
-
download_path = os.path.abspath(
|
|
|
|
|
|
|
| 314 |
|
| 315 |
# Open a file for writing
|
| 316 |
with open(download_path, "wb") as fh:
|
|
@@ -324,11 +364,15 @@ class SimpleTextBrowser:
|
|
| 324 |
except UnsupportedFormatException as e:
|
| 325 |
print(e)
|
| 326 |
self.page_title = ("Download complete.",)
|
| 327 |
-
self._set_page_content(
|
|
|
|
|
|
|
| 328 |
except FileConversionException as e:
|
| 329 |
print(e)
|
| 330 |
self.page_title = ("Download complete.",)
|
| 331 |
-
self._set_page_content(
|
|
|
|
|
|
|
| 332 |
except FileNotFoundError:
|
| 333 |
self.page_title = "Error 404"
|
| 334 |
self._set_page_content(f"## Error 404\n\nFile not found: {download_path}")
|
|
@@ -341,10 +385,14 @@ class SimpleTextBrowser:
|
|
| 341 |
if content_type is not None and "text/html" in content_type.lower():
|
| 342 |
res = self._mdconvert.convert(response)
|
| 343 |
self.page_title = f"Error {response.status_code}"
|
| 344 |
-
self._set_page_content(
|
|
|
|
|
|
|
| 345 |
else:
|
| 346 |
text = ""
|
| 347 |
-
for chunk in response.iter_content(
|
|
|
|
|
|
|
| 348 |
text += chunk
|
| 349 |
self.page_title = f"Error {response.status_code}"
|
| 350 |
self._set_page_content(f"## Error {response.status_code}\n\n{text}")
|
|
@@ -366,14 +414,18 @@ class SimpleTextBrowser:
|
|
| 366 |
header += f"You previously visited this page {round(time.time() - self.history[i][1])} seconds ago.\n"
|
| 367 |
break
|
| 368 |
|
| 369 |
-
header +=
|
|
|
|
|
|
|
| 370 |
return (header, self.viewport)
|
| 371 |
|
| 372 |
|
| 373 |
class SearchInformationTool(Tool):
|
| 374 |
name = "web_search"
|
| 375 |
description = "Perform a web search query (think a google search) and returns the search results."
|
| 376 |
-
inputs = {
|
|
|
|
|
|
|
| 377 |
inputs["filter_year"] = {
|
| 378 |
"type": "string",
|
| 379 |
"description": "[Optional parameter]: filter the search results to only include pages from a specific year. For example, '2020' will only include pages from 2020. Make sure to use this parameter if you're trying to search for articles from a specific date!",
|
|
@@ -394,7 +446,12 @@ class SearchInformationTool(Tool):
|
|
| 394 |
class VisitTool(Tool):
|
| 395 |
name = "visit_page"
|
| 396 |
description = "Visit a webpage at a given URL and return its text. Given a url to a YouTube video, this returns the transcript."
|
| 397 |
-
inputs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
output_type = "string"
|
| 399 |
|
| 400 |
def __init__(self, browser):
|
|
@@ -413,7 +470,12 @@ class DownloadTool(Tool):
|
|
| 413 |
Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"]
|
| 414 |
After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it.
|
| 415 |
DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead."""
|
| 416 |
-
inputs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
output_type = "string"
|
| 418 |
|
| 419 |
def __init__(self, browser):
|
|
@@ -435,7 +497,9 @@ DO NOT use this tool for .pdf or .txt or .htm files: for these types of files us
|
|
| 435 |
f.write(response.content)
|
| 436 |
|
| 437 |
if "pdf" in extension or "txt" in extension or "htm" in extension:
|
| 438 |
-
raise Exception(
|
|
|
|
|
|
|
| 439 |
|
| 440 |
return f"File was downloaded and saved under path {new_path}."
|
| 441 |
|
|
@@ -461,15 +525,23 @@ class ArchiveSearchTool(Tool):
|
|
| 461 |
archive_url = no_timestamp_url + f"×tamp={date}"
|
| 462 |
response = requests.get(archive_url).json()
|
| 463 |
response_notimestamp = requests.get(no_timestamp_url).json()
|
| 464 |
-
if
|
|
|
|
|
|
|
|
|
|
| 465 |
closest = response["archived_snapshots"]["closest"]
|
| 466 |
print("Archive found!", closest)
|
| 467 |
|
| 468 |
-
elif
|
|
|
|
|
|
|
|
|
|
| 469 |
closest = response_notimestamp["archived_snapshots"]["closest"]
|
| 470 |
print("Archive found!", closest)
|
| 471 |
else:
|
| 472 |
-
raise Exception(
|
|
|
|
|
|
|
| 473 |
target_url = closest["url"]
|
| 474 |
self.browser.visit_page(target_url)
|
| 475 |
header, content = self.browser._state()
|
|
@@ -499,9 +571,7 @@ class PageUpTool(Tool):
|
|
| 499 |
|
| 500 |
class PageDownTool(Tool):
|
| 501 |
name = "page_down"
|
| 502 |
-
description =
|
| 503 |
-
"Scroll the viewport DOWN one page-length in the current webpage and return the new viewport content."
|
| 504 |
-
)
|
| 505 |
inputs = {}
|
| 506 |
output_type = "string"
|
| 507 |
|
|
@@ -558,6 +628,20 @@ class FindNextTool(Tool):
|
|
| 558 |
header, content = self.browser._state()
|
| 559 |
|
| 560 |
if find_result is None:
|
| 561 |
-
return
|
|
|
|
|
|
|
|
|
|
| 562 |
else:
|
| 563 |
return header.strip() + "\n=======================\n" + content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# TODO: REMOVE REDUNDANT SERPAPI CODE AND IMPORT/EXTEND DEFAULT GoogleSearchTool FROM SMOLAGENTS
|
| 4 |
# Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
|
| 5 |
# https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
|
| 6 |
import mimetypes
|
|
|
|
| 15 |
import pathvalidate
|
| 16 |
import requests
|
| 17 |
from serpapi import GoogleSearch
|
|
|
|
| 18 |
from smolagents import Tool
|
| 19 |
|
| 20 |
from .cookies import COOKIES
|
| 21 |
+
from .mdconvert import (FileConversionException, MarkdownConverter,
|
| 22 |
+
UnsupportedFormatException)
|
| 23 |
|
| 24 |
|
| 25 |
class SimpleTextBrowser:
|
|
|
|
| 48 |
self._page_content: str = ""
|
| 49 |
|
| 50 |
self._find_on_page_query: Union[str, None] = None
|
| 51 |
+
self._find_on_page_last_result: Union[int, None] = (
|
| 52 |
+
None # Location of the last result
|
| 53 |
+
)
|
| 54 |
|
| 55 |
@property
|
| 56 |
def address(self) -> str:
|
|
|
|
| 65 |
if uri_or_path == "about:blank":
|
| 66 |
self._set_page_content("")
|
| 67 |
elif uri_or_path.startswith("google:"):
|
| 68 |
+
self._serpapi_search(
|
| 69 |
+
uri_or_path[len("google:") :].strip(), filter_year=filter_year
|
| 70 |
+
)
|
| 71 |
else:
|
| 72 |
if (
|
| 73 |
not uri_or_path.startswith("http:")
|
|
|
|
| 104 |
self.viewport_current_page = len(self.viewport_pages) - 1
|
| 105 |
|
| 106 |
def page_down(self) -> None:
|
| 107 |
+
self.viewport_current_page = min(
|
| 108 |
+
self.viewport_current_page + 1, len(self.viewport_pages) - 1
|
| 109 |
+
)
|
| 110 |
|
| 111 |
def page_up(self) -> None:
|
| 112 |
self.viewport_current_page = max(self.viewport_current_page - 1, 0)
|
|
|
|
| 116 |
|
| 117 |
# Did we get here via a previous find_on_page search with the same query?
|
| 118 |
# If so, map to find_next
|
| 119 |
+
if (
|
| 120 |
+
query == self._find_on_page_query
|
| 121 |
+
and self.viewport_current_page == self._find_on_page_last_result
|
| 122 |
+
):
|
| 123 |
return self.find_next()
|
| 124 |
|
| 125 |
# Ok it's a new search start from the current viewport
|
|
|
|
| 147 |
if starting_viewport >= len(self.viewport_pages):
|
| 148 |
starting_viewport = 0
|
| 149 |
|
| 150 |
+
viewport_match = self._find_next_viewport(
|
| 151 |
+
self._find_on_page_query, starting_viewport
|
| 152 |
+
)
|
| 153 |
if viewport_match is None:
|
| 154 |
self._find_on_page_last_result = None
|
| 155 |
return None
|
|
|
|
| 158 |
self._find_on_page_last_result = viewport_match
|
| 159 |
return self.viewport
|
| 160 |
|
| 161 |
+
def _find_next_viewport(
|
| 162 |
+
self, query: str, starting_viewport: int
|
| 163 |
+
) -> Union[int, None]:
|
| 164 |
"""Search for matches between the starting viewport looping when reaching the end."""
|
| 165 |
|
| 166 |
if query is None:
|
|
|
|
| 169 |
# Normalize the query, and convert to a regular expression
|
| 170 |
nquery = re.sub(r"\*", "__STAR__", query)
|
| 171 |
nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
|
| 172 |
+
nquery = nquery.replace(
|
| 173 |
+
" __STAR__ ", "__STAR__ "
|
| 174 |
+
) # Merge isolated stars with prior word
|
| 175 |
nquery = nquery.replace("__STAR__", ".*").lower()
|
| 176 |
|
| 177 |
if nquery.strip() == "":
|
|
|
|
| 214 |
while start_idx < len(self._page_content):
|
| 215 |
end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
|
| 216 |
# Adjust to end on a space
|
| 217 |
+
while end_idx < len(self._page_content) and self._page_content[
|
| 218 |
+
end_idx - 1
|
| 219 |
+
] not in [" ", "\t", "\r", "\n"]:
|
| 220 |
end_idx += 1
|
| 221 |
self.viewport_pages.append((start_idx, end_idx))
|
| 222 |
start_idx = end_idx
|
|
|
|
| 231 |
"api_key": self.serpapi_key,
|
| 232 |
}
|
| 233 |
if filter_year is not None:
|
| 234 |
+
params["tbs"] = (
|
| 235 |
+
f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
|
| 236 |
+
)
|
| 237 |
|
| 238 |
search = GoogleSearch(params)
|
| 239 |
results = search.get_dict()
|
| 240 |
self.page_title = f"{query} - Search"
|
| 241 |
if "organic_results" not in results.keys():
|
| 242 |
+
raise Exception(
|
| 243 |
+
f"No results found for query: '{query}'. Use a less specific query."
|
| 244 |
+
)
|
| 245 |
if len(results["organic_results"]) == 0:
|
| 246 |
+
year_filter_message = (
|
| 247 |
+
f" with filter year={filter_year}" if filter_year is not None else ""
|
| 248 |
+
)
|
| 249 |
self._set_page_content(
|
| 250 |
f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
|
| 251 |
)
|
|
|
|
| 276 |
|
| 277 |
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{_prev_visit(page['link'])}{snippet}"
|
| 278 |
|
| 279 |
+
redacted_version = redacted_version.replace(
|
| 280 |
+
"Your browser can't play this video.", ""
|
| 281 |
+
)
|
| 282 |
web_snippets.append(redacted_version)
|
| 283 |
|
| 284 |
content = (
|
|
|
|
| 298 |
self._set_page_content(res.text_content)
|
| 299 |
else:
|
| 300 |
# Prepare the request parameters
|
| 301 |
+
request_kwargs = (
|
| 302 |
+
self.request_kwargs.copy()
|
| 303 |
+
if self.request_kwargs is not None
|
| 304 |
+
else {}
|
| 305 |
+
)
|
| 306 |
request_kwargs["stream"] = True
|
| 307 |
|
| 308 |
# Send a HTTP request to the URL
|
|
|
|
| 323 |
fname = None
|
| 324 |
download_path = None
|
| 325 |
try:
|
| 326 |
+
fname = pathvalidate.sanitize_filename(
|
| 327 |
+
os.path.basename(urlparse(url).path)
|
| 328 |
+
).strip()
|
| 329 |
+
download_path = os.path.abspath(
|
| 330 |
+
os.path.join(self.downloads_folder, fname)
|
| 331 |
+
)
|
| 332 |
|
| 333 |
suffix = 0
|
| 334 |
while os.path.exists(download_path) and suffix < 1000:
|
| 335 |
suffix += 1
|
| 336 |
base, ext = os.path.splitext(fname)
|
| 337 |
new_fname = f"{base}__{suffix}{ext}"
|
| 338 |
+
download_path = os.path.abspath(
|
| 339 |
+
os.path.join(self.downloads_folder, new_fname)
|
| 340 |
+
)
|
| 341 |
|
| 342 |
except NameError:
|
| 343 |
pass
|
|
|
|
| 348 |
if extension is None:
|
| 349 |
extension = ".download"
|
| 350 |
fname = str(uuid.uuid4()) + extension
|
| 351 |
+
download_path = os.path.abspath(
|
| 352 |
+
os.path.join(self.downloads_folder, fname)
|
| 353 |
+
)
|
| 354 |
|
| 355 |
# Open a file for writing
|
| 356 |
with open(download_path, "wb") as fh:
|
|
|
|
| 364 |
except UnsupportedFormatException as e:
|
| 365 |
print(e)
|
| 366 |
self.page_title = ("Download complete.",)
|
| 367 |
+
self._set_page_content(
|
| 368 |
+
f"# Download complete\n\nSaved file to '{download_path}'"
|
| 369 |
+
)
|
| 370 |
except FileConversionException as e:
|
| 371 |
print(e)
|
| 372 |
self.page_title = ("Download complete.",)
|
| 373 |
+
self._set_page_content(
|
| 374 |
+
f"# Download complete\n\nSaved file to '{download_path}'"
|
| 375 |
+
)
|
| 376 |
except FileNotFoundError:
|
| 377 |
self.page_title = "Error 404"
|
| 378 |
self._set_page_content(f"## Error 404\n\nFile not found: {download_path}")
|
|
|
|
| 385 |
if content_type is not None and "text/html" in content_type.lower():
|
| 386 |
res = self._mdconvert.convert(response)
|
| 387 |
self.page_title = f"Error {response.status_code}"
|
| 388 |
+
self._set_page_content(
|
| 389 |
+
f"## Error {response.status_code}\n\n{res.text_content}"
|
| 390 |
+
)
|
| 391 |
else:
|
| 392 |
text = ""
|
| 393 |
+
for chunk in response.iter_content(
|
| 394 |
+
chunk_size=512, decode_unicode=True
|
| 395 |
+
):
|
| 396 |
text += chunk
|
| 397 |
self.page_title = f"Error {response.status_code}"
|
| 398 |
self._set_page_content(f"## Error {response.status_code}\n\n{text}")
|
|
|
|
| 414 |
header += f"You previously visited this page {round(time.time() - self.history[i][1])} seconds ago.\n"
|
| 415 |
break
|
| 416 |
|
| 417 |
+
header += (
|
| 418 |
+
f"Viewport position: Showing page {current_page + 1} of {total_pages}.\n"
|
| 419 |
+
)
|
| 420 |
return (header, self.viewport)
|
| 421 |
|
| 422 |
|
| 423 |
class SearchInformationTool(Tool):
|
| 424 |
name = "web_search"
|
| 425 |
description = "Perform a web search query (think a google search) and returns the search results."
|
| 426 |
+
inputs = {
|
| 427 |
+
"query": {"type": "string", "description": "The web search query to perform."}
|
| 428 |
+
}
|
| 429 |
inputs["filter_year"] = {
|
| 430 |
"type": "string",
|
| 431 |
"description": "[Optional parameter]: filter the search results to only include pages from a specific year. For example, '2020' will only include pages from 2020. Make sure to use this parameter if you're trying to search for articles from a specific date!",
|
|
|
|
| 446 |
class VisitTool(Tool):
|
| 447 |
name = "visit_page"
|
| 448 |
description = "Visit a webpage at a given URL and return its text. Given a url to a YouTube video, this returns the transcript."
|
| 449 |
+
inputs = {
|
| 450 |
+
"url": {
|
| 451 |
+
"type": "string",
|
| 452 |
+
"description": "The relative or absolute url of the webapge to visit.",
|
| 453 |
+
}
|
| 454 |
+
}
|
| 455 |
output_type = "string"
|
| 456 |
|
| 457 |
def __init__(self, browser):
|
|
|
|
| 470 |
Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"]
|
| 471 |
After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it.
|
| 472 |
DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead."""
|
| 473 |
+
inputs = {
|
| 474 |
+
"url": {
|
| 475 |
+
"type": "string",
|
| 476 |
+
"description": "The relative or absolute url of the file to be downloaded.",
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
output_type = "string"
|
| 480 |
|
| 481 |
def __init__(self, browser):
|
|
|
|
| 497 |
f.write(response.content)
|
| 498 |
|
| 499 |
if "pdf" in extension or "txt" in extension or "htm" in extension:
|
| 500 |
+
raise Exception(
|
| 501 |
+
"Do not use this tool for pdf or txt or html files: use visit_page instead."
|
| 502 |
+
)
|
| 503 |
|
| 504 |
return f"File was downloaded and saved under path {new_path}."
|
| 505 |
|
|
|
|
| 525 |
archive_url = no_timestamp_url + f"×tamp={date}"
|
| 526 |
response = requests.get(archive_url).json()
|
| 527 |
response_notimestamp = requests.get(no_timestamp_url).json()
|
| 528 |
+
if (
|
| 529 |
+
"archived_snapshots" in response
|
| 530 |
+
and "closest" in response["archived_snapshots"]
|
| 531 |
+
):
|
| 532 |
closest = response["archived_snapshots"]["closest"]
|
| 533 |
print("Archive found!", closest)
|
| 534 |
|
| 535 |
+
elif (
|
| 536 |
+
"archived_snapshots" in response_notimestamp
|
| 537 |
+
and "closest" in response_notimestamp["archived_snapshots"]
|
| 538 |
+
):
|
| 539 |
closest = response_notimestamp["archived_snapshots"]["closest"]
|
| 540 |
print("Archive found!", closest)
|
| 541 |
else:
|
| 542 |
+
raise Exception(
|
| 543 |
+
f"Your {url=} was not archived on Wayback Machine, try a different url."
|
| 544 |
+
)
|
| 545 |
target_url = closest["url"]
|
| 546 |
self.browser.visit_page(target_url)
|
| 547 |
header, content = self.browser._state()
|
|
|
|
| 571 |
|
| 572 |
class PageDownTool(Tool):
|
| 573 |
name = "page_down"
|
| 574 |
+
description = "Scroll the viewport DOWN one page-length in the current webpage and return the new viewport content."
|
|
|
|
|
|
|
| 575 |
inputs = {}
|
| 576 |
output_type = "string"
|
| 577 |
|
|
|
|
| 628 |
header, content = self.browser._state()
|
| 629 |
|
| 630 |
if find_result is None:
|
| 631 |
+
return (
|
| 632 |
+
header.strip()
|
| 633 |
+
+ "\n=======================\nThe search string was not found on this page."
|
| 634 |
+
)
|
| 635 |
else:
|
| 636 |
return header.strip() + "\n=======================\n" + content
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
__all__ = [
|
| 640 |
+
"DownloadTool",
|
| 641 |
+
"VisitTool",
|
| 642 |
+
"PageUpTool",
|
| 643 |
+
"PageDownTool",
|
| 644 |
+
"FinderTool",
|
| 645 |
+
"FindNextTool",
|
| 646 |
+
"ArchiveSearchTool",
|
| 647 |
+
]
|
scripts/time_tools.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2024 The Footscray Coding Collective. All rights reserved.
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import pytz
|
| 8 |
+
from smolagents import tool
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@tool
|
| 12 |
+
def get_temporal_context(
|
| 13 |
+
timezone_str: str = "US/Eastern", market: str = "US", date_str: Optional[str] = None
|
| 14 |
+
) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Provides a concise overview of the current temporal context, including date, time, and market status.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
timezone_str: The timezone to display time in (default: US/Eastern)
|
| 20 |
+
market: Market identifier (US, EU, ASIA) (default: US)
|
| 21 |
+
date_str: Date in YYYY-MM-DD format (optional, defaults to current date if not provided)
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
A formatted string containing the current date, time, year, trading day status, and market hours status.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
# Get current time information using pytz
|
| 29 |
+
try:
|
| 30 |
+
tz = pytz.timezone(timezone_str)
|
| 31 |
+
except pytz.exceptions.UnknownTimeZoneError:
|
| 32 |
+
return f"Error: Unknown timezone '{timezone_str}'. Try using standard timezone names like 'US/Eastern'."
|
| 33 |
+
|
| 34 |
+
now = datetime.now(tz)
|
| 35 |
+
current_date = now.strftime("%Y-%m-%d")
|
| 36 |
+
current_time = now.strftime("%H:%M:%S")
|
| 37 |
+
current_year = now.year
|
| 38 |
+
weekday_name = now.strftime("%A")
|
| 39 |
+
time_info = f"""Current Time Information:
|
| 40 |
+
- Date: {current_date} ({weekday_name})
|
| 41 |
+
- Time: {current_time} ({timezone_str})
|
| 42 |
+
- Year: {current_year}
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
# Get Market hours Information
|
| 46 |
+
if market == "US":
|
| 47 |
+
# Convert time to US/Eastern for US market check
|
| 48 |
+
eastern_tz = pytz.timezone("US/Eastern")
|
| 49 |
+
eastern_now = now.astimezone(eastern_tz)
|
| 50 |
+
|
| 51 |
+
is_weekday_us = eastern_now.weekday() < 5
|
| 52 |
+
us_minutes = eastern_now.hour * 60 + eastern_now.minute
|
| 53 |
+
us_market_open = 9 * 60 + 30 # 9:30 AM ET
|
| 54 |
+
us_market_close = 16 * 60 # 4:00 PM ET
|
| 55 |
+
|
| 56 |
+
if is_weekday_us and us_market_open <= us_minutes < us_market_close:
|
| 57 |
+
market_status = "Open"
|
| 58 |
+
else:
|
| 59 |
+
market_status = "Closed"
|
| 60 |
+
|
| 61 |
+
market_hours_info = f"US Markets (NYSE, NASDAQ): {market_status}"
|
| 62 |
+
|
| 63 |
+
elif market == "EU":
|
| 64 |
+
# Convert time to London for EU market check
|
| 65 |
+
london_tz = pytz.timezone("Europe/London")
|
| 66 |
+
london_now = now.astimezone(london_tz)
|
| 67 |
+
|
| 68 |
+
is_weekday_eu = london_now.weekday() < 5
|
| 69 |
+
eu_minutes = london_now.hour * 60 + london_now.minute
|
| 70 |
+
eu_market_open = 8 * 60 # 8:00 AM London
|
| 71 |
+
eu_market_close = 16 * 60 + 30 # 4:30 PM London
|
| 72 |
+
|
| 73 |
+
if is_weekday_eu and eu_market_open <= eu_minutes < eu_market_close:
|
| 74 |
+
market_status = "Open"
|
| 75 |
+
else:
|
| 76 |
+
market_status = "Closed"
|
| 77 |
+
|
| 78 |
+
market_hours_info = f"European Markets (LSE, Euronext): {market_status}"
|
| 79 |
+
|
| 80 |
+
elif market == "ASIA":
|
| 81 |
+
# Convert time to Tokyo for Asian market check
|
| 82 |
+
tokyo_tz = pytz.timezone("Asia/Tokyo")
|
| 83 |
+
tokyo_now = now.astimezone(tokyo_tz)
|
| 84 |
+
|
| 85 |
+
is_weekday_tokyo = tokyo_now.weekday() < 5
|
| 86 |
+
tokyo_minutes = tokyo_now.hour * 60 + tokyo_now.minute
|
| 87 |
+
tokyo_morning_open = 9 * 60 # 9:00 AM Tokyo
|
| 88 |
+
tokyo_morning_close = 11 * 60 + 30 # 11:30 AM Tokyo
|
| 89 |
+
tokyo_afternoon_open = 12 * 60 + 30 # 12:30 PM Tokyo
|
| 90 |
+
tokyo_afternoon_close = 15 * 60 # 3:00 PM Tokyo
|
| 91 |
+
|
| 92 |
+
is_tokyo_session = (
|
| 93 |
+
tokyo_morning_open <= tokyo_minutes < tokyo_morning_close
|
| 94 |
+
) or (tokyo_afternoon_open <= tokyo_minutes < tokyo_afternoon_close)
|
| 95 |
+
|
| 96 |
+
if is_weekday_tokyo and is_tokyo_session:
|
| 97 |
+
market_status = "Open"
|
| 98 |
+
else:
|
| 99 |
+
market_status = "Closed"
|
| 100 |
+
|
| 101 |
+
market_hours_info = (
|
| 102 |
+
"Asian Markets (Tokyo Stock Exchange, Shanghai Stock Exchange, "
|
| 103 |
+
f"Australian Securities Exchange): {market_status}"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
else:
|
| 107 |
+
return f"Error: Invalid market '{market}'. Supported markets are 'US', 'EU', and 'ASIA'."
|
| 108 |
+
|
| 109 |
+
# Get Trading Day Information
|
| 110 |
+
if date_str:
|
| 111 |
+
try:
|
| 112 |
+
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
|
| 113 |
+
# Apply timezone to date_obj
|
| 114 |
+
date_obj = tz.localize(date_obj)
|
| 115 |
+
except ValueError:
|
| 116 |
+
return (
|
| 117 |
+
f"Error: Invalid date format '{date_str}'. Use YYYY-MM-DD format."
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
date_obj = now
|
| 121 |
+
date_str = now.strftime("%Y-%m-%d")
|
| 122 |
+
|
| 123 |
+
is_weekend = date_obj.weekday() > 4
|
| 124 |
+
trading_day = "No" if is_weekend else "Yes"
|
| 125 |
+
trading_info = f"Trading Day: {trading_day}"
|
| 126 |
+
|
| 127 |
+
# Combine all information
|
| 128 |
+
final_result = f"""{time_info}
|
| 129 |
+
{market_hours_info}
|
| 130 |
+
- {trading_info}
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
return final_result
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
return f"Error retrieving temporal context: {str(e)}"
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
__all__ = ["get_temporal_context"]
|
scripts/visual_qa.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import base64
|
| 2 |
import json
|
| 3 |
import mimetypes
|
|
@@ -10,10 +12,8 @@ import requests
|
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
from huggingface_hub import InferenceClient
|
| 12 |
from PIL import Image
|
| 13 |
-
from transformers import AutoProcessor
|
| 14 |
-
|
| 15 |
from smolagents import Tool, tool
|
| 16 |
-
|
| 17 |
|
| 18 |
load_dotenv(override=True)
|
| 19 |
|
|
@@ -31,7 +31,9 @@ def process_images_and_text(image_path, query, client):
|
|
| 31 |
},
|
| 32 |
]
|
| 33 |
|
| 34 |
-
prompt_with_template = idefics_processor.apply_chat_template(
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# load images from local directory
|
| 37 |
|
|
@@ -42,7 +44,9 @@ def process_images_and_text(image_path, query, client):
|
|
| 42 |
|
| 43 |
# Convert the image to a base64 string
|
| 44 |
buffer = BytesIO()
|
| 45 |
-
image.save(
|
|
|
|
|
|
|
| 46 |
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 47 |
|
| 48 |
# add string formatting required by the endpoint
|
|
@@ -51,7 +55,9 @@ def process_images_and_text(image_path, query, client):
|
|
| 51 |
return image_string
|
| 52 |
|
| 53 |
image_string = encode_local_image(image_path)
|
| 54 |
-
prompt_with_images = prompt_with_template.replace("<image>", " ").format(
|
|
|
|
|
|
|
| 55 |
|
| 56 |
payload = {
|
| 57 |
"inputs": prompt_with_images,
|
|
@@ -95,7 +101,10 @@ def encode_image(image_path):
|
|
| 95 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
| 96 |
|
| 97 |
|
| 98 |
-
headers = {
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def resize_image(image_path):
|
|
@@ -115,7 +124,11 @@ class VisualQATool(Tool):
|
|
| 115 |
"description": "The path to the image on which to answer the question",
|
| 116 |
"type": "string",
|
| 117 |
},
|
| 118 |
-
"question": {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
}
|
| 120 |
output_type = "string"
|
| 121 |
# try use the same model with two different endpoints
|
|
@@ -136,9 +149,7 @@ class VisualQATool(Tool):
|
|
| 136 |
output = process_images_and_text(new_image_path, question, self.client)
|
| 137 |
|
| 138 |
if add_note:
|
| 139 |
-
output =
|
| 140 |
-
f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
|
| 141 |
-
)
|
| 142 |
|
| 143 |
return output
|
| 144 |
|
|
@@ -156,7 +167,9 @@ def visualizer(image_path: str, question: Optional[str] = None) -> str:
|
|
| 156 |
add_note = True
|
| 157 |
question = "Please write a detailed caption for this image."
|
| 158 |
if not isinstance(image_path, str):
|
| 159 |
-
raise Exception(
|
|
|
|
|
|
|
| 160 |
|
| 161 |
mime_type, _ = mimetypes.guess_type(image_path)
|
| 162 |
base64_image = encode_image(image_path)
|
|
@@ -168,13 +181,18 @@ def visualizer(image_path: str, question: Optional[str] = None) -> str:
|
|
| 168 |
"role": "user",
|
| 169 |
"content": [
|
| 170 |
{"type": "text", "text": "what is in this image" + question},
|
| 171 |
-
{
|
|
|
|
|
|
|
|
|
|
| 172 |
],
|
| 173 |
}
|
| 174 |
],
|
| 175 |
"max_tokens": 1000,
|
| 176 |
}
|
| 177 |
-
response = requests.post(
|
|
|
|
|
|
|
| 178 |
try:
|
| 179 |
output = response.json()["choices"][0]["message"]["content"]
|
| 180 |
except Exception:
|
|
@@ -184,5 +202,5 @@ def visualizer(image_path: str, question: Optional[str] = None) -> str:
|
|
| 184 |
output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
|
| 185 |
|
| 186 |
# TO DO: write to yaml or chromadb -> HF Dataset in due course...
|
| 187 |
-
|
| 188 |
return output
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
import base64
|
| 4 |
import json
|
| 5 |
import mimetypes
|
|
|
|
| 12 |
from dotenv import load_dotenv
|
| 13 |
from huggingface_hub import InferenceClient
|
| 14 |
from PIL import Image
|
|
|
|
|
|
|
| 15 |
from smolagents import Tool, tool
|
| 16 |
+
from transformers import AutoProcessor
|
| 17 |
|
| 18 |
load_dotenv(override=True)
|
| 19 |
|
|
|
|
| 31 |
},
|
| 32 |
]
|
| 33 |
|
| 34 |
+
prompt_with_template = idefics_processor.apply_chat_template(
|
| 35 |
+
messages, add_generation_prompt=True
|
| 36 |
+
)
|
| 37 |
|
| 38 |
# load images from local directory
|
| 39 |
|
|
|
|
| 44 |
|
| 45 |
# Convert the image to a base64 string
|
| 46 |
buffer = BytesIO()
|
| 47 |
+
image.save(
|
| 48 |
+
buffer, format="JPEG"
|
| 49 |
+
) # Use the appropriate format (e.g., JPEG, PNG)
|
| 50 |
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 51 |
|
| 52 |
# add string formatting required by the endpoint
|
|
|
|
| 55 |
return image_string
|
| 56 |
|
| 57 |
image_string = encode_local_image(image_path)
|
| 58 |
+
prompt_with_images = prompt_with_template.replace("<image>", " ").format(
|
| 59 |
+
image_string
|
| 60 |
+
)
|
| 61 |
|
| 62 |
payload = {
|
| 63 |
"inputs": prompt_with_images,
|
|
|
|
| 101 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
| 102 |
|
| 103 |
|
| 104 |
+
headers = {
|
| 105 |
+
"Content-Type": "application/json",
|
| 106 |
+
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}",
|
| 107 |
+
}
|
| 108 |
|
| 109 |
|
| 110 |
def resize_image(image_path):
|
|
|
|
| 124 |
"description": "The path to the image on which to answer the question",
|
| 125 |
"type": "string",
|
| 126 |
},
|
| 127 |
+
"question": {
|
| 128 |
+
"description": "the question to answer",
|
| 129 |
+
"type": "string",
|
| 130 |
+
"nullable": True,
|
| 131 |
+
},
|
| 132 |
}
|
| 133 |
output_type = "string"
|
| 134 |
# try use the same model with two different endpoints
|
|
|
|
| 149 |
output = process_images_and_text(new_image_path, question, self.client)
|
| 150 |
|
| 151 |
if add_note:
|
| 152 |
+
output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
|
|
|
|
|
|
|
| 153 |
|
| 154 |
return output
|
| 155 |
|
|
|
|
| 167 |
add_note = True
|
| 168 |
question = "Please write a detailed caption for this image."
|
| 169 |
if not isinstance(image_path, str):
|
| 170 |
+
raise Exception(
|
| 171 |
+
"You should provide at least `image_path` string argument to this tool!"
|
| 172 |
+
)
|
| 173 |
|
| 174 |
mime_type, _ = mimetypes.guess_type(image_path)
|
| 175 |
base64_image = encode_image(image_path)
|
|
|
|
| 181 |
"role": "user",
|
| 182 |
"content": [
|
| 183 |
{"type": "text", "text": "what is in this image" + question},
|
| 184 |
+
{
|
| 185 |
+
"type": "image_url",
|
| 186 |
+
"image_url": {"url": f"data:{mime_type};base64,{base64_image}"},
|
| 187 |
+
},
|
| 188 |
],
|
| 189 |
}
|
| 190 |
],
|
| 191 |
"max_tokens": 1000,
|
| 192 |
}
|
| 193 |
+
response = requests.post(
|
| 194 |
+
"https://openrouter.ai/api/v1", headers=headers, json=payload
|
| 195 |
+
)
|
| 196 |
try:
|
| 197 |
output = response.json()["choices"][0]["message"]["content"]
|
| 198 |
except Exception:
|
|
|
|
| 202 |
output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
|
| 203 |
|
| 204 |
# TO DO: write to yaml or chromadb -> HF Dataset in due course...
|
| 205 |
+
|
| 206 |
return output
|