Leonardo commited on
Commit
eaaf050
·
verified ·
1 Parent(s): 24730c9

Sync local Space with Hub

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ logs/
3
+ data
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
  title: ODR
3
  emoji: 🏆
4
- colorFrom: purple
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.23.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: OpenAI's Deep Research, but open. Forked m-ric repo!
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: ODR
3
  emoji: 🏆
4
+ colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.14.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: OpenAI's Deep Research, but open
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,44 +1,131 @@
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
  # Copyright 2024 The Footscray Coding Collective. All rights reserved.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import mimetypes
5
  import os
6
  import re
7
  import shutil
8
- from typing import Optional
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import gradio as gr
 
11
  from dotenv import load_dotenv
12
  from huggingface_hub import login
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from scripts.flux_lora_tool import FluxLoRATool
14
  from scripts.text_cleaner_tool import TextCleanerTool
15
  from scripts.text_inspector_tool import TextInspectorTool
16
- from scripts.text_web_browser import (
17
- ArchiveSearchTool,
18
- FinderTool,
19
- FindNextTool,
20
- PageDownTool,
21
- PageUpTool,
22
- SimpleTextBrowser,
23
- VisitTool,
24
- )
25
  from scripts.visual_qa import visualizer
26
- from smolagents import (
27
- CodeAgent,
28
- GoogleSearchTool,
29
- HfApiModel,
30
- LiteLLMModel,
31
- OpenAIServerModel,
32
- Tool,
33
- TransformersModel,
34
  )
35
- from smolagents.agent_types import AgentAudio, AgentImage, AgentText
36
- from smolagents.gradio_ui import handle_agent_output_types, pull_messages_from_step
37
 
38
  # ------------------------ Configuration and Setup ------------------------
39
  # Constants and configurations
40
  AUTHORIZED_IMPORTS = [
41
  "requests", # Web requests (fetching data from the internet)
 
42
  "zipfile", # Working with ZIP archives
43
  "pandas", # Data manipulation and analysis (DataFrames)
44
  "numpy", # Numerical computing (arrays, linear algebra)
@@ -48,7 +135,7 @@ AUTHORIZED_IMPORTS = [
48
  "pubchempy", # Accessing PubChem chemical database
49
  "yaml",
50
  "xml", # XML processing
51
- "yahoo_finance", # Fetching stock data
52
  "Bio", # Bioinformatics tools (e.g., sequence analysis)
53
  "sklearn", # Scikit-learn for machine learning
54
  "scipy", # Scientific computing (stats, optimization)
@@ -74,7 +161,7 @@ AUTHORIZED_IMPORTS = [
74
  "time", # Measuring time
75
  "tempfile", # Creating temporary files and directories
76
  # Data Visualization (if needed) - Consider security implications carefully
77
- "matplotlib", # Plotting library (basic charts)
78
  "seaborn", # Statistical data visualization (more advanced)
79
  # Web Scraping (more specific/controlled) - Consider ethical implications
80
  "lxml", # Faster XML/HTML processing (alternative to bs4)
@@ -85,6 +172,7 @@ AUTHORIZED_IMPORTS = [
85
  "schedule", # Allow the agent to schedule tasks
86
  "uuid",
87
  "base64",
 
88
  ]
89
 
90
  USER_AGENT = (
@@ -93,7 +181,7 @@ USER_AGENT = (
93
  )
94
  BROWSER_CONFIG = {
95
  "viewport_size": 1024 * 5,
96
- "downloads_folder": "downloads_folder",
97
  "request_kwargs": {
98
  "headers": {"User-Agent": USER_AGENT},
99
  "timeout": 300,
@@ -103,7 +191,6 @@ BROWSER_CONFIG = {
103
 
104
  CUSTOM_ROLE_CONVERSIONS = {"tool-call": "assistant", "tool-response": "user"}
105
 
106
-
107
  ALLOWED_FILE_TYPES = [
108
  "application/pdf",
109
  "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
@@ -121,23 +208,108 @@ ALLOWED_FILE_TYPES = [
121
  ]
122
 
123
 
124
- def setup_environment():
125
- """Initialize environment variables and authentication."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  load_dotenv(override=True)
 
 
127
  if os.getenv("HF_TOKEN"): # Check if token is actually set
128
  login(os.getenv("HF_TOKEN"))
129
- print("HF_TOKEN (last 10 characters):", os.getenv("HF_TOKEN")[-10:])
130
  else:
131
- print("HF_TOKEN not found in environment variables.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  # ------------------------ Model and Tool Management ------------------------
135
  class ModelManager:
136
- """Manages model loading and initialization."""
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  @staticmethod
139
  def load_model(chosen_inference: str, model_id: str, key_manager=None):
140
- """Load the specified model with appropriate configuration."""
 
 
 
 
 
 
 
 
 
 
 
141
  try:
142
  if chosen_inference == "hf_api":
143
  return HfApiModel(model_id=model_id)
@@ -156,7 +328,7 @@ class ModelManager:
156
  model_id=model_id, api_key=key_manager.get_key("openai_api_key")
157
  )
158
 
159
- elif chosen_inference == "transformers":
160
  return TransformersModel(
161
  model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
162
  device_map="auto",
@@ -167,41 +339,114 @@ class ModelManager:
167
  raise ValueError(f"Invalid inference type: {chosen_inference}")
168
 
169
  except Exception as e:
170
- print(f"✗ Couldn't load model: {e}")
171
  raise
172
 
173
 
 
174
  class ToolRegistry:
175
- """Manages tool initialization and organization."""
176
 
177
  @staticmethod
178
- def load_web_tools(model, browser, text_limit=20000):
179
- """Initialize and return web-related tools."""
 
 
 
 
 
 
 
 
 
 
 
 
180
  return [
181
- GoogleSearchTool(provider="serper"),
182
- VisitTool(browser),
183
- PageUpTool(browser),
184
- PageDownTool(browser),
185
- FinderTool(browser),
186
- FindNextTool(browser),
187
- ArchiveSearchTool(browser),
188
  TextInspectorTool(model, text_limit),
189
  ]
190
 
191
  @staticmethod
192
- def load_document_tools():
193
  """
194
- Initialize and return document processing, i.e. sanitisation and indexing, tools.
 
195
  Returns:
196
- List of document tools
197
  """
198
  return [
199
  TextCleanerTool(),
200
  ]
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  @staticmethod
203
  def load_image_generation_tools():
204
- """Initialize and return image generation tools."""
 
 
 
 
 
205
  try:
206
  return Tool.from_space(
207
  space_id="xkerser/FLUX.1-dev",
@@ -209,95 +454,219 @@ class ToolRegistry:
209
  description="Generates high-quality AgentImage using the FLUX.1-dev model based on text prompts.",
210
  )
211
  except Exception as e:
212
- print(f"✗ Couldn't initialize image generation tool: {e}")
213
- return FluxLoRATool
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
 
216
- # ------------------------ Agent Creation and Execution ------------------------
217
- def create_agent():
218
  """
219
- Creates a fresh agent instance with properly configured tools.
 
 
 
 
 
 
 
220
  Returns:
221
- CodeAgent: Configured agent ready for use
 
222
  Raises:
223
- ValueError: If tool validation fails
224
  RuntimeError: If agent creation fails
225
  """
226
  try:
227
- # Initialize model
228
- model = LiteLLMModel(
229
- custom_role_conversions=CUSTOM_ROLE_CONVERSIONS,
230
- model_id="openrouter/google/gemini-2.0-flash-001",
231
- )
232
 
233
  # Initialize tools
234
  text_limit = 30000
235
  browser = SimpleTextBrowser(**BROWSER_CONFIG)
236
 
237
- # Collect all tools in a single list
238
- web_tools = ToolRegistry.load_web_tools(model, browser, text_limit)
239
- doc_tools = ToolRegistry.load_document_tools() # New document tools
 
 
 
240
  image_generator = ToolRegistry.load_image_generation_tools()
241
-
242
- # Combine all tools into a single list
243
- all_tools = [visualizer] + web_tools + doc_tools + [image_generator]
 
 
 
 
 
 
 
 
 
 
244
 
245
  # Validate tools before creating agent
246
- for tool in all_tools:
247
- if not isinstance(tool, Tool):
248
- raise ValueError(
249
- f"Invalid tool type: {type(tool)}. "
250
- f"All tools must be instances of Tool class."
251
- )
252
 
253
  return CodeAgent(
254
  model=model,
255
  tools=all_tools,
256
- max_steps=12,
257
  verbosity_level=2,
258
  additional_authorized_imports=AUTHORIZED_IMPORTS,
259
- planning_interval=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
261
- except (ValueError, RuntimeError) as e:
262
- print(f"Failed to create agent: {e}")
263
  raise RuntimeError(f"Agent creation failed: {e}")
264
 
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  def stream_to_gradio(
267
  agent,
268
  task: str,
269
  reset_agent_memory: bool = False,
270
  additional_args: Optional[dict] = None,
271
  ):
272
- """Runs an agent with the given task and streams messages as Gradio ChatMessages."""
273
- for step_log in agent.run(
274
- task, stream=True, reset=reset_agent_memory, additional_args=additional_args
275
- ):
276
- for message in pull_messages_from_step(step_log):
277
- yield message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
- # Process final answer : Use a more comprehensive media output
280
- final_answer = step_log # Last log is the run's final_answer
281
- final_answer = handle_agent_output_types(final_answer)
 
 
282
 
283
- if isinstance(final_answer, AgentText):
284
- yield gr.ChatMessage(
285
- role="assistant",
286
- content=f"**Final answer:**\n{final_answer.to_string()}\n",
287
- )
288
- if isinstance(final_answer, AgentImage):
289
- yield gr.ChatMessage(
290
- role="assistant",
291
- content={"image": final_answer.to_string(), "type": "file"},
292
- ) # Send as Gradio-compatible file object:
293
- if isinstance(final_answer, AgentAudio):
 
 
 
 
 
 
 
 
 
 
294
  yield gr.ChatMessage(
295
  role="assistant",
296
- content={"audio": final_answer.to_string(), "type": "file"},
297
- ) # Send as Gradio-compatible file object
298
- else:
299
- yield gr.ChatMessage(
300
- role="assistant", content=f"**Final answer:** {str(final_answer)}"
301
  )
302
 
303
 
@@ -313,20 +682,37 @@ class GradioUI:
313
  if not os.path.exists(file_upload_folder):
314
  os.mkdir(file_upload_folder)
315
 
316
- def interact_with_agent(self, prompt, messages, session_state):
317
- """Main interaction handler with the agent."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
  # Get or create session-specific agent
320
  if "agent" not in session_state:
321
- session_state["agent"] = create_agent()
 
 
 
322
 
323
  # Adding monitoring
324
  try:
325
  # Log the existence of agent memory
326
  has_memory = hasattr(session_state["agent"], "memory")
327
- print(f"Agent has memory: {has_memory}")
328
  if has_memory:
329
- print(f"Memory type: {type(session_state['agent'].memory)}")
330
 
331
  messages.append(gr.ChatMessage(role="user", content=prompt))
332
  yield messages
@@ -339,7 +725,7 @@ class GradioUI:
339
  yield messages # Yield messages one last time
340
 
341
  except Exception as e:
342
- print(f"Error in interaction: {str(e)}")
343
  raise
344
 
345
  def upload_file(
@@ -448,7 +834,7 @@ class GradioUI:
448
  @gr.render()
449
  def layout(request: gr.Request):
450
  device = self.detect_device(request)
451
- print(f"device - {device}")
452
  # Render layout with sidebar
453
  if device == "Desktop":
454
  return self._create_desktop_layout()
@@ -464,7 +850,7 @@ class GradioUI:
464
  with gr.Sidebar():
465
  gr.Markdown(
466
  """#OpenDeepResearch - 3theSmolagents!
467
- Model_id: google/gemini-2.0-flash-001"""
468
  )
469
  with gr.Group():
470
  gr.Markdown("**What's on your mind mate?**", container=True)
@@ -635,18 +1021,75 @@ class GradioUI:
635
  )
636
 
637
 
638
- # ------------------------ Execution ------------------------
639
- def main():
640
- """Main entry point for the application."""
641
- # Initialize environment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
  setup_environment()
643
 
644
- # Ensure downloads folder exists
645
- os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
- # Launch UI
648
- GradioUI(file_upload_folder="uploaded_files").launch()
 
 
 
 
 
 
 
 
 
 
 
649
 
650
 
 
651
  if __name__ == "__main__":
652
- main()
 
 
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
  # Copyright 2024 The Footscray Coding Collective. All rights reserved.
4
+ """
5
+ Financial Research Agent: Advanced Market Analysis and Data Access
6
+
7
+ This script implements a comprehensive financial research agent capable of performing market analysis,
8
+ retrieving financial data, and providing interactive research capabilities through either a GUI or
9
+ command-line interface.
10
+
11
+ The agent leverages the Smolagents framework to create an autonomous system that can:
12
+ 1. Access and analyze real-time market data through Alpha Vantage API integration
13
+ 2. Process financial documents and extract relevant information
14
+ 3. Perform web searches and analyze webpage content
15
+ 4. Create visualizations of financial data
16
+ 5. Generate comprehensive financial analysis reports
17
+ 6. Handle user uploads of various document types
18
+
19
+ Key Components:
20
+ -------------
21
+ - ModelManager: Handles loading and configuration of various LLM models
22
+ - ToolRegistry: Manages initialization and organization of tools available to the agent
23
+ - GradioUI: Provides a user-friendly interface with responsive design for desktop/mobile
24
+ - A robust set of financial tools for retrieving stock data, financial statements, and market sentiment
25
+ - Web browsing capabilities with text extraction and analysis
26
+ - Document processing for PDFs, spreadsheets, and other common file formats
27
+ - Visualization tools for creating charts and graphs from financial data
28
+
29
+ Usage:
30
+ -----
31
+ Run in UI mode (default):
32
+ python app.py
33
+
34
+ Run in headless mode with a specific query:
35
+ python app.py --mode headless --query "Analyze Tesla's financial performance for 2023"
36
+
37
+ Configuration:
38
+ ------------
39
+ The script uses environment variables for API keys and other configuration settings.
40
+ Required environment variables:
41
+ - ALPHA_VANTAGE_API_KEY: For accessing financial data APIs
42
+ - HF_TOKEN: For accessing Hugging Face models (optional)
43
+
44
+ The agent also maintains detailed logs in the logs/ directory for debugging and auditing.
45
+
46
+ Dependencies:
47
+ -----------
48
+ - smolagents: Core framework for agent capabilities
49
+ - gradio: For the web interface
50
+ - Alpha Vantage API integration: For financial data
51
+ - Various data processing libraries: For handling and analyzing financial information
52
+
53
+ Technical Notes:
54
+ --------------
55
+ - The agent runs with a configurable number of maximum steps (default: 20)
56
+ - Planning occurs at regular intervals (default: every 4 steps)
57
+ - The agent has access to a curated list of authorized Python imports for security
58
+ - All file uploads are validated for type and size before processing
59
+
60
+ Created by the Footscray Coding Collective
61
+ Copyright 2024, All rights reserved
62
+ """
63
+ import contextlib
64
+ import datetime
65
+ import logging
66
  import mimetypes
67
  import os
68
  import re
69
  import shutil
70
+ from typing import Any, Dict, Generator, List, Optional, Tuple
71
 
72
+ # Typer for CLI functionality
73
+ import typer
74
+
75
+ # Telemetry imports (optional)
76
+ with contextlib.suppress(ImportError):
77
+ from openinference.instrumentation.smolagents import SmolagentsInstrumentor
78
+ from phoenix.otel import register
79
+
80
+ # Initialize telemetry for observability and tracing
81
+ register()
82
+ SmolagentsInstrumentor().instrument()
83
+
84
+ # third-party
85
  import gradio as gr
86
+ import pytz
87
  from dotenv import load_dotenv
88
  from huggingface_hub import login
89
+ from rich.console import Console
90
+ from rich.logging import RichHandler
91
+ from smolagents import FinalAnswerTool # smolagents
92
+ from smolagents import (CodeAgent, GoogleSearchTool, HfApiModel, LiteLLMModel,
93
+ OpenAIServerModel, Tool, TransformersModel)
94
+ from smolagents.agent_types import AgentText
95
+ from smolagents.gradio_ui import (handle_agent_output_types,
96
+ pull_messages_from_step)
97
+
98
+ # local
99
+ from scripts.finance_tools import (DataVisualizationTool,
100
+ FinancialCalculatorTool, TrendAnalysisTool,
101
+ get_balance_sheet_data, get_cash_flow_data,
102
+ get_company_overview_data,
103
+ get_earnings_data,
104
+ get_income_statement_data,
105
+ get_market_news_sentiment,
106
+ get_stock_quote_data, get_time_series_daily,
107
+ search_symbols)
108
  from scripts.flux_lora_tool import FluxLoRATool
109
  from scripts.text_cleaner_tool import TextCleanerTool
110
  from scripts.text_inspector_tool import TextInspectorTool
111
+ from scripts.text_web_browser import (ArchiveSearchTool, DownloadTool,
112
+ FinderTool, FindNextTool, PageDownTool,
113
+ PageUpTool, SimpleTextBrowser, VisitTool)
114
+ from scripts.time_tools import get_temporal_context
 
 
 
 
 
115
  from scripts.visual_qa import visualizer
116
+
117
+ # Initialize console and app
118
+ console = Console()
119
+ app = typer.Typer(
120
+ help="Financial Research Agent - Access market data and analysis through a CLI or UI",
121
+ add_completion=False,
 
 
122
  )
 
 
123
 
124
  # ------------------------ Configuration and Setup ------------------------
125
  # Constants and configurations
126
  AUTHORIZED_IMPORTS = [
127
  "requests", # Web requests (fetching data from the internet)
128
+ "pytz", # Timezone handling
129
  "zipfile", # Working with ZIP archives
130
  "pandas", # Data manipulation and analysis (DataFrames)
131
  "numpy", # Numerical computing (arrays, linear algebra)
 
135
  "pubchempy", # Accessing PubChem chemical database
136
  "yaml",
137
  "xml", # XML processing
138
+ "yahoo_finance", # Fetching stock datauv
139
  "Bio", # Bioinformatics tools (e.g., sequence analysis)
140
  "sklearn", # Scikit-learn for machine learning
141
  "scipy", # Scientific computing (stats, optimization)
 
161
  "time", # Measuring time
162
  "tempfile", # Creating temporary files and directories
163
  # Data Visualization (if needed) - Consider security implications carefully
164
+ "matplotlib.plt", # Plotting library
165
  "seaborn", # Statistical data visualization (more advanced)
166
  # Web Scraping (more specific/controlled) - Consider ethical implications
167
  "lxml", # Faster XML/HTML processing (alternative to bs4)
 
172
  "schedule", # Allow the agent to schedule tasks
173
  "uuid",
174
  "base64",
175
+ "smolagents", # smolagents package to be able to create smolagents tools
176
  ]
177
 
178
  USER_AGENT = (
 
181
  )
182
  BROWSER_CONFIG = {
183
  "viewport_size": 1024 * 5,
184
+ "downloads_folder": "data/downloads_folder",
185
  "request_kwargs": {
186
  "headers": {"User-Agent": USER_AGENT},
187
  "timeout": 300,
 
191
 
192
  CUSTOM_ROLE_CONVERSIONS = {"tool-call": "assistant", "tool-response": "user"}
193
 
 
194
  ALLOWED_FILE_TYPES = [
195
  "application/pdf",
196
  "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
 
208
  ]
209
 
210
 
211
+ # Set up logging configuration
212
+ def setup_logging() -> Tuple[str, logging.Logger]:
213
+ """
214
+ Configure logging with structured output and file storage.
215
+
216
+ The function creates logs directory and timestamped log filename, sets up
217
+ logging with Rich integration and creates and returns logger.
218
+
219
+ Returns:
220
+ Tuple containing the log file path and configured logger
221
+ """
222
+ # Create logs directory
223
+ current_dir = os.path.dirname(os.path.abspath(__file__))
224
+ logs_dir = os.path.join(current_dir, "logs")
225
+ os.makedirs(logs_dir, exist_ok=True)
226
+
227
+ # Generate timestamped log filename
228
+ melbourne_timezone = pytz.timezone("Australia/Melbourne")
229
+ log_filename = f'smolagents_{datetime.datetime.now(melbourne_timezone).strftime("%Y%m%d_%H%M%S")}.log'
230
+ log_file = os.path.join(logs_dir, log_filename)
231
+
232
+ # Set up logging with Rich integration
233
+ logging.basicConfig(
234
+ level=logging.INFO,
235
+ format="%(asctime)s [%(levelname)s] - %(message)s",
236
+ datefmt="%Y-%m-%d %H:%M:%S",
237
+ handlers=[
238
+ RichHandler(rich_tracebacks=True, show_time=True),
239
+ logging.FileHandler(log_file),
240
+ ],
241
+ )
242
+
243
+ # Create and return logger
244
+ logger = logging.getLogger(__name__)
245
+ return log_file, logger
246
+
247
+
248
+ LOG_FILE, logger = setup_logging()
249
+
250
+
251
+ def setup_environment() -> None:
252
+ """Initialize environment variables and authentication.
253
+
254
+ This function ensures that required environment variables are set and
255
+ attempts to authenticate with Hugging Face and Alpha Vantage services.
256
+ """
257
  load_dotenv(override=True)
258
+
259
+ # Check Hugging Face token
260
  if os.getenv("HF_TOKEN"): # Check if token is actually set
261
  login(os.getenv("HF_TOKEN"))
262
+ console.print("HF_TOKEN loaded successfully")
263
  else:
264
+ console.print(
265
+ "[yellow]HF_TOKEN not found in environment variables. "
266
+ "Some features may not work properly.[/yellow]"
267
+ )
268
+
269
+ # Check Alpha Vantage API key
270
+ try:
271
+ # Ensure Alpha Vantage API key is available
272
+ api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
273
+ if not api_key:
274
+ console.print(
275
+ "[yellow]⚠️ Warning: ALPHA_VANTAGE_API_KEY not found. "
276
+ "Finance tools may not work properly.[/yellow]"
277
+ )
278
+ else:
279
+ console.print("[green]✓ ALPHA_VANTAGE_API_KEY loaded successfully[/green]")
280
+ except Exception as e:
281
+ console.print(f"[red]Error checking ALPHA_VANTAGE_API_KEY: {e}[/red]")
282
 
283
 
284
  # ------------------------ Model and Tool Management ------------------------
285
  class ModelManager:
286
+ """Manages model loading and initialization.
287
+
288
+ This class provides a static method to load the specified model with the
289
+ appropriate configuration. It supports the following inference types:
290
+ - hf_api: Use the Hugging Face API to load the model.
291
+ - hf_api_provider: Use the Hugging Face API to load the model with the
292
+ 'together' provider.
293
+ - litellm: Load the LiteLLM model with the specified model ID.
294
+ - openai: Load the OpenAI model with the specified model ID and API key.
295
+ - transformers: Load the Hugging Face transformers model with the
296
+ specified model ID and configuration.
297
+ """
298
 
299
  @staticmethod
300
  def load_model(chosen_inference: str, model_id: str, key_manager=None):
301
+ """Load the specified model with appropriate configuration.
302
+
303
+ Args:
304
+ chosen_inference (str): The inference type to use.
305
+ model_id (str): The model ID to load.
306
+ key_manager (Optional[KeyManager]): The key manager to use for
307
+ loading the model. Required for OpenAI models.
308
+
309
+ Raises:
310
+ ValueError: If the chosen inference type is invalid.
311
+ Exception: If an error occurs while loading the model.
312
+ """
313
  try:
314
  if chosen_inference == "hf_api":
315
  return HfApiModel(model_id=model_id)
 
328
  model_id=model_id, api_key=key_manager.get_key("openai_api_key")
329
  )
330
 
331
+ if chosen_inference == "transformers":
332
  return TransformersModel(
333
  model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
334
  device_map="auto",
 
339
  raise ValueError(f"Invalid inference type: {chosen_inference}")
340
 
341
  except Exception as e:
342
+ console.print(f"[red]✗ Couldn't load model: {e}[/red]")
343
  raise
344
 
345
 
346
+ # ------------------------ Tool Registration ------------------------
347
  class ToolRegistry:
348
+ """Manages tool initialization and organization using Zhou Protocol priorities."""
349
 
350
  @staticmethod
351
+ def load_information_tools(model, text_limit=30000):
352
+ """
353
+ Initialize and return information analysis tools.
354
+
355
+ This method creates tools for analyzing text from documents, and other sources.
356
+ The information tools should be prioritized first in the agent's toolset.
357
+
358
+ Args:
359
+ model: Language model to use for analysis
360
+ text_limit: Maximum character length for text summaries
361
+
362
+ Returns:
363
+ List of information analysis tools
364
+ """
365
  return [
 
 
 
 
 
 
 
366
  TextInspectorTool(model, text_limit),
367
  ]
368
 
369
  @staticmethod
370
+ def load_utility_tools():
371
  """
372
+ Initialize and return utility tools for text cleaning and normalization.
373
+
374
  Returns:
375
+ List of utility tools
376
  """
377
  return [
378
  TextCleanerTool(),
379
  ]
380
 
381
+ @staticmethod
382
+ def load_time_tools():
383
+ """
384
+ Initialize and return time-related tools.
385
+
386
+ Returns:
387
+ List of time-related tools
388
+ """
389
+ return [get_temporal_context]
390
+
391
+ @staticmethod
392
+ def load_finance_tools():
393
+ """
394
+ Initialize and return financial analysis tools.
395
+
396
+ Returns:
397
+ List of financial tools in priority order
398
+ """
399
+ return [
400
+ # Analysis tools first (higher priority)
401
+ DataVisualizationTool(),
402
+ FinancialCalculatorTool(),
403
+ TrendAnalysisTool(),
404
+ # Data retrieval tools next
405
+ search_symbols,
406
+ get_stock_quote_data,
407
+ get_company_overview_data,
408
+ get_earnings_data,
409
+ get_income_statement_data,
410
+ get_balance_sheet_data,
411
+ get_cash_flow_data,
412
+ get_time_series_daily,
413
+ get_market_news_sentiment,
414
+ ]
415
+
416
+ @staticmethod
417
+ def load_web_tools(browser, text_limit=20000):
418
+ """
419
+ Initialize and return web interaction tools.
420
+
421
+ Args:
422
+ browser: Browser instance for web navigation
423
+ text_limit: Maximum character length for text processing
424
+
425
+ Returns:
426
+ List of web tools in priority order
427
+ """
428
+ return [
429
+ # Search tools first
430
+ GoogleSearchTool(provider="serper"),
431
+ # Navigation tools next
432
+ VisitTool(browser),
433
+ DownloadTool(browser),
434
+ # Page interaction tools last
435
+ PageUpTool(browser),
436
+ PageDownTool(browser),
437
+ FinderTool(browser),
438
+ FindNextTool(browser),
439
+ ArchiveSearchTool(browser),
440
+ ]
441
+
442
  @staticmethod
443
  def load_image_generation_tools():
444
+ """
445
+ Initialize and return image generation tools.
446
+
447
+ Returns:
448
+ Image generation tool or fallback
449
+ """
450
  try:
451
  return Tool.from_space(
452
  space_id="xkerser/FLUX.1-dev",
 
454
  description="Generates high-quality AgentImage using the FLUX.1-dev model based on text prompts.",
455
  )
456
  except Exception as e:
457
+ console.print(
458
+ f"[yellow]✗ Couldn't initialize image generation tool: {e}[/yellow]"
459
+ )
460
+ return FluxLoRATool()
461
+
462
+ @staticmethod
463
+ def load_final_answer_tool():
464
+ """
465
+ Return the final answer tool for providing conclusive responses.
466
+
467
+ Returns:
468
+ List containing the final answer tool
469
+ """
470
+ return [FinalAnswerTool()]
471
 
472
 
473
+ def create_agent(model_id: str = "openrouter/google/gemini-2.0-flash-001"):
 
474
  """
475
+ Create a fresh agent instance with properly configured tools.
476
+
477
+ This function creates a CodeAgent with tools organized by the Zhou Protocol
478
+ priority system, ensuring the most relevant tools are considered first.
479
+
480
+ Args:
481
+ model_id: The ID of the model to use for the agent
482
+
483
  Returns:
484
+ A configured CodeAgent instance
485
+
486
  Raises:
 
487
  RuntimeError: If agent creation fails
488
  """
489
  try:
490
+ # Initialize model with fallback system
491
+ model = _load_model_with_fallback(model_id)
 
 
 
492
 
493
  # Initialize tools
494
  text_limit = 30000
495
  browser = SimpleTextBrowser(**BROWSER_CONFIG)
496
 
497
+ # Collect all tools with proper Zhou Protocol prioritization
498
+ information_tools = ToolRegistry.load_information_tools(model, text_limit)
499
+ utility_tools = ToolRegistry.load_utility_tools()
500
+ finance_tools = ToolRegistry.load_finance_tools()
501
+ web_tools = ToolRegistry.load_web_tools(browser)
502
+ time_tools = ToolRegistry.load_time_tools()
503
  image_generator = ToolRegistry.load_image_generation_tools()
504
+ final_answer = ToolRegistry.load_final_answer_tool()
505
+
506
+ # Combine all tools with information tools prioritized first
507
+ all_tools = (
508
+ information_tools # Critical information extraction (highest priority)
509
+ + utility_tools # General utility functions
510
+ + finance_tools # Financial analysis capabilities
511
+ + web_tools # Web search and navigation
512
+ + time_tools # Time context tools
513
+ + [visualizer] # Image analysis
514
+ + [image_generator] # Image generation
515
+ + final_answer # Task completion (always last)
516
+ )
517
 
518
  # Validate tools before creating agent
519
+ _validate_tools(all_tools)
 
 
 
 
 
520
 
521
  return CodeAgent(
522
  model=model,
523
  tools=all_tools,
524
+ max_steps=20,
525
  verbosity_level=2,
526
  additional_authorized_imports=AUTHORIZED_IMPORTS,
527
+ planning_interval=4,
528
+ description="""
529
+ This agent assists with comprehensive research and financial analysis. It first analyzes
530
+ any provided documents or text, then leverages specialized financial tools and web search
531
+ capabilities to provide thorough insights.
532
+
533
+ QUERY COMPREHENSION FRAMEWORK
534
+ Before answering any complex question, apply the Zhou Comprehension Pattern:
535
+ 1. **Initial Parse**: What is literally being asked?
536
+ 2. **Intent Detection**: What is the user actually trying to accomplish?
537
+ 3. **Knowledge Assessment**: What information is needed to address this properly?
538
+ 4. **Tool Selection**: Which tools provide the most direct path to a solution?
539
+ 5. **Execution Planning**: What sequence of operations will yield the best result?
540
+
541
+ CLARIFICATION CHECKLIST
542
+ When faced with ambiguous queries, the agent should systematically clarify:
543
+ * **Scope**: "How comprehensive should this analysis be?"
544
+ * **Format**: "What form would you like the results in?"
545
+ * **Technical Level**: "Should I explain technical details or focus on practical applications?"
546
+ * **Time Horizon**: "Are you interested in historical data, current status, or future projections?"
547
+ * **Priority**: "Which aspect of this question is most important to you?"
548
+ """.strip(),
549
  )
550
+ except Exception as e:
551
+ console.print(f"[red]✗ Agent creation failed: {e}[/red]")
552
  raise RuntimeError(f"Agent creation failed: {e}")
553
 
554
 
555
+ def _load_model_with_fallback(model_id: str) -> Any:
556
+ """
557
+ Attempt to load the specified model with fallbacks if it fails.
558
+
559
+ Args:
560
+ model_id: Primary model ID to try loading
561
+
562
+ Returns:
563
+ Loaded model instance
564
+
565
+ Raises:
566
+ RuntimeError: If all model loading attempts fail
567
+ """
568
+ # Fallback model chain from most capable to most reliable
569
+ fallback_models = [
570
+ model_id, # Try the requested model first
571
+ "openrouter/anthropic/claude-3.7-sonnet",
572
+ "openai/gpt-4o-mini",
573
+ "anthropic/claude-3.7-sonnet",
574
+ "HuggingFaceTB/SmolLM2-1.7B-Instruct", # Last resort local option
575
+ ]
576
+
577
+ last_error = None
578
+ for model in fallback_models:
579
+ try:
580
+ return LiteLLMModel(
581
+ custom_role_conversions=CUSTOM_ROLE_CONVERSIONS,
582
+ model_id=model,
583
+ )
584
+ except Exception as e:
585
+ last_error = e
586
+ console.print(f"[yellow]Failed to load model {model}: {e}[/yellow]")
587
+
588
+ # If we get here, all models failed
589
+ raise RuntimeError(f"All model loading attempts failed. Last error: {last_error}")
590
+
591
+
592
+ def _validate_tools(tools):
593
+ """
594
+ Validate that all tools are proper Tool instances.
595
+
596
+ Args:
597
+ tools: List of tools to validate
598
+
599
+ Raises:
600
+ ValueError: If any tool is not a Tool instance
601
+ """
602
+ for tool in tools:
603
+ if not isinstance(tool, Tool):
604
+ raise ValueError(
605
+ f"Invalid tool type: {type(tool)}. "
606
+ f"All tools must be instances of Tool class."
607
+ )
608
+
609
+
610
+ # ------------------------ Gradio UI Components ------------------------
611
+
612
+
613
  def stream_to_gradio(
614
  agent,
615
  task: str,
616
  reset_agent_memory: bool = False,
617
  additional_args: Optional[dict] = None,
618
  ):
619
+ """Streams agent responses with improved status indicators."""
620
+ try:
621
+ # Initial processing indicator
622
+ yield gr.ChatMessage(role="assistant", content="⏳ Processing your request...")
623
+
624
+ # Track what we've yielded to replace the processing indicator
625
+ first_message_yielded = False
626
+
627
+ for step_log in agent.run(
628
+ task, stream=True, reset=reset_agent_memory, additional_args=additional_args
629
+ ):
630
+ # The key fix: pull_messages_from_step is a generator function that yields messages
631
+ # We need to iterate through each yielded message
632
+ for message in pull_messages_from_step(step_log):
633
+ if not first_message_yielded:
634
+ # Replace the initial "Processing" message
635
+ first_message_yielded = True
636
+ message.content = message.content.replace(
637
+ "⏳ Processing your request...", ""
638
+ )
639
 
640
+ # Check what type of operation is being performed based on the metadata or content
641
+ # Instead of trying to access a 'status' attribute that doesn't exist
642
+ content_lower = (
643
+ message.content.lower() if hasattr(message, "content") else ""
644
+ )
645
 
646
+ if "document analysis" in content_lower:
647
+ message.content = f"📄 **Document Analysis:** {message.content}"
648
+ elif "search" in content_lower:
649
+ message.content = f"🔍 **Search:** {message.content}"
650
+
651
+ yield message
652
+
653
+ # Final answer with enhanced formatting
654
+ final_answer = handle_agent_output_types(step_log)
655
+
656
+ if isinstance(final_answer, AgentText):
657
+ yield gr.ChatMessage(
658
+ role="assistant",
659
+ content=f"✅ **Final Answer:**\n\n{final_answer.to_string()}",
660
+ )
661
+ else:
662
+ yield gr.ChatMessage(
663
+ role="assistant", content=f"✅ **Final Answer:** {str(final_answer)}"
664
+ )
665
+
666
+ except Exception as e:
667
  yield gr.ChatMessage(
668
  role="assistant",
669
+ content=f" **Error:** {str(e)}\n\nPlease try again with a different query.",
 
 
 
 
670
  )
671
 
672
 
 
682
  if not os.path.exists(file_upload_folder):
683
  os.mkdir(file_upload_folder)
684
 
685
+ def interact_with_agent(
686
+ self,
687
+ prompt: str,
688
+ messages: List[gr.ChatMessage],
689
+ session_state: Dict[str, Any],
690
+ ) -> Generator[List[gr.ChatMessage], None, None]:
691
+ """Main interaction handler with the agent.
692
+
693
+ Args:
694
+ prompt: The user's input prompt
695
+ messages: The list of messages so far (including the user's prompt)
696
+ session_state: The current state of the user's session
697
+
698
+ Yields:
699
+ A list of messages after each step (including the user's prompt)
700
+ """
701
 
702
  # Get or create session-specific agent
703
  if "agent" not in session_state:
704
+ model_id = session_state.get(
705
+ "model_id", "openrouter/google/gemini-2.0-flash-001"
706
+ )
707
+ session_state["agent"] = create_agent(model_id)
708
 
709
  # Adding monitoring
710
  try:
711
  # Log the existence of agent memory
712
  has_memory = hasattr(session_state["agent"], "memory")
713
+ console.print(f"Agent has memory: {has_memory}")
714
  if has_memory:
715
+ console.print(f"Memory type: {type(session_state['agent'].memory)}")
716
 
717
  messages.append(gr.ChatMessage(role="user", content=prompt))
718
  yield messages
 
725
  yield messages # Yield messages one last time
726
 
727
  except Exception as e:
728
+ console.print(f"[red]Error in interaction: {str(e)}[/red]")
729
  raise
730
 
731
  def upload_file(
 
834
  @gr.render()
835
  def layout(request: gr.Request):
836
  device = self.detect_device(request)
837
+ console.print(f"device - {device}")
838
  # Render layout with sidebar
839
  if device == "Desktop":
840
  return self._create_desktop_layout()
 
850
  with gr.Sidebar():
851
  gr.Markdown(
852
  """#OpenDeepResearch - 3theSmolagents!
853
+ Model_id: deepseek/deepseek-r1"""
854
  )
855
  with gr.Group():
856
  gr.Markdown("**What's on your mind mate?**", container=True)
 
1021
  )
1022
 
1023
 
1024
+ # ------------------------ CLI Command ------------------------
1025
+ @app.command()
1026
+ def run(
1027
+ mode: str = typer.Option(
1028
+ "ui",
1029
+ "--mode",
1030
+ "-m",
1031
+ help="Operating mode: 'ui' for Gradio interface or 'headless' for CLI mode",
1032
+ ),
1033
+ model_id: str = typer.Option(
1034
+ "openrouter/google/gemini-2.0-flash-001",
1035
+ "--model",
1036
+ help="Model ID to use for the agent",
1037
+ ),
1038
+ query: Optional[str] = typer.Option(
1039
+ None, "--query", "-q", help="Query to execute (required in headless mode)"
1040
+ ),
1041
+ ):
1042
+ """
1043
+ Run the financial research agent in either UI or headless mode.
1044
+
1045
+ In UI mode, launches a Gradio interface for interactive use.
1046
+ In headless mode, processes a single query and outputs the result to the console.
1047
+ """
1048
+ # Setup environment variables
1049
  setup_environment()
1050
 
1051
+ # Validate inputs for headless mode
1052
+ if mode == "headless" and not query:
1053
+ console.print("[red]Error: query parameter is required in headless mode[/red]")
1054
+ raise typer.Exit(code=1)
1055
+
1056
+ # Create agent with specified model ID
1057
+ console.print(f"[bold]Initializing agent with model:[/bold] {model_id}")
1058
+
1059
+ # Execute in appropriate mode
1060
+ if mode == "ui":
1061
+ console.print(
1062
+ "[bold green]Starting UI mode with Gradio interface...[/bold green]"
1063
+ )
1064
+
1065
+ # Ensure downloads folder exists
1066
+ os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
1067
+
1068
+ # Launch UI
1069
+ GradioUI(file_upload_folder="data/uploaded_files").launch()
1070
+
1071
+ elif mode == "headless":
1072
+ console.print(f"[bold]Processing query in headless mode:[/bold] {query}")
1073
+
1074
+ # Create agent for headless mode
1075
+ agent = create_agent(model_id)
1076
 
1077
+ # Show a simple spinner during processing
1078
+ with console.status("[bold green]Processing query...[/bold green]"):
1079
+ result = agent.run(query)
1080
+
1081
+ # Display the results
1082
+ console.print("\n[bold green]Results:[/bold green]")
1083
+ console.print(result)
1084
+
1085
+ else:
1086
+ console.print(
1087
+ f"[red]Error: Invalid mode '{mode}'. Use 'ui' or 'headless'[/red]"
1088
+ )
1089
+ raise typer.Exit(code=1)
1090
 
1091
 
1092
+ # ------------------------ Main Entry Point ------------------------
1093
  if __name__ == "__main__":
1094
+ # Use the typer app as the entry point
1095
+ app()
flux_image.py DELETED
File without changes
requirements.txt CHANGED
@@ -1,13 +1,9 @@
 
1
  anthropic>=0.37.1
2
  beautifulsoup4>=4.12.3
3
- Bio
4
- chess
5
- clean-text[gpl]
6
  datasets>=2.21.0
7
  google_search_results>=2.4.2
8
  huggingface_hub>=0.23.4
9
- llama-index
10
- llama-index-embeddings-huggingface
11
  mammoth>=1.8.0
12
  markdownify>=0.13.1
13
  numexpr>=2.10.1
@@ -19,25 +15,26 @@ pathvalidate>=3.2.1
19
  pdfminer>=20191125
20
  pdfminer.six>=20240706
21
  Pillow>=11.0.0
22
- pubchempy
23
  puremagic>=1.28
24
- pydub
25
- PyPDF2
26
  python-dotenv>=1.0.1
27
  python_pptx>=1.0.2
28
- python-pptx
29
  Requests>=2.32.3
30
- scikit-learn
31
- scikit-learn
32
- scipy
33
  serpapi>=0.1.5
34
- smolagents[gradio, langchain, litellm, telemetry]
35
- SpeechRecognition
36
- sympy
37
  torch>=2.2.2
38
  torchvision>=0.17.2
39
- tqdm>=4.66.4
40
- tqdm
41
  transformers>=4.46.0
42
- xlrd
43
  youtube_transcript_api>=0.6.2
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ smolagents[litellm, telemetry]
2
  anthropic>=0.37.1
3
  beautifulsoup4>=4.12.3
 
 
 
4
  datasets>=2.21.0
5
  google_search_results>=2.4.2
6
  huggingface_hub>=0.23.4
 
 
7
  mammoth>=1.8.0
8
  markdownify>=0.13.1
9
  numexpr>=2.10.1
 
15
  pdfminer>=20191125
16
  pdfminer.six>=20240706
17
  Pillow>=11.0.0
 
18
  puremagic>=1.28
19
+ pypdf>=5.1.0
 
20
  python-dotenv>=1.0.1
21
  python_pptx>=1.0.2
 
22
  Requests>=2.32.3
 
 
 
23
  serpapi>=0.1.5
24
+ tqdm>=4.66.4
 
 
25
  torch>=2.2.2
26
  torchvision>=0.17.2
 
 
27
  transformers>=4.46.0
 
28
  youtube_transcript_api>=0.6.2
29
+ chess
30
+ sympy
31
+ pubchempy
32
+ Bio
33
+ scikit-learn
34
+ scipy
35
+ pydub
36
+ PyPDF2
37
+ python-pptx
38
+ torch
39
+ xlrd
40
+ SpeechRecognition
scripts/cookies.py CHANGED
@@ -1,6 +1,5 @@
1
  from requests.cookies import RequestsCookieJar
2
 
3
-
4
  COOKIES_LIST = [
5
  {
6
  "domain": ".youtube.com",
@@ -712,4 +711,6 @@ COOKIES = RequestsCookieJar()
712
 
713
  # Add cookies to the jar
714
  for cookie in COOKIES_LIST:
715
- COOKIES.set(cookie["name"], cookie["value"], domain=cookie["domain"], path=cookie["path"])
 
 
 
1
  from requests.cookies import RequestsCookieJar
2
 
 
3
  COOKIES_LIST = [
4
  {
5
  "domain": ".youtube.com",
 
711
 
712
  # Add cookies to the jar
713
  for cookie in COOKIES_LIST:
714
+ COOKIES.set(
715
+ cookie["name"], cookie["value"], domain=cookie["domain"], path=cookie["path"]
716
+ )
scripts/finance_tools.py ADDED
@@ -0,0 +1,987 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The Footscray Coding Collective. All rights reserved.
4
+ """
5
+ Financial Data and Analysis Tools
6
+ --------------------------------------
7
+ A comprehensive suite of tools for retrieving financial market data through the Alpha Vantage API.
8
+ These tools enable accessing real-time stock quotes, company fundamentals, financial statements,
9
+ price history, market news, and sentiment analysis with proper error handling and caching.
10
+
11
+ The Alpha Vantage tools follow the Zhou Protocol for financial data retrieval:
12
+ - Singleton pattern for API client management
13
+ - Comprehensive error handling with failed request tracking
14
+ - In-memory request caching to minimize API usage
15
+ - Detailed docstrings with usage examples
16
+
17
+ Key Financial Tools:
18
+ - search_symbols: Find ticker symbols for companies by keywords
19
+ - get_stock_quote_data: Real-time stock quote information
20
+ - get_company_overview_data: Company profiles and fundamentals
21
+ - get_earnings_data: Quarterly and annual earnings information
22
+ - get_income_statement_data: Income statement analysis
23
+ - get_balance_sheet_data: Balance sheet information
24
+ - get_cash_flow_data: Cash flow statement analysis
25
+ - get_time_series_daily: Historical price and volume data
26
+ - get_market_news_sentiment: News and sentiment analysis
27
+
28
+ Financial Analysis Tools:
29
+ - FinancialCalculatorTool: Calculate financial metrics (growth rates, margins, CAGR)
30
+ - DataVisualizationTool: Generate visual representations of financial data
31
+ - TrendAnalysisTool: Perform year-over-year trend analysis on financial metrics
32
+ """
33
+
34
+ import io
35
+ import logging
36
+ import os
37
+ import traceback
38
+ from typing import Any, Dict, Optional, Set
39
+
40
+ # Third-party imports in alphabetical order with dotenv first
41
+ try:
42
+ from dotenv import load_dotenv
43
+
44
+ load_dotenv()
45
+ except ImportError:
46
+ pass
47
+
48
+ import matplotlib.pyplot as plt # Plot the chart
49
+ import pandas as pd # Store dataframe
50
+ import requests
51
+ from smolagents import Tool, tool
52
+
53
+
54
+ class AlphaVantageClient:
55
+ """Centralized client for Alpha Vantage API requests with caching and error handling."""
56
+
57
+ def __init__(self):
58
+ """Initialize the client with empty caches."""
59
+ self._api_key: Optional[str] = None
60
+ self._failed_requests: Set[str] = set()
61
+ self._data_cache: Dict[str, Dict[str, Any]] = {}
62
+
63
+ def get_api_key(self) -> str:
64
+ """
65
+ Get Alpha Vantage API key from environment or cache.
66
+
67
+ Returns:
68
+ API key string or error message
69
+ """
70
+ if self._api_key:
71
+ return self._api_key
72
+
73
+ api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
74
+ if not api_key:
75
+ return "Error: No API key found. Set ALPHA_VANTAGE_API_KEY in your environment."
76
+
77
+ self._api_key = api_key
78
+ return api_key
79
+
80
+ def make_request(self, function: str, symbol: str, **params: Any) -> Dict[str, Any]:
81
+ """
82
+ Make a request to Alpha Vantage API with error handling and caching.
83
+
84
+ Args:
85
+ function (str): API function name
86
+ symbol (str): Stock symbol
87
+ **params (Any): Additional parameters for the request, excluding 'function' and 'symbol'
88
+
89
+ Returns:
90
+ Dict[str, Any]: Raw JSON response data
91
+ """
92
+ # Validate params
93
+ if "function" in params or "symbol" in params:
94
+ raise ValueError("function and symbol should not be included in params")
95
+
96
+ # Generate cache key
97
+ cache_key = f"{function}:{symbol}:{hash(frozenset(params.items()))}"
98
+
99
+ # Return cached data if available
100
+ if cache_key in self._data_cache:
101
+ return self._data_cache[cache_key]
102
+
103
+ # Check if this request has failed before
104
+ if cache_key in self._failed_requests:
105
+ return {
106
+ "Error": f"Previously failed request for {symbol} with function {function}"
107
+ }
108
+
109
+ # Get API key
110
+ api_key = self.get_api_key()
111
+ if api_key.startswith("Error:"):
112
+ return {"Error Message": api_key}
113
+
114
+ # Build request URL and parameters
115
+ url = "https://www.alphavantage.co/query"
116
+ request_params = {
117
+ "function": function,
118
+ "symbol": symbol,
119
+ "apikey": api_key,
120
+ **params,
121
+ }
122
+
123
+ try:
124
+ # Make request with timeout for responsiveness
125
+ response = requests.get(url, params=request_params, timeout=10)
126
+ response.raise_for_status()
127
+ data = response.json()
128
+
129
+ # Check for API errors
130
+ if "Error Message" in data or "Information" in data or not data:
131
+ self._failed_requests.add(cache_key)
132
+ return data
133
+
134
+ # Cache successful response
135
+ self._data_cache[cache_key] = data
136
+ return data
137
+
138
+ except requests.RequestException as e:
139
+ error_data = {"Error Message": f"API request failed: {str(e)}"}
140
+ self._failed_requests.add(cache_key)
141
+ return error_data
142
+ except ValueError as e:
143
+ error_data = {"Error Message": f"Failed to parse response: {str(e)}"}
144
+ self._failed_requests.add(cache_key)
145
+ return error_data
146
+
147
+ def clear_cache(
148
+ self, function: Optional[str] = None, symbol: Optional[str] = None
149
+ ) -> None:
150
+ """
151
+ Clear the data cache, optionally filtering by function and/or symbol.
152
+
153
+ Args:
154
+ function: Optional function name to filter cache entries
155
+ symbol: Optional symbol to filter cache entries
156
+ """
157
+ if not function and not symbol:
158
+ self._data_cache.clear()
159
+ return
160
+
161
+ keys_to_remove = []
162
+ for key in self._data_cache:
163
+ parts = key.split(":")
164
+ if function and parts[0] != function:
165
+ continue
166
+ if symbol and parts[1] != symbol:
167
+ continue
168
+ keys_to_remove.append(key)
169
+
170
+ for key in keys_to_remove:
171
+ del self._data_cache[key]
172
+
173
+
174
+ # Create a singleton instance of the client
175
+ _client = AlphaVantageClient()
176
+
177
+
178
+ @tool
179
+ def get_stock_quote_data(symbol: str) -> Dict[str, Any]:
180
+ """
181
+ Retrieve raw real-time stock quote information from Alpha Vantage.
182
+
183
+ This tool fetches current market data for a specified stock ticker,
184
+ returning the raw data for custom processing and analysis.
185
+
186
+ Args:
187
+ symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
188
+
189
+ Returns:
190
+ Raw JSON data containing:
191
+ - Global Quote object with price, volume, and trading information
192
+ - Error information if the request failed
193
+
194
+ Example:
195
+ ```python
196
+ # Get raw quote data
197
+ data = get_stock_quote_data("MSFT")
198
+
199
+ # Extract price
200
+ if "Global Quote" in data:
201
+ quote = data["Global Quote"]
202
+ price = float(quote.get("05. price", 0))
203
+ change = float(quote.get("09. change", 0))
204
+ print(f"MSFT: ${price:.2f} ({change:+.2f})")
205
+ ```
206
+ """
207
+ return _client.make_request("GLOBAL_QUOTE", symbol)
208
+
209
+
210
+ @tool
211
+ def get_company_overview_data(symbol: str) -> Dict[str, Any]:
212
+ """
213
+ Retrieve raw company information and metrics from Alpha Vantage.
214
+
215
+ This tool provides comprehensive information about a company, returning
216
+ raw data for custom analysis and presentation.
217
+
218
+ Args:
219
+ symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
220
+
221
+ Returns:
222
+ Raw JSON data containing:
223
+ - Company profile (name, sector, industry)
224
+ - Financial metrics (market cap, P/E ratio, etc.)
225
+ - Performance indicators (ROE, ROA, etc.)
226
+ - Company description
227
+ - Error information if the request failed
228
+
229
+ Example:
230
+ ```python
231
+ # Get company data
232
+ data = get_company_overview_data("AAPL")
233
+
234
+ # Create custom analysis
235
+ if "Sector" in data:
236
+ sector = data.get("Sector")
237
+ market_cap = float(data.get("MarketCapitalization", 0))
238
+ pe_ratio = float(data.get("PERatio", 0))
239
+
240
+ print(f"AAPL is in the {sector} sector")
241
+ print(f"Market Cap: ${market_cap/1e9:.2f}B")
242
+ print(f"P/E Ratio: {pe_ratio:.2f}")
243
+ ```
244
+ """
245
+ return _client.make_request("OVERVIEW", symbol)
246
+
247
+
248
+ @tool
249
+ def get_earnings_data(symbol: str) -> Dict[str, Any]:
250
+ """
251
+ Retrieve raw earnings data for a company from Alpha Vantage.
252
+
253
+ This tool fetches quarterly and annual earnings data, returning
254
+ raw information for custom analysis and trend evaluation.
255
+
256
+ Args:
257
+ symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
258
+
259
+ Returns:
260
+ Raw JSON data containing:
261
+ - quarterlyEarnings array with fiscal dates, reported EPS, and surprises
262
+ - annualEarnings array with yearly EPS figures
263
+ - Error information if the request failed
264
+
265
+ Example:
266
+ ```python
267
+ # Get earnings data
268
+ data = get_earnings_data("MSFT")
269
+
270
+ # Analyze earnings surprises
271
+ if "quarterlyEarnings" in data:
272
+ quarterly = data["quarterlyEarnings"]
273
+
274
+ # Calculate average earnings surprise percentage
275
+ surprises = [float(q.get("surprisePercentage", 0)) for q in quarterly[:4]]
276
+ avg_surprise = sum(surprises) / len(surprises)
277
+
278
+ print(f"Average earnings surprise (last 4Q): {avg_surprise:.2f}%")
279
+
280
+ # Find biggest positive surprise
281
+ max_surprise = max(surprises)
282
+ print(f"Largest positive surprise: {max_surprise:.2f}%")
283
+ ```
284
+ """
285
+ return _client.make_request("EARNINGS", symbol)
286
+
287
+
288
+ @tool
289
+ def get_income_statement_data(symbol: str) -> Dict[str, Any]:
290
+ """
291
+ Retrieve raw income statement data for a company from Alpha Vantage.
292
+
293
+ This tool fetches annual and quarterly income statements, returning
294
+ raw financial data for custom analysis and profit trend evaluation.
295
+
296
+ Args:
297
+ symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
298
+
299
+ Returns:
300
+ Raw JSON data containing:
301
+ - annualReports array with yearly income statements
302
+ - quarterlyReports array with quarterly income statements
303
+ - Error information if the request failed
304
+
305
+ Example:
306
+ ```python
307
+ # Get income statement data
308
+ data = get_income_statement_data("AAPL")
309
+
310
+ # Analyze profitability trends
311
+ if "annualReports" in data and len(data["annualReports"]) >= 3:
312
+ reports = data["annualReports"][:3] # Last 3 years
313
+
314
+ # Extract revenue and profit
315
+ revenues = [float(r.get("totalRevenue", 0)) for r in reports]
316
+ net_incomes = [float(r.get("netIncome", 0)) for r in reports]
317
+
318
+ # Calculate profit margins
319
+ margins = [ni/rev*100 if rev else 0 for ni, rev in zip(net_incomes, revenues)]
320
+
321
+ for i, margin in enumerate(margins):
322
+ year = reports[i].get("fiscalDateEnding", "Unknown")
323
+ print(f"{year}: Profit margin = {margin:.2f}%")
324
+ ```
325
+ """
326
+ return _client.make_request("INCOME_STATEMENT", symbol)
327
+
328
+
329
+ @tool
330
+ def get_balance_sheet_data(symbol: str) -> Dict[str, Any]:
331
+ """
332
+ Retrieve raw balance sheet data for a company from Alpha Vantage.
333
+
334
+ This tool fetches annual and quarterly balance sheets, returning
335
+ raw financial data for custom analysis of a company's financial position.
336
+
337
+ Args:
338
+ symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
339
+
340
+ Returns:
341
+ Raw JSON data containing:
342
+ - annualReports array with yearly balance sheets
343
+ - quarterlyReports array with quarterly balance sheets
344
+ - Error information if the request failed
345
+
346
+ Example:
347
+ ```python
348
+ # Get balance sheet data
349
+ data = get_balance_sheet_data("MSFT")
350
+
351
+ # Calculate debt-to-equity ratio
352
+ if "annualReports" in data and data["annualReports"]:
353
+ latest = data["annualReports"][0]
354
+
355
+ total_debt = float(latest.get("shortTermDebt", 0)) + float(latest.get("longTermDebt", 0))
356
+ equity = float(latest.get("totalShareholderEquity", 0))
357
+
358
+ if equity:
359
+ debt_to_equity = total_debt / equity
360
+ print(f"Debt-to-Equity Ratio: {debt_to_equity:.2f}")
361
+
362
+ # Calculate current ratio
363
+ current_assets = float(latest.get("totalCurrentAssets", 0))
364
+ current_liabilities = float(latest.get("totalCurrentLiabilities", 0))
365
+
366
+ if current_liabilities:
367
+ current_ratio = current_assets / current_liabilities
368
+ print(f"Current Ratio: {current_ratio:.2f}")
369
+ ```
370
+ """
371
+ return _client.make_request("BALANCE_SHEET", symbol)
372
+
373
+
374
+ @tool
375
+ def get_cash_flow_data(symbol: str) -> Dict[str, Any]:
376
+ """
377
+ Retrieve raw cash flow statement data for a company from Alpha Vantage.
378
+
379
+ This tool fetches annual and quarterly cash flow statements, returning
380
+ raw financial data for analyzing a company's cash generation and usage.
381
+
382
+ Args:
383
+ symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
384
+
385
+ Returns:
386
+ Raw JSON data containing:
387
+ - annualReports array with yearly cash flow statements
388
+ - quarterlyReports array with quarterly cash flow statements
389
+ - Error information if the request failed
390
+
391
+ Example:
392
+ ```python
393
+ # Get cash flow data
394
+ data = get_cash_flow_data("AMZN")
395
+
396
+ # Analyze free cash flow
397
+ if "annualReports" in data and data["annualReports"]:
398
+ reports = data["annualReports"][:3] # Last 3 years
399
+
400
+ for report in reports:
401
+ year = report.get("fiscalDateEnding", "Unknown")
402
+ operating_cf = float(report.get("operatingCashflow", 0))
403
+ capex = float(report.get("capitalExpenditures", 0))
404
+
405
+ # Free cash flow = Operating cash flow - Capital expenditures
406
+ free_cf = operating_cf - abs(capex)
407
+
408
+ print(f"{year}: Free Cash Flow = ${free_cf/1e9:.2f}B")
409
+ ```
410
+ """
411
+ return _client.make_request("CASH_FLOW", symbol)
412
+
413
+
414
+ @tool
415
+ def get_time_series_daily(symbol: str, outputsize: str = "compact") -> Dict[str, Any]:
416
+ """
417
+ Retrieve daily time series stock price data from Alpha Vantage.
418
+
419
+ This tool fetches historical daily OHLCV (Open, High, Low, Close, Volume) data
420
+ for specified ticker symbols, supporting both compact (100 data points) and
421
+ full (20+ years) history.
422
+
423
+ Args:
424
+ symbol: The stock ticker symbol (e.g., 'AAPL', 'MSFT', 'IBM')
425
+ outputsize: Data size, either 'compact' (last 100 points) or 'full' (20+ years)
426
+
427
+ Returns:
428
+ Raw JSON data containing:
429
+ - "Meta Data" object with information about the data series
430
+ - "Time Series (Daily)" object with date-keyed OHLCV data points
431
+ - Error information if the request failed
432
+
433
+ Example:
434
+ ```python
435
+ # Get daily prices (compact = last 100 days)
436
+ data = get_time_series_daily("TSLA")
437
+
438
+ # Calculate moving averages
439
+ if "Time Series (Daily)" in data:
440
+ time_series = data["Time Series (Daily)"]
441
+ dates = sorted(time_series.keys())
442
+
443
+ # Extract closing prices
444
+ prices = [float(time_series[date]["4. close"]) for date in dates]
445
+
446
+ # Calculate 20-day moving average
447
+ if len(prices) >= 20:
448
+ ma_20 = sum(prices[-20:]) / 20
449
+ print(f"20-day Moving Average: ${ma_20:.2f}")
450
+
451
+ # Get latest price
452
+ latest_price = prices[-1]
453
+ print(f"Latest price: ${latest_price:.2f}")
454
+
455
+ # Compare to moving average
456
+ diff_pct = (latest_price / ma_20 - 1) * 100
457
+ print(f"Price is {diff_pct:+.2f}% from 20-day MA")
458
+ ```
459
+ """
460
+ return _client.make_request("TIME_SERIES_DAILY", symbol, outputsize=outputsize)
461
+
462
+
463
+ # Ensure that the default value IS specified
464
+ @tool
465
+ def search_symbols(keywords: str) -> Dict[str, Any]:
466
+ """
467
+ [FINANCIAL DISCOVERY] Search for stock symbols matching the provided keywords.
468
+
469
+ WHEN TO USE: ALWAYS use this tool FIRST when you don't know the exact stock symbol for a company.
470
+
471
+ This tool helps find relevant ticker symbols when you don't know the exact symbol,
472
+ matching companies by name, description, or partial symbols.
473
+
474
+ Args:
475
+ keywords: Search term (e.g., 'microsoft', 'tech', 'MSFT')
476
+
477
+ Returns:
478
+ Raw JSON data containing:
479
+ - bestMatches array with matching companies (symbol, name, type, region)
480
+ - Error information if the request failed
481
+
482
+ Example:
483
+ ```python
484
+ # Search for companies related to "electric vehicles"
485
+ results = search_symbols("electric vehicles")
486
+
487
+ # Print matched symbols and names
488
+ if "bestMatches" in results:
489
+ matches = results["bestMatches"]
490
+
491
+ print(f"Found {len(matches)} matches:")
492
+ for match in matches:
493
+ symbol = match.get("1. symbol", "")
494
+ name = match.get("2. name", "")
495
+ market = match.get("4. region", "")
496
+
497
+ print(f"{symbol} - {name} ({market})")
498
+ ```
499
+ """
500
+ return _client.make_request("SYMBOL_SEARCH", "", keywords=keywords)
501
+
502
+
503
+ @tool
504
+ def clear_api_cache() -> str:
505
+ """
506
+ Clear all cached API data to force fresh requests.
507
+
508
+ Returns:
509
+ Confirmation message
510
+ """
511
+ _client._data_cache.clear()
512
+ return "API cache cleared successfully."
513
+
514
+
515
+ @tool
516
+ def get_market_news_sentiment(
517
+ tickers: Optional[str] = None,
518
+ topics: Optional[str] = None,
519
+ time_from: Optional[str] = None,
520
+ time_to: Optional[str] = None,
521
+ sort: str = "LATEST",
522
+ limit: int = 50,
523
+ ) -> Dict[str, Any]:
524
+ """
525
+ Retrieve market news and sentiment data from Alpha Vantage.
526
+
527
+ This tool fetches live and historical market news with sentiment analysis from premier
528
+ news outlets worldwide, covering stocks, cryptocurrencies, forex, and various market topics.
529
+
530
+ Args:
531
+ tickers: Optional comma-separated list of symbols (e.g., 'AAPL,MSFT' or 'COIN,CRYPTO:BTC,FOREX:USD')
532
+ topics: Optional comma-separated list of news topics (e.g., 'technology,ipo')
533
+ Available topics: blockchain, earnings, ipo, mergers_and_acquisitions, financial_markets,
534
+ economy_fiscal, economy_monetary, economy_macro, energy_transportation, finance,
535
+ life_sciences, manufacturing, real_estate, retail_wholesale, technology
536
+ time_from: Optional start time in YYYYMMDDTHHMM format (e.g., '20220410T0130')
537
+ time_to: Optional end time in YYYYMMDDTHHMM format
538
+ sort: Sorting order - 'LATEST' (default), 'EARLIEST', or 'RELEVANCE'
539
+ limit: Maximum number of results to return (default: 50, max: 1000)
540
+
541
+ Returns:
542
+ Raw JSON data containing:
543
+ - feed: Array of news articles with title, summary, url, time_published, authors, and more
544
+ - sentiment scores for each article (if available)
545
+ - Error information if the request failed
546
+
547
+ Example:
548
+ ```python
549
+ # Get latest news about Apple
550
+ apple_news = get_market_news_sentiment(tickers="AAPL")
551
+
552
+ # Get news articles at the intersection of technology and IPOs
553
+ tech_ipo_news = get_market_news_sentiment(topics="technology,ipo")
554
+
555
+ # Get Bitcoin news from a specific time period
556
+ btc_news = get_market_news_sentiment(
557
+ tickers="CRYPTO:BTC",
558
+ time_from="20230101T0000",
559
+ time_to="20230201T0000"
560
+ )
561
+
562
+ # Process the sentiment data
563
+ if "feed" in apple_news:
564
+ for article in apple_news["feed"]:
565
+ title = article.get("title", "No title")
566
+ sentiment = article.get("overall_sentiment_score", "N/A")
567
+ print(f"Article: {title} | Sentiment: {sentiment}")
568
+ ```
569
+ """
570
+ params = {
571
+ "function": "NEWS_SENTIMENT",
572
+ }
573
+
574
+ # Add optional parameters
575
+ if tickers:
576
+ params["tickers"] = tickers
577
+ if topics:
578
+ params["topics"] = topics
579
+ if time_from:
580
+ params["time_from"] = time_from
581
+ if time_to:
582
+ params["time_to"] = time_to
583
+ if sort:
584
+ params["sort"] = sort
585
+ if limit:
586
+ params["limit"] = limit
587
+
588
+ return _client.make_request("NEWS_SENTIMENT", "", **params)
589
+
590
+
591
+ """Example functions to be used in the tools and called by the agent"""
592
+
593
+
594
+ class FinancialCalculatorTool(Tool):
595
+ """
596
+ Performs various financial calculations, given structured data from a table.
597
+ Useful for calculating growth rates, financial ratios, and other key metrics.
598
+ The tool can directly perform calculations on the data for numerical answers.
599
+ """
600
+
601
+ name = "financial_calculator"
602
+ description = """
603
+ Performs various financial calculations, given structured data from a table.
604
+ Useful for calculating growth rates, financial ratios, and other key metrics.
605
+ The tool can directly perform calculations on the data for numerical answers.
606
+
607
+ Input:
608
+ - `data` (str): A string representing table data (e.g., CSV, markdown table).
609
+ - `calculation_type` (str): The type of calculation to perform, such as 'growth_rate', 'profit_margin', 'debt_to_equity'.
610
+ - `year1`, `year2`, `metric` (str): Parameters for "growth", e.g., "2020", "2021", "Revenue".
611
+ - `year`, `revenue`, `netIncome`(str): Parameters for 'Profit_Margin', e.g. "2023", "10000", "1000".
612
+ - `year`, `totalDebt`, `totalEquity` (str): Parameters for 'Debt_To_Equity', e.g. "2023", "5000", "10000".
613
+ - `startYear`, `endYear`, `metric"(str): Parametes for "CAGR", e.g. "2020", "2025", "Revenue"
614
+
615
+ Output:
616
+ - `calculation_result` (str): The result of the financial calculation as a string, to two decimals points.
617
+ This ensures the agent can understand and utilize the output effectively.
618
+ """
619
+
620
+ inputs = {
621
+ "data": {
622
+ "type": "string",
623
+ "description": "A string representing table data. Must be in CSV format with a header row.",
624
+ },
625
+ "calculation_type": {
626
+ "type": "string",
627
+ "description": "The type of calculation to perform. Must be one of the following exactly: 'growth_rate', 'profit_margin', 'debt_to_equity', 'CAGR'.",
628
+ },
629
+ "year1": {
630
+ "type": "string",
631
+ "description": "Year 1 for growth rate calculation, as a string.",
632
+ "nullable": True,
633
+ },
634
+ "metric": {
635
+ "type": "string",
636
+ "description": "Valid CSV Header to compare, for growth. MUST correspond to the appropriate header in dataset.",
637
+ "nullable": True,
638
+ },
639
+ "year2": {
640
+ "type": "string",
641
+ "description": "Year 2 for growth rate calculation, as a string. Make sure that is a valid CSV Header.",
642
+ "nullable": True,
643
+ },
644
+ "revenue": {
645
+ "type": "string",
646
+ "description": "Revenue for the fiscal year profit calculation (as a string).",
647
+ "nullable": True,
648
+ },
649
+ "netIncome": {
650
+ "type": "string",
651
+ "description": "Must be Valid Valid Net income for the fiscal year profit margin calculation, in string format",
652
+ "nullable": True,
653
+ },
654
+ "endYear": {
655
+ "type": "string",
656
+ "description": "Year 2 string for the CAGR function",
657
+ "nullable": True,
658
+ },
659
+ "year": {
660
+ "type": "string",
661
+ "description": "Valid Year",
662
+ "nullable": True,
663
+ },
664
+ "startYear": {
665
+ "type": "string",
666
+ "description": "Year 1, string for the CAGR function",
667
+ "nullable": True,
668
+ },
669
+ "totalAssets": {
670
+ "type": "string",
671
+ "description": "The Total assets data in string format",
672
+ "nullable": True,
673
+ },
674
+ "totalDebt": {
675
+ "type": "string",
676
+ "description": "The total debt data in string.",
677
+ "nullable": True,
678
+ },
679
+ "totalEquity": {
680
+ "type": "string",
681
+ "description": "The Total Shareholders Equity in string format",
682
+ "nullable": True,
683
+ },
684
+ }
685
+ output_type = "string"
686
+
687
+ def forward(
688
+ self,
689
+ data: str, # A string representing the data. Must be a valid CSV
690
+ calculation_type: str, # type of calculation you'd like to do with the data
691
+ year1: Optional[str] = None, # Year1, all string types
692
+ metric: Optional[str] = None, # metric, all string types
693
+ year2: Optional[str] = None, # Year2, all string types
694
+ revenue: Optional[str] = None, # Revenue, all string types
695
+ netIncome: Optional[str] = None, # Net income, all string types
696
+ endYear: Optional[str] = None, # Year 2 string for the CAGR function
697
+ year: Optional[str] = None, # Valid Year
698
+ startYear: Optional[str] = None, # Year 1, string for the CAGR function
699
+ totalAssets: Optional[str] = None, # The Total assets data in string format
700
+ totalDebt: Optional[str] = None, # The total debt data in string.
701
+ totalEquity: Optional[
702
+ str
703
+ ] = None, # The Total Shareholders Equity in string format
704
+ ) -> str:
705
+ """
706
+ Performs the specified financial calculation.
707
+ Args:
708
+ data: A string representing the dat. Must be a valid CSV
709
+ calculation_type: type of calculation you'd like to do with the data
710
+ year1: Year1, all string types
711
+ year2: Year2, all string types
712
+ metric: metric, all string types
713
+
714
+ Returns:
715
+ A string representing the result of the calculation. If an error occurs, the string will start with "Error: "
716
+ """
717
+ try:
718
+ df = pd.read_csv(io.StringIO(data))
719
+ except Exception as e:
720
+ return f"Error reading data: {e}. Ensure that the input provided is a valid csv, AND has headers (no comments or empty rows)."
721
+
722
+ try:
723
+ if calculation_type == "growth_rate":
724
+ if not (year1 and year2 and metric):
725
+ return "Error: Missing year1, year2, or metric for growth_rate calculation."
726
+
727
+ value1 = df.loc[df["Year"] == year1][metric].values[0]
728
+ value2 = df.loc[df["Year"] == year2][metric].values[0]
729
+
730
+ growth_rate = ((value2 - value1) / value1) * 100
731
+ return f"{growth_rate:.2f}%"
732
+
733
+ elif calculation_type == "profit_margin":
734
+ if not year or not revenue or not netIncome:
735
+ return "Error: Missing year for profit_margin calculation"
736
+
737
+ # revenue = df.loc[df['Year'] == year]['Revenue'].values[0] # Replace with your actual data columns
738
+ # net_income = df.loc[df['Year'] == year]['Net Income'].values[0] # This can also be EBIT or operating profit or whatever
739
+
740
+ profit_margin = (float(netIncome) / float(revenue)) * 100
741
+ return f"{profit_margin:.2f}%"
742
+
743
+ elif calculation_type == "debt_to_equity":
744
+ if not year or not totalDebt or not totalEquity:
745
+ return "Error: Missing year for debt_to_equity calculation"
746
+
747
+ # total_debt = df.loc[df['Year'] == year]['Total Debt'].values[0] # Could be short term or long term
748
+ # total_equity = df.loc[df['Year'] == year]['Total Equity'].values[0] # Could be share holders equity?
749
+
750
+ debt_to_equity = float(totalDebt) / float(totalEquity)
751
+ return f"{debt_to_equity:.2f}"
752
+ elif calculation_type == "CAGR":
753
+
754
+ if not (startYear and endYear and metric):
755
+ return "Error: Missing startYear, endYear, or metric for CAGR calculation."
756
+
757
+ try: # Make the CSV valid
758
+ start_value = float(
759
+ df[df["Year"] == startYear][metric].values[0]
760
+ ) # float(start_value) #df[df.columns[1]] #["Start Value"].values[0]
761
+ end_value = float(
762
+ df[df["Year"] == endYear][metric].values[0]
763
+ ) # float(end_value) # float(raw[0]) #df[df.columns[1]] #["End Value"].values[0]# CSV
764
+ except Exception as exception:
765
+ return f"start value {df[df['Year'] == startYear][metric].values[0]} endvalue {df[df['Year'] == endYear][metric].values[0]}. start and end values are not valid headers! Ensure CSV Headers are there, and they're valid. OriginalException{exception}"
766
+ try: # check to confirm the calculations work by converting them to float
767
+ n = int(endYear) - int(startYear)
768
+ cagr = (end_value / start_value) ** (1 / n) - 1
769
+ return f"{cagr:.2f}" # f"EndValue {endYear2:.2f} Startvalue {startYear2:.2f}"
770
+ except Exception:
771
+ return f"start year {startYear} end year {endYear} Startvalue {start_value} end value {end_value}. Year calcs invalid! Invalid CSV"
772
+
773
+ else:
774
+ return f"Error: Unsupported Calculation Type: {calculation_type}. Consider growth_rate, profit_margin, debt_to_equity, CAGR."
775
+ except Exception as e:
776
+ return f"Error performing calculation: {e}"
777
+
778
+
779
+ class DataVisualizationTool(Tool):
780
+ """
781
+ Generates visualizations (charts, graphs) from structured data to help identify trends.
782
+ Be thoughtful about the data AND type of graph: they must match.
783
+ You CANNOT import things other than csv, so make sure to follow the instructions.
784
+ """
785
+
786
+ name = "data_visualization"
787
+ description = """
788
+ Generates visualizations (charts, graphs) from structured data to help identify trends. Be thoughtful about the data AND type of graph: they must match. You CANNOT import things other than csv, so make sure to follow the instructions.
789
+
790
+ Input:
791
+ - `data` (str): A valid CSV string, that represents values to graph: MUST start with a HEADER row, then be followed by valid csv syntax
792
+ - `chart_type` (str): The type of chart/graph to generate, MUST be one of: 'line', 'bar', 'scatter'.
793
+ - `x_axis_label` (str): Label for the x axis. If unsure, set as "years"
794
+ - `y_axis_label` (str): Label for the y axis. If unsure, set as "net income"
795
+
796
+ Output:
797
+ - `plot_string` (str): A verbal description of the plot, especially its overall trend. A short trend is sufficient.
798
+
799
+ """
800
+ inputs = {
801
+ "data": {
802
+ "type": "string",
803
+ "description": "CSV data representing a time series: Start this with headers followed by values!!",
804
+ },
805
+ "chart_type": {
806
+ "type": "string",
807
+ "description": "Type of chart to generate (e.g., MUST be one of 'line', 'bar', 'scatter').",
808
+ },
809
+ "x_axis_label": {
810
+ "type": "string",
811
+ "description": "Label of x-axis, such as 'years' or 'quarters'",
812
+ },
813
+ "y_axis_label": {
814
+ "type": "string",
815
+ "description": "Label of y-axis, such as 'net income' or 'revenue'",
816
+ },
817
+ }
818
+ output_type = "string"
819
+
820
+ def forward(
821
+ self, data: str, chart_type: str, x_axis_label: str, y_axis_label: str
822
+ ) -> str:
823
+ """
824
+ Perform chart visuals
825
+
826
+ Args:
827
+ data (str): string CSV in the correct format
828
+ chart_type (str): one of scatter, line, bar
829
+ x_axis_label (str): label
830
+ y_axis_label (str): label
831
+
832
+ Returns:
833
+ str: A verbal description of the plot, especially its overall trend.
834
+ """
835
+ if not data:
836
+ return "Error: No data provided."
837
+ if not chart_type:
838
+ return "Error: No chart."
839
+ if not x_axis_label:
840
+ return "Error: No x-axis label provided."
841
+ if not y_axis_label:
842
+ return "Error: No y-axis label provided."
843
+ try:
844
+ df = pd.read_csv(io.StringIO(data))
845
+ except Exception as e:
846
+ return f"Problem building data {data}: {e}"
847
+ if len(df.columns) < 2:
848
+ return "Error: Data must have at least two columns."
849
+ try:
850
+ plt.figure(figsize=(10, 6)) # Adjust the figure size for better readability
851
+ if chart_type == "line":
852
+ plt.xlabel(x_axis_label)
853
+ plt.ylabel(y_axis_label)
854
+ plt.plot(
855
+ df[df.columns[0]], df[df.columns[1]]
856
+ ) # [df.columns[0]], df[df.columns[1]]
857
+ elif chart_type == "bar":
858
+ plt.ylabel(y_axis_label)
859
+ plt.xlabel(x_axis_label)
860
+ plt.bar(df[df.columns[0]], df[df.columns[1]]) # .values[0]
861
+ elif chart_type == "scatter":
862
+ plt.ylabel(y_axis_label)
863
+ plt.xlabel(x_axis_label)
864
+ plt.scatter(df[df.columns[0]], df[df.columns[1]]) # .values[0]
865
+ else:
866
+ raise ValueError(f"Unsupported chart type: {chart_type}")
867
+ chart_summary = f"Chart generated, which shows the {chart_type} of {df.columns[1]} with respect to {df.columns[0]}. "
868
+ plt.title(y_axis_label + " vs. " + x_axis_label) # What we're graphing
869
+ # plt.text(80000000000, 80000000000, chart_summary) # Show the chart summary
870
+ plt.show() # actually show the chart to the user, as above shows matplotlib backend
871
+ return chart_summary
872
+ except Exception as e:
873
+ return f"Problem with chart plotting: {e}" # chart_type = None
874
+
875
+
876
+ class TrendAnalysisTool(Tool):
877
+ """
878
+ You can retrieve year over year increase percentages for a specific category by setting the category.
879
+ Please provide a valid CSV. MAKE SURE headers = columns, and that is in the correct format.
880
+ """
881
+
882
+ name = "trend_analysis"
883
+ description = """
884
+ You can retrieve year over year increase percentages for a specific category by setting the category. Please provide a valid CSV. MAKE SURE headers = columns, and that is in the correct format.
885
+ """
886
+ inputs = {
887
+ "data": {
888
+ "type": "string",
889
+ "description": "A string representing the data (e.g., CSV format) - MUST HAVE HEADERS. MUST specify all colums",
890
+ },
891
+ "category": {
892
+ "type": "string",
893
+ "description": "The category we want to compare, such as revenue. Check to know WHAT the name is!!",
894
+ },
895
+ }
896
+ output_type = "string"
897
+
898
+ def forward(self, data: str, category: str) -> str:
899
+ """Make year over year increases for a given csv
900
+ Args:
901
+ data: all the data
902
+ category: the category we want to compare, such as revenue
903
+ """
904
+ try:
905
+ df = pd.read_csv(io.StringIO(data))
906
+ except Exception as e:
907
+ return f"Error reading data: {e}. Ensure valid CSV, and headers are present: {e}!!"
908
+ try:
909
+ df["YoY Change"] = df[category].pct_change() * 100
910
+ df["YoY Change"] = df["YoY Change"].map("{:.2f}%".format)
911
+ change_description = df.to_string() #
912
+ return change_description
913
+ except Exception as e:
914
+ return f"Error with trend analysis: {e}. Check the name or data!!"
915
+
916
+
917
+ # ###########################
918
+ # # Example loading the tools:
919
+ # ###########################
920
+
921
+ # # def load_finance_tools():
922
+ # # finance_tools = [
923
+ # # get_stock_quote_data,
924
+ # # get_company_overview_data,
925
+ # # get_earnings_data,
926
+ # # get_income_statement_data,
927
+ # # get_balance_sheet_data,
928
+ # # get_cash_flow_data,
929
+ # # get_time_series_daily,
930
+ # # search_symbols,
931
+ # # DataVisualizationTool(),
932
+ # # FinancialCalculatorTool(),
933
+ # # TrendAnalysisTool()
934
+ # # ]
935
+ # # return finance_tools
936
+
937
+
938
+ def load_finance_tools():
939
+ """Initialize and return finance tools for data retrieval and analysis.
940
+ You MUST put all the correct tools in here, or it will not run.
941
+ """
942
+
943
+ finance_tools = []
944
+ # finance_tools_names = [] # was getting errors on loading
945
+
946
+ def safe_tool_load(tool_func, tool_name):
947
+ """Helper to safely load and append a finance tool."""
948
+ try:
949
+ finance_tools.append(tool_func)
950
+ # finance_tools_names.append(tool_func.__name__) # was getting errors on loading
951
+ logging.info(f"Loaded {tool_name} tool successfully")
952
+ except Exception as e:
953
+ logging.error(f"Failed to load tool {tool_name}: {e}")
954
+ logging.error(traceback.format_exc()) # Print the stack trace
955
+
956
+ # Financial calculation tools first
957
+ safe_tool_load(DataVisualizationTool(), "DataVisualizationTool")
958
+ safe_tool_load(FinancialCalculatorTool(), "FinancialCalculatorTool")
959
+ safe_tool_load(TrendAnalysisTool(), "TrendAnalysisTool")
960
+ # Raw data retrieval tools last
961
+ safe_tool_load(get_stock_quote_data, "get_stock_quote_data")
962
+ safe_tool_load(get_company_overview_data, "get_company_overview_data")
963
+ safe_tool_load(get_earnings_data, "get_earnings_data")
964
+ safe_tool_load(get_income_statement_data, "get_income_statement_data")
965
+ safe_tool_load(get_balance_sheet_data, "get_balance_sheet_data")
966
+ safe_tool_load(get_cash_flow_data, "get_cash_flow_data")
967
+ safe_tool_load(get_time_series_daily, "get_time_series_daily")
968
+ safe_tool_load(search_symbols, "search_symbols")
969
+ safe_tool_load(get_market_news_sentiment, "get_market_news_sentiment")
970
+
971
+ return finance_tools
972
+
973
+
974
+ __all__ = [
975
+ "get_stock_quote_data",
976
+ "get_company_overview_data",
977
+ "get_earnings_data",
978
+ "get_income_statement_data",
979
+ "get_balance_sheet_data",
980
+ "get_cash_flow_data",
981
+ "get_time_series_daily",
982
+ "search_symbols",
983
+ "get_market_news_sentiment",
984
+ "DataVisualizationTool",
985
+ "FinancialCalculatorTool",
986
+ "TrendAnalysisTool",
987
+ ]
scripts/flux_lora_tool.py CHANGED
@@ -12,30 +12,28 @@ Usage:
12
  agent = CodeAgent(tools=[flux_tool], ...)
13
  """
14
 
 
15
  import os
16
- import uuid
17
  import tempfile
18
- import logging
19
- from typing import Dict, Any, Optional, List, Union, Tuple
20
  from dataclasses import dataclass
21
- import contextlib
22
- from pathlib import Path
23
 
24
  # Third-party
25
  import requests
26
- from PIL import Image
27
  from gradio_client import Client
28
-
29
- # Smolagents
30
  from smolagents import Tool
31
 
32
  # -----------------------------------------------------------------------------
33
  # CONSTANTS AND TYPE DEFINITIONS
34
  # -----------------------------------------------------------------------------
35
 
 
36
  @dataclass
37
  class LoRAModelInfo:
38
  """Value object representing LoRA model information."""
 
39
  name: str
40
  description: Optional[str] = None
41
  example_image_url: Optional[str] = None
@@ -44,6 +42,7 @@ class LoRAModelInfo:
44
  @dataclass
45
  class ImageGenerationResult:
46
  """Value object representing a generated image result."""
 
47
  image_path: str
48
  seed: int
49
  metadata: Optional[Dict[str, Any]] = None
@@ -53,14 +52,15 @@ class ImageGenerationResult:
53
  # CORE TOOL IMPLEMENTATION
54
  # -----------------------------------------------------------------------------
55
 
 
56
  class FluxLoRATool(Tool):
57
  """
58
  Tool for generating images using FLUX-LoRA-DLC API.
59
-
60
  This tool implements the Zhou Protocol integration patterns to provide
61
  a clean, efficient interface for image generation using LoRA models.
62
  """
63
-
64
  name = "flux_lora_generator"
65
  description = """
66
  Generates high-quality images using FLUX-LoRA models.
@@ -68,74 +68,74 @@ class FluxLoRATool(Tool):
68
  """
69
  inputs = {
70
  "prompt": {
71
- "type": "string",
72
- "description": "Detailed description of the desired image."
73
  },
74
  "image_input": {
75
- "type": "string",
76
  "description": "Optional URL or file path to input image for img2img generation.",
77
- "optional": True
78
  },
79
  "image_strength": {
80
  "type": "float",
81
  "description": "Strength of input image influence (0.0-1.0), where 1.0 maintains more of original image.",
82
  "optional": True,
83
- "default": 0.75
84
  },
85
  "cfg_scale": {
86
  "type": "float",
87
  "description": "Guidance scale for prompt adherence (1.0-30.0).",
88
  "optional": True,
89
- "default": 3.5
90
  },
91
  "steps": {
92
  "type": "integer",
93
  "description": "Number of sampling steps (10-100).",
94
  "optional": True,
95
- "default": 28
96
  },
97
  "seed": {
98
  "type": "integer",
99
  "description": "Random seed for reproducibility. Use -1 for random seed.",
100
  "optional": True,
101
- "default": -1
102
  },
103
  "width": {
104
  "type": "integer",
105
  "description": "Image width in pixels.",
106
  "optional": True,
107
- "default": 1024
108
  },
109
  "height": {
110
  "type": "integer",
111
  "description": "Image height in pixels.",
112
  "optional": True,
113
- "default": 1024
114
  },
115
  "lora_scale": {
116
  "type": "float",
117
  "description": "LoRA influence scale (0.0-1.0).",
118
  "optional": True,
119
- "default": 0.95
120
  },
121
  "custom_lora": {
122
  "type": "string",
123
  "description": "Custom LoRA model to use. Leave empty for default.",
124
- "optional": True
125
- }
126
  }
127
  output_type = "string"
128
-
129
  def __init__(
130
- self,
131
  api_url: str = "xkerser/FLUX-LoRA-DLC",
132
  image_save_dir: Optional[str] = None,
133
  connection_timeout: int = 60,
134
- verbose: bool = False
135
  ):
136
  """
137
  Initialize the FLUX-LoRA Tool with Zhou Protocol connection patterns.
138
-
139
  Args:
140
  api_url: URL or endpoint ID for the FLUX-LoRA-DLC API
141
  image_save_dir: Directory to save generated images (created if doesn't exist)
@@ -143,66 +143,67 @@ class FluxLoRATool(Tool):
143
  verbose: Enable detailed logging
144
  """
145
  super().__init__()
146
-
147
  # Initialize logging
148
  self.logger = logging.getLogger("flux_lora_tool")
149
  self.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
150
-
151
  # Set up client and storage directories
152
  self.api_url = api_url
153
  self.connection_timeout = connection_timeout
154
  self._client = None # Lazy initialization
155
-
156
  # Set up image storage directory
157
- self.image_save_dir = image_save_dir or os.path.join(tempfile.gettempdir(), "flux_lora_images")
 
 
158
  os.makedirs(self.image_save_dir, exist_ok=True)
159
- self.logger.info(f"FluxLoRATool initialized. Images will be saved to: {self.image_save_dir}")
160
-
 
 
161
  @property
162
  def client(self) -> Client:
163
  """
164
  Get or initialize the Gradio client with proper connection handling.
165
-
166
  Returns:
167
  Initialized Gradio client
168
-
169
  Raises:
170
  ConnectionError: If client initialization fails
171
  """
172
  if self._client is None:
173
  try:
174
- self._client = Client(
175
- self.api_url,
176
- timeout=self.connection_timeout
177
- )
178
  self.logger.debug(f"Gradio client initialized for: {self.api_url}")
179
  except Exception as e:
180
  error_msg = f"Failed to initialize FLUX-LoRA client: {str(e)}"
181
  self.logger.error(error_msg)
182
  raise ConnectionError(error_msg) from e
183
-
184
  return self._client
185
-
186
  def _validate_inputs(self, **kwargs) -> Dict[str, Any]:
187
  """
188
  Validate and normalize input parameters with Zhou Protocol validation patterns.
189
-
190
  Args:
191
  **kwargs: Input parameters
192
-
193
  Returns:
194
  Validated and normalized parameters
195
-
196
  Raises:
197
  ValueError: If input validation fails
198
  """
199
  validated = {}
200
-
201
  # Required parameter: prompt
202
  if not kwargs.get("prompt"):
203
  raise ValueError("Prompt is required for image generation")
204
  validated["prompt"] = kwargs["prompt"]
205
-
206
  # Image input handling
207
  if "image_input" in kwargs and kwargs["image_input"]:
208
  input_image = kwargs["image_input"]
@@ -215,7 +216,7 @@ class FluxLoRATool(Tool):
215
  if not os.path.exists(input_image):
216
  raise ValueError(f"Image file not found: {input_image}")
217
  validated["image_input"] = input_image
218
-
219
  # Numeric parameter validation with constraints
220
  numeric_params = {
221
  "image_strength": {"min": 0.0, "max": 1.0, "default": 0.75},
@@ -223,13 +224,13 @@ class FluxLoRATool(Tool):
223
  "steps": {"min": 10, "max": 100, "default": 28},
224
  "width": {"min": 128, "max": 2048, "default": 1024},
225
  "height": {"min": 128, "max": 2048, "default": 1024},
226
- "lora_scale": {"min": 0.0, "max": 1.0, "default": 0.95}
227
  }
228
-
229
  for param, constraints in numeric_params.items():
230
  if param in kwargs and kwargs[param] is not None:
231
  value = kwargs[param]
232
-
233
  # Type conversion if needed
234
  if param in ["steps", "width", "height"]:
235
  try:
@@ -241,17 +242,17 @@ class FluxLoRATool(Tool):
241
  value = float(value)
242
  except (ValueError, TypeError):
243
  raise ValueError(f"Parameter '{param}' must be a number")
244
-
245
  # Range validation
246
  if value < constraints["min"] or value > constraints["max"]:
247
  raise ValueError(
248
  f"Parameter '{param}' must be between {constraints['min']} and {constraints['max']}"
249
  )
250
-
251
  validated[param] = value
252
  else:
253
  validated[param] = constraints["default"]
254
-
255
  # Special handling for seed
256
  if "seed" in kwargs and kwargs["seed"] is not None:
257
  try:
@@ -264,6 +265,7 @@ class FluxLoRATool(Tool):
264
  self.logger.warning(f"Failed to get random seed from API: {e}")
265
  # Fallback to Python's random
266
  import random
 
267
  seed = random.randint(0, 2**32 - 1)
268
  validated["seed"] = seed
269
  except (ValueError, TypeError):
@@ -271,57 +273,56 @@ class FluxLoRATool(Tool):
271
  else:
272
  # Default to random seed
273
  validated["seed"] = self._get_random_seed()
274
-
275
  # Custom LoRA handling
276
  if "custom_lora" in kwargs and kwargs["custom_lora"]:
277
  validated["custom_lora"] = kwargs["custom_lora"]
278
-
279
  return validated
280
-
281
  def _download_image(self, url: str) -> str:
282
  """
283
  Download image from URL and save to local file.
284
-
285
  Args:
286
  url: Image URL
287
-
288
  Returns:
289
  Local file path
290
-
291
  Raises:
292
  ConnectionError: If download fails
293
  """
294
  try:
295
  response = requests.get(url, stream=True, timeout=30)
296
  response.raise_for_status()
297
-
298
  # Generate temporary file path
299
  file_ext = self._guess_extension(response.headers.get("Content-Type", ""))
300
  temp_path = os.path.join(
301
- self.image_save_dir,
302
- f"input_{uuid.uuid4().hex}{file_ext}"
303
  )
304
-
305
  # Save image
306
  with open(temp_path, "wb") as f:
307
  for chunk in response.iter_content(chunk_size=8192):
308
  f.write(chunk)
309
-
310
  self.logger.debug(f"Downloaded image from {url} to {temp_path}")
311
  return temp_path
312
-
313
  except Exception as e:
314
  error_msg = f"Failed to download image from {url}: {str(e)}"
315
  self.logger.error(error_msg)
316
  raise ConnectionError(error_msg) from e
317
-
318
  def _guess_extension(self, content_type: str) -> str:
319
  """
320
  Guess file extension from content type.
321
-
322
  Args:
323
  content_type: HTTP Content-Type header
324
-
325
  Returns:
326
  File extension (with dot)
327
  """
@@ -336,14 +337,14 @@ class FluxLoRATool(Tool):
336
  return ".gif"
337
  else:
338
  return ".png" # Default to PNG
339
-
340
  def _get_random_seed(self) -> int:
341
  """
342
  Get a random seed from the API.
343
-
344
  Returns:
345
  Random seed value
346
-
347
  Raises:
348
  RuntimeError: If random seed retrieval fails
349
  """
@@ -357,14 +358,14 @@ class FluxLoRATool(Tool):
357
  # Just log and re-raise as we have fallback in the validation method
358
  self.logger.warning(f"Failed to get random seed: {e}")
359
  raise
360
-
361
  def _handle_custom_lora(self, custom_lora: Optional[str]) -> None:
362
  """
363
  Add or remove custom LoRA model.
364
-
365
  Args:
366
  custom_lora: Custom LoRA model string
367
-
368
  Raises:
369
  RuntimeError: If LoRA handling fails
370
  """
@@ -381,15 +382,14 @@ class FluxLoRATool(Tool):
381
  # Add custom LoRA
382
  try:
383
  self.client.predict(
384
- custom_lora=custom_lora,
385
- api_name="/add_custom_lora"
386
  )
387
  self.logger.debug(f"Added custom LoRA: {custom_lora}")
388
  except Exception as e:
389
  error_msg = f"Failed to add custom LoRA '{custom_lora}': {str(e)}"
390
  self.logger.error(error_msg)
391
  raise RuntimeError(error_msg) from e
392
-
393
  def forward(
394
  self,
395
  prompt: str,
@@ -401,11 +401,11 @@ class FluxLoRATool(Tool):
401
  width: Optional[int] = None,
402
  height: Optional[int] = None,
403
  lora_scale: Optional[float] = None,
404
- custom_lora: Optional[str] = None
405
  ) -> str:
406
  """
407
  Generate an image with FLUX-LoRA.
408
-
409
  Args:
410
  prompt: Text description of the desired image
411
  image_input: Optional path or URL to input image for img2img
@@ -417,10 +417,10 @@ class FluxLoRATool(Tool):
417
  height: Image height in pixels (128-2048)
418
  lora_scale: LoRA influence scale (0.0-1.0)
419
  custom_lora: Custom LoRA model to use
420
-
421
  Returns:
422
  Formatted string with image generation results
423
-
424
  Raises:
425
  ValueError: If input validation fails
426
  ConnectionError: If API communication fails
@@ -438,12 +438,12 @@ class FluxLoRATool(Tool):
438
  width=width,
439
  height=height,
440
  lora_scale=lora_scale,
441
- custom_lora=custom_lora
442
  )
443
  self.logger.debug(f"Validated parameters: {params}")
444
  except ValueError as e:
445
  return f"Parameter validation failed: {str(e)}"
446
-
447
  # Step 2: Handle custom LoRA if specified
448
  if "custom_lora" in params:
449
  try:
@@ -451,15 +451,16 @@ class FluxLoRATool(Tool):
451
  self._handle_custom_lora(custom_lora_value)
452
  except RuntimeError as e:
453
  return f"Custom LoRA setup failed: {str(e)}"
454
-
455
  # Step 3: Generate image
456
  try:
457
  # Prepare image input if provided
458
  img_param = None
459
  if "image_input" in params and params["image_input"]:
460
  from gradio_client import handle_file
 
461
  img_param = handle_file(params.pop("image_input"))
462
-
463
  # Call the API
464
  generation_args = {
465
  "prompt": params["prompt"],
@@ -472,27 +473,23 @@ class FluxLoRATool(Tool):
472
  "height": params["height"],
473
  "lora_scale": params["lora_scale"],
474
  }
475
-
476
  # Add image input if available
477
  if img_param:
478
  generation_args["image_input"] = img_param
479
-
480
  self.logger.info(f"Generating image with params: {generation_args}")
481
- result = self.client.predict(
482
- api_name="/run_lora",
483
- **generation_args
484
- )
485
-
486
  # Process result
487
  if isinstance(result, tuple) and len(result) >= 2:
488
  image_path, actual_seed = result[0], result[1]
489
-
490
  # Save image to our directory
491
  try:
492
  output_path = self._save_image(image_path)
493
  image_result = ImageGenerationResult(
494
- image_path=output_path,
495
- seed=int(actual_seed)
496
  )
497
  return self._format_result(image_result, params["prompt"])
498
  except Exception as e:
@@ -500,69 +497,69 @@ class FluxLoRATool(Tool):
500
  return f"Image generated but failed to save: {str(e)}"
501
  else:
502
  raise ValueError(f"Unexpected API response format: {result}")
503
-
504
  except Exception as e:
505
  error_msg = f"Image generation failed: {str(e)}"
506
  self.logger.error(error_msg)
507
  return error_msg
508
-
509
  def _save_image(self, image_path: str) -> str:
510
  """
511
  Save generated image to specified directory.
512
-
513
  Args:
514
  image_path: Path to generated image from API
515
-
516
  Returns:
517
  Path to saved image
518
-
519
  Raises:
520
  IOError: If image saving fails
521
  """
522
  try:
523
  # Load the image
524
  img = Image.open(image_path)
525
-
526
  # Generate timestamp-based filename
527
  timestamp = uuid.uuid4().hex[:8]
528
  output_filename = f"flux_lora_{timestamp}.png"
529
  output_path = os.path.join(self.image_save_dir, output_filename)
530
-
531
  # Save to our directory
532
  img.save(output_path)
533
  self.logger.debug(f"Saved image to {output_path}")
534
-
535
  return output_path
536
-
537
  except Exception as e:
538
  error_msg = f"Failed to save image: {str(e)}"
539
  self.logger.error(error_msg)
540
  raise IOError(error_msg) from e
541
-
542
  def _format_result(self, result: ImageGenerationResult, prompt: str) -> str:
543
  """
544
  Format the image generation result as a string.
545
-
546
  Args:
547
  result: Image generation result
548
  prompt: Original prompt
549
-
550
  Returns:
551
  Formatted string with generation details
552
  """
553
  lines = [
554
- f"📷 Image generated successfully!",
555
  f"🖼️ Image saved to: {result.image_path}",
556
  f"🌱 Seed used: {result.seed}",
557
  f"📝 Original prompt: {prompt}",
558
  ]
559
-
560
  # Add metadata if available
561
  if result.metadata:
562
  lines.append("📊 Additional metadata:")
563
  for key, value in result.metadata.items():
564
  lines.append(f" - {key}: {value}")
565
-
566
  return "\n".join(lines)
567
 
568
 
@@ -570,17 +567,18 @@ class FluxLoRATool(Tool):
570
  # UTILITY FUNCTIONS
571
  # -----------------------------------------------------------------------------
572
 
 
573
  def download_image(url: str, output_dir: Optional[str] = None) -> str:
574
  """
575
  Standalone utility to download an image from a URL.
576
-
577
  Args:
578
  url: Image URL
579
  output_dir: Directory to save image (created if doesn't exist)
580
-
581
  Returns:
582
  Path to downloaded image
583
-
584
  Raises:
585
  ValueError: If URL is invalid
586
  ConnectionError: If download fails
@@ -588,31 +586,30 @@ def download_image(url: str, output_dir: Optional[str] = None) -> str:
588
  """
589
  if not url.startswith(("http://", "https://")):
590
  raise ValueError(f"Invalid URL: {url}")
591
-
592
  # Setup output directory
593
  if output_dir is None:
594
  output_dir = os.path.join(tempfile.gettempdir(), "flux_lora_images")
595
  os.makedirs(output_dir, exist_ok=True)
596
-
597
  try:
598
  # Download image
599
  response = requests.get(url, stream=True, timeout=30)
600
  response.raise_for_status()
601
-
602
  # Determine file extension
603
  content_type = response.headers.get("Content-Type", "")
604
  ext = ".jpg" if "jpeg" in content_type.lower() else ".png"
605
-
606
  # Save image
607
  output_path = os.path.join(output_dir, f"download_{uuid.uuid4().hex}{ext}")
608
  with open(output_path, "wb") as f:
609
  for chunk in response.iter_content(chunk_size=8192):
610
  f.write(chunk)
611
-
612
  return output_path
613
-
614
  except requests.RequestException as e:
615
  raise ConnectionError(f"Failed to download image: {str(e)}")
616
  except IOError as e:
617
  raise IOError(f"Failed to save image: {str(e)}")
618
-
 
12
  agent = CodeAgent(tools=[flux_tool], ...)
13
  """
14
 
15
+ import logging
16
  import os
 
17
  import tempfile
18
+ import uuid
 
19
  from dataclasses import dataclass
20
+ from typing import Any, Dict, Optional
 
21
 
22
  # Third-party
23
  import requests
 
24
  from gradio_client import Client
25
+ from PIL import Image
 
26
  from smolagents import Tool
27
 
28
  # -----------------------------------------------------------------------------
29
  # CONSTANTS AND TYPE DEFINITIONS
30
  # -----------------------------------------------------------------------------
31
 
32
+
33
  @dataclass
34
  class LoRAModelInfo:
35
  """Value object representing LoRA model information."""
36
+
37
  name: str
38
  description: Optional[str] = None
39
  example_image_url: Optional[str] = None
 
42
  @dataclass
43
  class ImageGenerationResult:
44
  """Value object representing a generated image result."""
45
+
46
  image_path: str
47
  seed: int
48
  metadata: Optional[Dict[str, Any]] = None
 
52
  # CORE TOOL IMPLEMENTATION
53
  # -----------------------------------------------------------------------------
54
 
55
+
56
  class FluxLoRATool(Tool):
57
  """
58
  Tool for generating images using FLUX-LoRA-DLC API.
59
+
60
  This tool implements the Zhou Protocol integration patterns to provide
61
  a clean, efficient interface for image generation using LoRA models.
62
  """
63
+
64
  name = "flux_lora_generator"
65
  description = """
66
  Generates high-quality images using FLUX-LoRA models.
 
68
  """
69
  inputs = {
70
  "prompt": {
71
+ "type": "string",
72
+ "description": "Detailed description of the desired image.",
73
  },
74
  "image_input": {
75
+ "type": "string",
76
  "description": "Optional URL or file path to input image for img2img generation.",
77
+ "optional": True,
78
  },
79
  "image_strength": {
80
  "type": "float",
81
  "description": "Strength of input image influence (0.0-1.0), where 1.0 maintains more of original image.",
82
  "optional": True,
83
+ "default": 0.75,
84
  },
85
  "cfg_scale": {
86
  "type": "float",
87
  "description": "Guidance scale for prompt adherence (1.0-30.0).",
88
  "optional": True,
89
+ "default": 3.5,
90
  },
91
  "steps": {
92
  "type": "integer",
93
  "description": "Number of sampling steps (10-100).",
94
  "optional": True,
95
+ "default": 28,
96
  },
97
  "seed": {
98
  "type": "integer",
99
  "description": "Random seed for reproducibility. Use -1 for random seed.",
100
  "optional": True,
101
+ "default": -1,
102
  },
103
  "width": {
104
  "type": "integer",
105
  "description": "Image width in pixels.",
106
  "optional": True,
107
+ "default": 1024,
108
  },
109
  "height": {
110
  "type": "integer",
111
  "description": "Image height in pixels.",
112
  "optional": True,
113
+ "default": 1024,
114
  },
115
  "lora_scale": {
116
  "type": "float",
117
  "description": "LoRA influence scale (0.0-1.0).",
118
  "optional": True,
119
+ "default": 0.95,
120
  },
121
  "custom_lora": {
122
  "type": "string",
123
  "description": "Custom LoRA model to use. Leave empty for default.",
124
+ "optional": True,
125
+ },
126
  }
127
  output_type = "string"
128
+
129
  def __init__(
130
+ self,
131
  api_url: str = "xkerser/FLUX-LoRA-DLC",
132
  image_save_dir: Optional[str] = None,
133
  connection_timeout: int = 60,
134
+ verbose: bool = False,
135
  ):
136
  """
137
  Initialize the FLUX-LoRA Tool with Zhou Protocol connection patterns.
138
+
139
  Args:
140
  api_url: URL or endpoint ID for the FLUX-LoRA-DLC API
141
  image_save_dir: Directory to save generated images (created if doesn't exist)
 
143
  verbose: Enable detailed logging
144
  """
145
  super().__init__()
146
+
147
  # Initialize logging
148
  self.logger = logging.getLogger("flux_lora_tool")
149
  self.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
150
+
151
  # Set up client and storage directories
152
  self.api_url = api_url
153
  self.connection_timeout = connection_timeout
154
  self._client = None # Lazy initialization
155
+
156
  # Set up image storage directory
157
+ self.image_save_dir = image_save_dir or os.path.join(
158
+ tempfile.gettempdir(), "flux_lora_images"
159
+ )
160
  os.makedirs(self.image_save_dir, exist_ok=True)
161
+ self.logger.info(
162
+ f"FluxLoRATool initialized. Images will be saved to: {self.image_save_dir}"
163
+ )
164
+
165
  @property
166
  def client(self) -> Client:
167
  """
168
  Get or initialize the Gradio client with proper connection handling.
169
+
170
  Returns:
171
  Initialized Gradio client
172
+
173
  Raises:
174
  ConnectionError: If client initialization fails
175
  """
176
  if self._client is None:
177
  try:
178
+ self._client = Client(self.api_url, timeout=self.connection_timeout)
 
 
 
179
  self.logger.debug(f"Gradio client initialized for: {self.api_url}")
180
  except Exception as e:
181
  error_msg = f"Failed to initialize FLUX-LoRA client: {str(e)}"
182
  self.logger.error(error_msg)
183
  raise ConnectionError(error_msg) from e
184
+
185
  return self._client
186
+
187
  def _validate_inputs(self, **kwargs) -> Dict[str, Any]:
188
  """
189
  Validate and normalize input parameters with Zhou Protocol validation patterns.
190
+
191
  Args:
192
  **kwargs: Input parameters
193
+
194
  Returns:
195
  Validated and normalized parameters
196
+
197
  Raises:
198
  ValueError: If input validation fails
199
  """
200
  validated = {}
201
+
202
  # Required parameter: prompt
203
  if not kwargs.get("prompt"):
204
  raise ValueError("Prompt is required for image generation")
205
  validated["prompt"] = kwargs["prompt"]
206
+
207
  # Image input handling
208
  if "image_input" in kwargs and kwargs["image_input"]:
209
  input_image = kwargs["image_input"]
 
216
  if not os.path.exists(input_image):
217
  raise ValueError(f"Image file not found: {input_image}")
218
  validated["image_input"] = input_image
219
+
220
  # Numeric parameter validation with constraints
221
  numeric_params = {
222
  "image_strength": {"min": 0.0, "max": 1.0, "default": 0.75},
 
224
  "steps": {"min": 10, "max": 100, "default": 28},
225
  "width": {"min": 128, "max": 2048, "default": 1024},
226
  "height": {"min": 128, "max": 2048, "default": 1024},
227
+ "lora_scale": {"min": 0.0, "max": 1.0, "default": 0.95},
228
  }
229
+
230
  for param, constraints in numeric_params.items():
231
  if param in kwargs and kwargs[param] is not None:
232
  value = kwargs[param]
233
+
234
  # Type conversion if needed
235
  if param in ["steps", "width", "height"]:
236
  try:
 
242
  value = float(value)
243
  except (ValueError, TypeError):
244
  raise ValueError(f"Parameter '{param}' must be a number")
245
+
246
  # Range validation
247
  if value < constraints["min"] or value > constraints["max"]:
248
  raise ValueError(
249
  f"Parameter '{param}' must be between {constraints['min']} and {constraints['max']}"
250
  )
251
+
252
  validated[param] = value
253
  else:
254
  validated[param] = constraints["default"]
255
+
256
  # Special handling for seed
257
  if "seed" in kwargs and kwargs["seed"] is not None:
258
  try:
 
265
  self.logger.warning(f"Failed to get random seed from API: {e}")
266
  # Fallback to Python's random
267
  import random
268
+
269
  seed = random.randint(0, 2**32 - 1)
270
  validated["seed"] = seed
271
  except (ValueError, TypeError):
 
273
  else:
274
  # Default to random seed
275
  validated["seed"] = self._get_random_seed()
276
+
277
  # Custom LoRA handling
278
  if "custom_lora" in kwargs and kwargs["custom_lora"]:
279
  validated["custom_lora"] = kwargs["custom_lora"]
280
+
281
  return validated
282
+
283
  def _download_image(self, url: str) -> str:
284
  """
285
  Download image from URL and save to local file.
286
+
287
  Args:
288
  url: Image URL
289
+
290
  Returns:
291
  Local file path
292
+
293
  Raises:
294
  ConnectionError: If download fails
295
  """
296
  try:
297
  response = requests.get(url, stream=True, timeout=30)
298
  response.raise_for_status()
299
+
300
  # Generate temporary file path
301
  file_ext = self._guess_extension(response.headers.get("Content-Type", ""))
302
  temp_path = os.path.join(
303
+ self.image_save_dir, f"input_{uuid.uuid4().hex}{file_ext}"
 
304
  )
305
+
306
  # Save image
307
  with open(temp_path, "wb") as f:
308
  for chunk in response.iter_content(chunk_size=8192):
309
  f.write(chunk)
310
+
311
  self.logger.debug(f"Downloaded image from {url} to {temp_path}")
312
  return temp_path
313
+
314
  except Exception as e:
315
  error_msg = f"Failed to download image from {url}: {str(e)}"
316
  self.logger.error(error_msg)
317
  raise ConnectionError(error_msg) from e
318
+
319
  def _guess_extension(self, content_type: str) -> str:
320
  """
321
  Guess file extension from content type.
322
+
323
  Args:
324
  content_type: HTTP Content-Type header
325
+
326
  Returns:
327
  File extension (with dot)
328
  """
 
337
  return ".gif"
338
  else:
339
  return ".png" # Default to PNG
340
+
341
  def _get_random_seed(self) -> int:
342
  """
343
  Get a random seed from the API.
344
+
345
  Returns:
346
  Random seed value
347
+
348
  Raises:
349
  RuntimeError: If random seed retrieval fails
350
  """
 
358
  # Just log and re-raise as we have fallback in the validation method
359
  self.logger.warning(f"Failed to get random seed: {e}")
360
  raise
361
+
362
  def _handle_custom_lora(self, custom_lora: Optional[str]) -> None:
363
  """
364
  Add or remove custom LoRA model.
365
+
366
  Args:
367
  custom_lora: Custom LoRA model string
368
+
369
  Raises:
370
  RuntimeError: If LoRA handling fails
371
  """
 
382
  # Add custom LoRA
383
  try:
384
  self.client.predict(
385
+ custom_lora=custom_lora, api_name="/add_custom_lora"
 
386
  )
387
  self.logger.debug(f"Added custom LoRA: {custom_lora}")
388
  except Exception as e:
389
  error_msg = f"Failed to add custom LoRA '{custom_lora}': {str(e)}"
390
  self.logger.error(error_msg)
391
  raise RuntimeError(error_msg) from e
392
+
393
  def forward(
394
  self,
395
  prompt: str,
 
401
  width: Optional[int] = None,
402
  height: Optional[int] = None,
403
  lora_scale: Optional[float] = None,
404
+ custom_lora: Optional[str] = None,
405
  ) -> str:
406
  """
407
  Generate an image with FLUX-LoRA.
408
+
409
  Args:
410
  prompt: Text description of the desired image
411
  image_input: Optional path or URL to input image for img2img
 
417
  height: Image height in pixels (128-2048)
418
  lora_scale: LoRA influence scale (0.0-1.0)
419
  custom_lora: Custom LoRA model to use
420
+
421
  Returns:
422
  Formatted string with image generation results
423
+
424
  Raises:
425
  ValueError: If input validation fails
426
  ConnectionError: If API communication fails
 
438
  width=width,
439
  height=height,
440
  lora_scale=lora_scale,
441
+ custom_lora=custom_lora,
442
  )
443
  self.logger.debug(f"Validated parameters: {params}")
444
  except ValueError as e:
445
  return f"Parameter validation failed: {str(e)}"
446
+
447
  # Step 2: Handle custom LoRA if specified
448
  if "custom_lora" in params:
449
  try:
 
451
  self._handle_custom_lora(custom_lora_value)
452
  except RuntimeError as e:
453
  return f"Custom LoRA setup failed: {str(e)}"
454
+
455
  # Step 3: Generate image
456
  try:
457
  # Prepare image input if provided
458
  img_param = None
459
  if "image_input" in params and params["image_input"]:
460
  from gradio_client import handle_file
461
+
462
  img_param = handle_file(params.pop("image_input"))
463
+
464
  # Call the API
465
  generation_args = {
466
  "prompt": params["prompt"],
 
473
  "height": params["height"],
474
  "lora_scale": params["lora_scale"],
475
  }
476
+
477
  # Add image input if available
478
  if img_param:
479
  generation_args["image_input"] = img_param
480
+
481
  self.logger.info(f"Generating image with params: {generation_args}")
482
+ result = self.client.predict(api_name="/run_lora", **generation_args)
483
+
 
 
 
484
  # Process result
485
  if isinstance(result, tuple) and len(result) >= 2:
486
  image_path, actual_seed = result[0], result[1]
487
+
488
  # Save image to our directory
489
  try:
490
  output_path = self._save_image(image_path)
491
  image_result = ImageGenerationResult(
492
+ image_path=output_path, seed=int(actual_seed)
 
493
  )
494
  return self._format_result(image_result, params["prompt"])
495
  except Exception as e:
 
497
  return f"Image generated but failed to save: {str(e)}"
498
  else:
499
  raise ValueError(f"Unexpected API response format: {result}")
500
+
501
  except Exception as e:
502
  error_msg = f"Image generation failed: {str(e)}"
503
  self.logger.error(error_msg)
504
  return error_msg
505
+
506
  def _save_image(self, image_path: str) -> str:
507
  """
508
  Save generated image to specified directory.
509
+
510
  Args:
511
  image_path: Path to generated image from API
512
+
513
  Returns:
514
  Path to saved image
515
+
516
  Raises:
517
  IOError: If image saving fails
518
  """
519
  try:
520
  # Load the image
521
  img = Image.open(image_path)
522
+
523
  # Generate timestamp-based filename
524
  timestamp = uuid.uuid4().hex[:8]
525
  output_filename = f"flux_lora_{timestamp}.png"
526
  output_path = os.path.join(self.image_save_dir, output_filename)
527
+
528
  # Save to our directory
529
  img.save(output_path)
530
  self.logger.debug(f"Saved image to {output_path}")
531
+
532
  return output_path
533
+
534
  except Exception as e:
535
  error_msg = f"Failed to save image: {str(e)}"
536
  self.logger.error(error_msg)
537
  raise IOError(error_msg) from e
538
+
539
  def _format_result(self, result: ImageGenerationResult, prompt: str) -> str:
540
  """
541
  Format the image generation result as a string.
542
+
543
  Args:
544
  result: Image generation result
545
  prompt: Original prompt
546
+
547
  Returns:
548
  Formatted string with generation details
549
  """
550
  lines = [
551
+ "📷 Image generated successfully!",
552
  f"🖼️ Image saved to: {result.image_path}",
553
  f"🌱 Seed used: {result.seed}",
554
  f"📝 Original prompt: {prompt}",
555
  ]
556
+
557
  # Add metadata if available
558
  if result.metadata:
559
  lines.append("📊 Additional metadata:")
560
  for key, value in result.metadata.items():
561
  lines.append(f" - {key}: {value}")
562
+
563
  return "\n".join(lines)
564
 
565
 
 
567
  # UTILITY FUNCTIONS
568
  # -----------------------------------------------------------------------------
569
 
570
+
571
  def download_image(url: str, output_dir: Optional[str] = None) -> str:
572
  """
573
  Standalone utility to download an image from a URL.
574
+
575
  Args:
576
  url: Image URL
577
  output_dir: Directory to save image (created if doesn't exist)
578
+
579
  Returns:
580
  Path to downloaded image
581
+
582
  Raises:
583
  ValueError: If URL is invalid
584
  ConnectionError: If download fails
 
586
  """
587
  if not url.startswith(("http://", "https://")):
588
  raise ValueError(f"Invalid URL: {url}")
589
+
590
  # Setup output directory
591
  if output_dir is None:
592
  output_dir = os.path.join(tempfile.gettempdir(), "flux_lora_images")
593
  os.makedirs(output_dir, exist_ok=True)
594
+
595
  try:
596
  # Download image
597
  response = requests.get(url, stream=True, timeout=30)
598
  response.raise_for_status()
599
+
600
  # Determine file extension
601
  content_type = response.headers.get("Content-Type", "")
602
  ext = ".jpg" if "jpeg" in content_type.lower() else ".png"
603
+
604
  # Save image
605
  output_path = os.path.join(output_dir, f"download_{uuid.uuid4().hex}{ext}")
606
  with open(output_path, "wb") as f:
607
  for chunk in response.iter_content(chunk_size=8192):
608
  f.write(chunk)
609
+
610
  return output_path
611
+
612
  except requests.RequestException as e:
613
  raise ConnectionError(f"Failed to download image: {str(e)}")
614
  except IOError as e:
615
  raise IOError(f"Failed to save image: {str(e)}")
 
scripts/frontmatter_tool.py DELETED
@@ -1,402 +0,0 @@
1
- """
2
- Frontmatter Generator Tool for Smolagents
3
-
4
- This tool helps generate consistent YAML frontmatter for documents,
5
- useful for RAG systems, static site generators, and document organization.
6
- Integrates with TextInspectorTool and MarkdownConverter for a complete
7
- document processing pipeline.
8
- """
9
-
10
- import re
11
- import yaml
12
- import json
13
- from datetime import datetime
14
- from typing import Dict, List, Optional, Any, Union
15
- from smolagents import Tool
16
-
17
-
18
- class FrontmatterGeneratorTool(Tool):
19
- """Tool for generating and manipulating YAML frontmatter in documents."""
20
-
21
- name = "frontmatter_generator"
22
- description = """
23
- Generates or extracts YAML frontmatter for documents. Frontmatter provides structured
24
- metadata for documents including title, author, date, description, and tags.
25
- Useful for document organization, RAG systems, and static site generators.
26
- Works with content from the inspect_file_as_text tool to add metadata to documents.
27
- """
28
-
29
- inputs = {
30
- "content": {
31
- "type": "string",
32
- "description": "Document content (with or without existing frontmatter)",
33
- },
34
- "title": {"type": "string", "description": "Document title", "nullable": True},
35
- "author": {
36
- "type": "string",
37
- "description": "Document author(s)",
38
- "nullable": True,
39
- },
40
- "date": {
41
- "type": "string",
42
- "description": "Document date in YYYY-MM-DD format (defaults to today if not provided)",
43
- "nullable": True,
44
- },
45
- "date_format": {
46
- "type": "string",
47
- "description": "Format string for the document date (e.g., '%Y-%m-%d', '%d/%m/%Y'). Defaults to '%Y-%m-%d'",
48
- "nullable": True,
49
- "default": "%Y-%m-%d",
50
- },
51
- "description": {
52
- "type": "string",
53
- "description": "Brief description of the document",
54
- "nullable": True,
55
- },
56
- "tags": {
57
- "type": "string",
58
- "description": "Comma-separated list of tags",
59
- "nullable": True,
60
- },
61
- "additional_fields": {
62
- "type": "string",
63
- "description": "JSON string with additional frontmatter fields",
64
- "nullable": True,
65
- },
66
- "mode": {
67
- "type": "string",
68
- "description": "Operation mode: 'generate' (create new), 'extract' (get existing), 'update' (modify existing), or 'strip' (remove)",
69
- "default": "generate",
70
- },
71
- }
72
- output_type = "string"
73
-
74
- # Regular expression to detect and extract YAML frontmatter
75
- FRONTMATTER_PATTERN = r"^---\s*\n(.*?)\n---\s*\n"
76
-
77
- def forward(
78
- self,
79
- content: str,
80
- title: Optional[str] = None,
81
- author: Optional[str] = None,
82
- date: Optional[str] = None,
83
- date_format: Optional[str] = "%Y-%m-%d",
84
- description: Optional[str] = None,
85
- tags: Optional[str] = None,
86
- additional_fields: Optional[str] = None,
87
- mode: str = "generate",
88
- ) -> str:
89
- """
90
- Process document content based on specified mode.
91
-
92
- Args:
93
- content: Document content with or without frontmatter
94
- title: Document title
95
- author: Document author(s)
96
- date: Document date (YYYY-MM-DD)
97
- date_format: strftime format string
98
- description: Brief document description
99
- tags: Comma-separated list of tags
100
- additional_fields: JSON string with additional fields
101
- mode: Operation mode (generate, extract, update, strip)
102
-
103
- Returns:
104
- Processed document or extracted frontmatter
105
- """
106
- # Validate inputs
107
- if not isinstance(content, str):
108
- return "Error: Content must be a string"
109
- if title and not isinstance(title, str):
110
- return "Error: Title must be a string"
111
- if author and not isinstance(author, str):
112
- return "Error: Author must be a string"
113
- if date and not isinstance(date, str):
114
- return "Error: Date must be a string"
115
- if description and not isinstance(description, str):
116
- return "Error: Description must be a string"
117
- if tags and not isinstance(tags, str):
118
- return "Error: Tags must be a string"
119
- if additional_fields and not isinstance(additional_fields, str):
120
- return "Error: Additional_fields must be a string"
121
- if not isinstance(mode, str):
122
- return "Error: Mode must be a string"
123
-
124
- # Validate mode
125
- valid_modes = ["generate", "extract", "update", "strip"]
126
- if mode not in valid_modes:
127
- return f"Error: Invalid mode '{mode}'. Valid options are: {', '.join(valid_modes)}"
128
-
129
- # Handle empty content
130
- if not content or not content.strip():
131
- if mode == "generate":
132
- # We can still generate frontmatter from provided fields
133
- content = ""
134
- else:
135
- return "Error: Empty content provided"
136
-
137
- # Special handling for TextInspectorTool output
138
- if content.startswith("Document content:"):
139
- content = content[len("Document content:"):].strip()
140
-
141
- # Process based on mode
142
- try:
143
- if mode == "extract":
144
- return self._extract_frontmatter(content)
145
- elif mode == "strip":
146
- return self._strip_frontmatter(content)
147
- elif mode == "update":
148
- return self._update_frontmatter(
149
- content,
150
- title,
151
- author,
152
- date,
153
- description,
154
- tags,
155
- additional_fields,
156
- date_format,
157
- )
158
- else: # generate
159
- return self._generate_frontmatter(
160
- content,
161
- title,
162
- author,
163
- date,
164
- description,
165
- tags,
166
- additional_fields,
167
- date_format,
168
- )
169
- except Exception as e:
170
- return f"Error processing frontmatter: {str(e)}"
171
-
172
- def _extract_frontmatter(self, content: str) -> str:
173
- """Extract and return existing frontmatter as formatted YAML."""
174
- match = re.search(self.FRONTMATTER_PATTERN, content, re.DOTALL)
175
- if not match:
176
- return "No frontmatter found in the document"
177
-
178
- try:
179
- yaml_content = match.group(1)
180
- # Parse and reformat for consistency
181
- frontmatter_dict = yaml.safe_load(yaml_content)
182
- return f"Extracted frontmatter:\n\n```yaml\n{yaml.dump(frontmatter_dict, sort_keys=False, default_flow_style=False)}```"
183
- except yaml.YAMLError:
184
- return "Found frontmatter but failed to parse it as valid YAML"
185
-
186
- def _strip_frontmatter(self, content: str) -> str:
187
- """Remove frontmatter from document and return clean content."""
188
- result = re.sub(self.FRONTMATTER_PATTERN, "", content, count=1, flags=re.DOTALL)
189
-
190
- # Check if anything was actually removed
191
- if result == content:
192
- return "No frontmatter found to strip. Content unchanged."
193
-
194
- return result.strip()
195
-
196
- def _parse_additional_fields(self, additional_fields: str) -> Dict[str, Any]:
197
- """Parse the additional_fields JSON string into a dictionary."""
198
- if not additional_fields:
199
- return {}
200
-
201
- try:
202
- return json.loads(additional_fields)
203
- except json.JSONDecodeError:
204
- raise ValueError("additional_fields must be a valid JSON string")
205
-
206
- def _infer_title_from_content(self, content: str) -> Optional[str]:
207
- """Attempt to infer document title from content."""
208
- # Try to find the first heading
209
- heading_match = re.search(r"^#\s+(.+)$", content, re.MULTILINE)
210
- if heading_match:
211
- return heading_match.group(1).strip()
212
-
213
- # Try to find the first non-empty line
214
- lines = content.split("\n")
215
- for line in lines:
216
- if line.strip():
217
- # Limit to a reasonable title length
218
- return line.strip()[:100]
219
-
220
- return None
221
-
222
- def _parse_tags(self, tags_string: str) -> List[str]:
223
- """Parse comma-separated tags into a list."""
224
- if not tags_string:
225
- return []
226
-
227
- # Split by comma and clean each tag
228
- tag_list = [tag.strip() for tag in tags_string.split(",")]
229
- # Remove any empty tags
230
- return [tag for tag in tag_list if tag]
231
-
232
- def _parse_flexible_date(
233
- self, date_str: str, date_format: Optional[str] = None
234
- ) -> str:
235
- """
236
- Try to parse dates in various formats and convert to YYYY-MM-DD.
237
-
238
- Args:
239
- date_str: The date string to parse
240
- date_format: Optional preferred format to try first
241
-
242
- Returns:
243
- Formatted date as string (YYYY-MM-DD by default)
244
- """
245
- if not date_str:
246
- return datetime.now().strftime("%Y-%m-%d")
247
-
248
- # If a specific format is provided, try it first
249
- if date_format:
250
- try:
251
- parsed_date = datetime.strptime(date_str, date_format)
252
- return parsed_date.strftime("%Y-%m-%d")
253
- except ValueError:
254
- # If it fails, continue with other formats
255
- pass
256
-
257
- # Common formats to try
258
- formats = [
259
- "%Y-%m-%d", # 2013-03-13
260
- "%d %B %Y", # 13 March 2013
261
- "%B %Y", # September 2013
262
- "%Y", # 1958
263
- "%d/%m/%Y", # 13/03/2013
264
- "%m/%d/%Y", # 03/13/2013
265
- "%d-%m-%Y", # 13-03-2013
266
- "%m-%d-%Y", # 03-13-2013
267
- "%Y/%m/%d", # 2013/03/13
268
- ]
269
-
270
- for fmt in formats:
271
- try:
272
- parsed_date = datetime.strptime(date_str, fmt)
273
- return parsed_date.strftime("%Y-%m-%d")
274
- except ValueError:
275
- continue
276
-
277
- # If no format matched, return the original string
278
- return date_str
279
-
280
- def _update_frontmatter(
281
- self,
282
- content: str,
283
- title: Optional[str] = None,
284
- author: Optional[str] = None,
285
- date: Optional[str] = None,
286
- description: Optional[str] = None,
287
- tags: Optional[str] = None,
288
- additional_fields: Optional[str] = None,
289
- date_format: Optional[str] = None,
290
- ) -> str:
291
- """Update existing frontmatter with new values."""
292
- # Check if frontmatter exists
293
- match = re.search(self.FRONTMATTER_PATTERN, content, re.DOTALL)
294
- if not match:
295
- # If no frontmatter exists, generate new one
296
- return self._generate_frontmatter(
297
- content,
298
- title,
299
- author,
300
- date,
301
- description,
302
- tags,
303
- additional_fields,
304
- date_format,
305
- )
306
-
307
- # Parse existing frontmatter
308
- yaml_content = match.group(1)
309
- try:
310
- frontmatter_dict = yaml.safe_load(yaml_content) or {}
311
- except yaml.YAMLError:
312
- frontmatter_dict = {}
313
-
314
- # Update with new values if provided
315
- if title:
316
- frontmatter_dict["title"] = title
317
- if author:
318
- frontmatter_dict["author"] = author
319
- if date:
320
- # Try to parse the date with the flexible parser
321
- frontmatter_dict["date"] = self._parse_flexible_date(date, date_format)
322
- if description:
323
- frontmatter_dict["description"] = description
324
- if tags:
325
- frontmatter_dict["tags"] = self._parse_tags(tags)
326
-
327
- # Add additional fields
328
- if additional_fields:
329
- additional_dict = self._parse_additional_fields(additional_fields)
330
- frontmatter_dict.update(additional_dict)
331
-
332
- # Generate new frontmatter
333
- new_frontmatter = yaml.dump(
334
- frontmatter_dict, sort_keys=False, default_flow_style=False
335
- )
336
- new_frontmatter = f"---\n{new_frontmatter}---\n\n"
337
-
338
- # Replace old frontmatter with new one
339
- return re.sub(
340
- self.FRONTMATTER_PATTERN, new_frontmatter, content, count=1, flags=re.DOTALL
341
- )
342
-
343
- def _generate_frontmatter(
344
- self,
345
- content: str,
346
- title: Optional[str] = None,
347
- author: Optional[str] = None,
348
- date: Optional[str] = None,
349
- description: Optional[str] = None,
350
- tags: Optional[str] = None,
351
- additional_fields: Optional[str] = None,
352
- date_format: Optional[str] = None,
353
- ) -> str:
354
- """Generate new frontmatter and prepend to content."""
355
- # Strip any existing frontmatter
356
- clean_content = (
357
- self._strip_frontmatter(content) if isinstance(content, str) else ""
358
- )
359
-
360
- # Build frontmatter dictionary
361
- frontmatter_dict = {}
362
-
363
- # Try to infer title if not provided
364
- if title:
365
- frontmatter_dict["title"] = title
366
- else:
367
- inferred_title = self._infer_title_from_content(clean_content)
368
- if inferred_title:
369
- frontmatter_dict["title"] = inferred_title
370
-
371
- # Add other fields if provided
372
- if author:
373
- frontmatter_dict["author"] = author
374
-
375
- # Process date with flexible parser
376
- if date:
377
- frontmatter_dict["date"] = self._parse_flexible_date(date, date_format)
378
- else:
379
- # Use current date with provided format or default
380
- format_to_use = date_format or "%Y-%m-%d"
381
- frontmatter_dict["date"] = datetime.now().strftime(format_to_use)
382
-
383
- if description:
384
- frontmatter_dict["description"] = description
385
-
386
- if tags:
387
- frontmatter_dict["tags"] = self._parse_tags(tags)
388
-
389
- # Add additional fields
390
- if additional_fields:
391
- additional_dict = self._parse_additional_fields(additional_fields)
392
- frontmatter_dict.update(additional_dict)
393
-
394
- # Generate YAML frontmatter
395
- frontmatter_yaml = yaml.dump(
396
- frontmatter_dict, sort_keys=False, default_flow_style=False
397
- )
398
- frontmatter = f"---\n{frontmatter_yaml}---\n\n"
399
-
400
- # Combine frontmatter with content
401
- return frontmatter + clean_content
402
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/gaia_scorer.py DELETED
@@ -1,124 +0,0 @@
1
- import re
2
- import string
3
- import warnings
4
-
5
-
6
- def normalize_number_str(number_str: str) -> float:
7
- # we replace these common units and commas to allow
8
- # conversion to float
9
- for char in ["$", "%", ","]:
10
- number_str = number_str.replace(char, "")
11
- try:
12
- return float(number_str)
13
- except ValueError:
14
- print(f"String {number_str} cannot be normalized to number str.")
15
- return float("inf")
16
-
17
-
18
- def split_string(
19
- s: str,
20
- char_list: list[str] = [",", ";"],
21
- ) -> list[str]:
22
- pattern = f"[{''.join(char_list)}]"
23
- return re.split(pattern, s)
24
-
25
-
26
- def is_float(element: any) -> bool:
27
- try:
28
- float(element)
29
- return True
30
- except ValueError:
31
- return False
32
-
33
-
34
- def question_scorer(
35
- model_answer: str,
36
- ground_truth: str,
37
- ) -> bool:
38
- # if gt is a number
39
- if is_float(ground_truth):
40
- normalized_answer = normalize_number_str(str(model_answer))
41
- return normalized_answer == float(ground_truth)
42
-
43
- # if gt is a list
44
- elif any(char in ground_truth for char in [",", ";"]):
45
- # question with the fish: normalization removes punct
46
-
47
- gt_elems = split_string(ground_truth)
48
- ma_elems = split_string(model_answer)
49
-
50
- # check length is the same
51
- if len(gt_elems) != len(ma_elems):
52
- warnings.warn("Answer lists have different lengths, returning False.", UserWarning)
53
- return False
54
-
55
- # compare each element as float or str
56
- comparisons = []
57
- for ma_elem, gt_elem in zip(ma_elems, gt_elems):
58
- if is_float(gt_elem):
59
- normalized_ma_elem = normalize_number_str(ma_elem)
60
- comparisons.append(normalized_ma_elem == float(gt_elem))
61
- else:
62
- # we do not remove punct since comparisons can include punct
63
- comparisons.append(
64
- normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)
65
- )
66
- return all(comparisons)
67
-
68
- # if gt is a str
69
- else:
70
- return normalize_str(model_answer) == normalize_str(ground_truth)
71
-
72
-
73
- def check_prediction_contains_answer_letters_in_order(prediction, true_answer):
74
- prediction = prediction.lower()
75
- true_answer = true_answer.lower()
76
- if len(prediction) > len(true_answer) * 3:
77
- return False
78
- i = 0
79
- for letter in true_answer:
80
- if letter in prediction[i:]:
81
- i += prediction[i:].index(letter)
82
- else:
83
- return False
84
- return True
85
-
86
-
87
- def check_close_call(prediction, true_answer, is_correct):
88
- if is_correct:
89
- return True
90
- else:
91
- if is_float(true_answer):
92
- return is_correct
93
- else:
94
- if (
95
- check_prediction_contains_answer_letters_in_order(str(prediction), str(true_answer))
96
- and len(str(true_answer)) * 0.5 <= len(str(prediction)) <= len(str(true_answer)) * 2
97
- ):
98
- print(f"Close call: {prediction} vs {true_answer}")
99
- return True
100
- else:
101
- return False
102
-
103
-
104
- def normalize_str(input_str, remove_punct=True) -> str:
105
- """
106
- Normalize a string by:
107
- - Removing all white spaces
108
- - Optionally removing punctuation (if remove_punct is True)
109
- - Converting to lowercase
110
- Parameters:
111
- - input_str: str, the string to normalize
112
- - remove_punct: bool, whether to remove punctuation (default: True)
113
- Returns:
114
- - str, the normalized string
115
- """
116
- # Remove all white spaces. Required e.g for seagull vs. sea gull
117
- no_spaces = re.sub(r"\s", "", input_str)
118
-
119
- # Remove punctuation, if specified.
120
- if remove_punct:
121
- translator = str.maketrans("", "", string.punctuation)
122
- return no_spaces.lower().translate(translator)
123
- else:
124
- return no_spaces.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/mdconvert.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  # This is copied from Magentic-one's great repo: https://github.com/microsoft/autogen/blob/v0.4.4/python/packages/autogen-magentic-one/src/autogen_magentic_one/markdown_browser/mdconvert.py
2
  # Thanks to Microsoft researchers for open-sourcing this!
3
  # type: ignore
@@ -22,7 +24,6 @@ import pandas as pd
22
  import pdfminer
23
  import pdfminer.high_level
24
  import pptx
25
-
26
  # File-format detection
27
  import puremagic
28
  import pydub
@@ -86,7 +87,11 @@ class _CustomMarkdownify(markdownify.MarkdownConverter):
86
  if self.options["default_title"] and not title:
87
  title = href
88
  title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
89
- return "%s[%s](%s%s)%s" % (prefix, text, href, title_part, suffix) if href else text
 
 
 
 
90
 
91
  def convert_img(self, el: Any, text: str, convert_as_inline: bool) -> str:
92
  """Same as usual converter, but removes data URIs"""
@@ -95,7 +100,10 @@ class _CustomMarkdownify(markdownify.MarkdownConverter):
95
  src = el.attrs.get("src", None) or ""
96
  title = el.attrs.get("title", None) or ""
97
  title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
98
- if convert_as_inline and el.parent.name not in self.options["keep_inline_images_in"]:
 
 
 
99
  return alt
100
 
101
  # Remove dataURIs
@@ -119,16 +127,22 @@ class DocumentConverterResult:
119
  class DocumentConverter:
120
  """Abstract superclass of all DocumentConverters."""
121
 
122
- def convert(self, local_path: str, **kwargs: Any) -> Union[None, DocumentConverterResult]:
 
 
123
  raise NotImplementedError()
124
 
125
 
126
  class PlainTextConverter(DocumentConverter):
127
  """Anything with content type text/plain"""
128
 
129
- def convert(self, local_path: str, **kwargs: Any) -> Union[None, DocumentConverterResult]:
 
 
130
  # Guess the content type from any file extension that might be around
131
- content_type, _ = mimetypes.guess_type("__placeholder" + kwargs.get("file_extension", ""))
 
 
132
 
133
  # Only accept text files
134
  if content_type is None:
@@ -148,7 +162,9 @@ class PlainTextConverter(DocumentConverter):
148
  class HtmlConverter(DocumentConverter):
149
  """Anything with content type text/html"""
150
 
151
- def convert(self, local_path: str, **kwargs: Any) -> Union[None, DocumentConverterResult]:
 
 
152
  # Bail if not html
153
  extension = kwargs.get("file_extension", "")
154
  if extension.lower() not in [".html", ".htm"]:
@@ -181,14 +197,17 @@ class HtmlConverter(DocumentConverter):
181
  assert isinstance(webpage_text, str)
182
 
183
  return DocumentConverterResult(
184
- title=None if soup.title is None else soup.title.string, text_content=webpage_text
 
185
  )
186
 
187
 
188
  class WikipediaConverter(DocumentConverter):
189
  """Handle Wikipedia pages separately, focusing only on the main document content."""
190
 
191
- def convert(self, local_path: str, **kwargs: Any) -> Union[None, DocumentConverterResult]:
 
 
192
  # Bail if not Wikipedia
193
  extension = kwargs.get("file_extension", "")
194
  if extension.lower() not in [".html", ".htm"]:
@@ -220,7 +239,9 @@ class WikipediaConverter(DocumentConverter):
220
  assert isinstance(main_title, str)
221
 
222
  # Convert the page
223
- webpage_text = f"# {main_title}\n\n" + _CustomMarkdownify().convert_soup(body_elm)
 
 
224
  else:
225
  webpage_text = _CustomMarkdownify().convert_soup(soup)
226
 
@@ -233,7 +254,9 @@ class WikipediaConverter(DocumentConverter):
233
  class YouTubeConverter(DocumentConverter):
234
  """Handle YouTube specially, focusing on the video title, description, and transcript."""
235
 
236
- def convert(self, local_path: str, **kwargs: Any) -> Union[None, DocumentConverterResult]:
 
 
237
  # Bail if not YouTube
238
  extension = kwargs.get("file_extension", "")
239
  if extension.lower() not in [".html", ".htm"]:
@@ -327,7 +350,12 @@ class YouTubeConverter(DocumentConverter):
327
  text_content=webpage_text,
328
  )
329
 
330
- def _get(self, metadata: Dict[str, str], keys: List[str], default: Union[str, None] = None) -> Union[str, None]:
 
 
 
 
 
331
  for k in keys:
332
  if k in metadata:
333
  return metadata[k]
@@ -444,7 +472,13 @@ class PptxConverter(HtmlConverter):
444
 
445
  # A placeholder name
446
  filename = re.sub(r"\W", "", shape.name) + ".jpg"
447
- md_content += "\n![" + (alt_text if alt_text else shape.name) + "](" + filename + ")\n"
 
 
 
 
 
 
448
 
449
  # Tables
450
  if self._is_table(shape):
@@ -460,7 +494,9 @@ class PptxConverter(HtmlConverter):
460
  html_table += "</tr>"
461
  first_row = False
462
  html_table += "</table></body></html>"
463
- md_content += "\n" + self._convert(html_table).text_content.strip() + "\n"
 
 
464
 
465
  # Text areas
466
  elif shape.has_text_frame:
@@ -508,7 +544,9 @@ class MediaConverter(DocumentConverter):
508
  return None
509
  else:
510
  try:
511
- result = subprocess.run([exiftool, "-json", local_path], capture_output=True, text=True).stdout
 
 
512
  return json.loads(result)[0]
513
  except Exception:
514
  return None
@@ -548,9 +586,13 @@ class WavConverter(MediaConverter):
548
  # Transcribe
549
  try:
550
  transcript = self._transcribe_audio(local_path)
551
- md_content += "\n\n### Audio Transcript:\n" + ("[No speech detected]" if transcript == "" else transcript)
 
 
552
  except Exception:
553
- md_content += "\n\n### Audio Transcript:\nError. Could not transcribe this audio."
 
 
554
 
555
  return DocumentConverterResult(
556
  title=None,
@@ -612,7 +654,9 @@ class Mp3Converter(WavConverter):
612
  "[No speech detected]" if transcript == "" else transcript
613
  )
614
  except Exception:
615
- md_content += "\n\n### Audio Transcript:\nError. Could not transcribe this audio."
 
 
616
 
617
  finally:
618
  os.unlink(temp_path)
@@ -662,7 +706,11 @@ class ImageConverter(MediaConverter):
662
  md_content += (
663
  "\n# Description:\n"
664
  + self._get_mlm_description(
665
- local_path, extension, mlm_client, mlm_model, prompt=kwargs.get("mlm_prompt")
 
 
 
 
666
  ).strip()
667
  + "\n"
668
  )
@@ -759,7 +807,11 @@ class MarkdownConverter:
759
 
760
  # Local path or url
761
  if isinstance(source, str):
762
- if source.startswith("http://") or source.startswith("https://") or source.startswith("file://"):
 
 
 
 
763
  return self.convert_url(source, **kwargs)
764
  else:
765
  return self.convert_local(source, **kwargs)
@@ -767,7 +819,9 @@ class MarkdownConverter:
767
  elif isinstance(source, requests.Response):
768
  return self.convert_response(source, **kwargs)
769
 
770
- def convert_local(self, path: str, **kwargs: Any) -> DocumentConverterResult: # TODO: deal with kwargs
 
 
771
  # Prepare a list of extensions to try (in order of priority)
772
  ext = kwargs.get("file_extension")
773
  extensions = [ext] if ext is not None else []
@@ -781,7 +835,9 @@ class MarkdownConverter:
781
  return self._convert(path, extensions, **kwargs)
782
 
783
  # TODO what should stream's type be?
784
- def convert_stream(self, stream: Any, **kwargs: Any) -> DocumentConverterResult: # TODO: deal with kwargs
 
 
785
  # Prepare a list of extensions to try (in order of priority)
786
  ext = kwargs.get("file_extension")
787
  extensions = [ext] if ext is not None else []
@@ -814,10 +870,14 @@ class MarkdownConverter:
814
 
815
  return result
816
 
817
- def convert_url(self, url: str, **kwargs: Any) -> DocumentConverterResult: # TODO: fix kwargs type
 
 
818
  # Send a HTTP request to the URL
819
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
820
- response = self._requests_session.get(url, stream=True, headers={"User-Agent": user_agent})
 
 
821
  response.raise_for_status()
822
  return self.convert_response(response, **kwargs)
823
 
@@ -871,7 +931,9 @@ class MarkdownConverter:
871
 
872
  return result
873
 
874
- def _convert(self, local_path: str, extensions: List[Union[str, None]], **kwargs) -> DocumentConverterResult:
 
 
875
  error_trace = ""
876
  for ext in extensions + [None]: # Try last with no extension
877
  for converter in self._page_converters:
@@ -899,7 +961,9 @@ class MarkdownConverter:
899
 
900
  if res is not None:
901
  # Normalize the content
902
- res.text_content = "\n".join([line.rstrip() for line in re.split(r"\r?\n", res.text_content)])
 
 
903
  res.text_content = re.sub(r"\n{3,}", "\n\n", res.text_content)
904
 
905
  # Todo
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
  # This is copied from Magentic-one's great repo: https://github.com/microsoft/autogen/blob/v0.4.4/python/packages/autogen-magentic-one/src/autogen_magentic_one/markdown_browser/mdconvert.py
4
  # Thanks to Microsoft researchers for open-sourcing this!
5
  # type: ignore
 
24
  import pdfminer
25
  import pdfminer.high_level
26
  import pptx
 
27
  # File-format detection
28
  import puremagic
29
  import pydub
 
87
  if self.options["default_title"] and not title:
88
  title = href
89
  title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
90
+ return (
91
+ "%s[%s](%s%s)%s" % (prefix, text, href, title_part, suffix)
92
+ if href
93
+ else text
94
+ )
95
 
96
  def convert_img(self, el: Any, text: str, convert_as_inline: bool) -> str:
97
  """Same as usual converter, but removes data URIs"""
 
100
  src = el.attrs.get("src", None) or ""
101
  title = el.attrs.get("title", None) or ""
102
  title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
103
+ if (
104
+ convert_as_inline
105
+ and el.parent.name not in self.options["keep_inline_images_in"]
106
+ ):
107
  return alt
108
 
109
  # Remove dataURIs
 
127
  class DocumentConverter:
128
  """Abstract superclass of all DocumentConverters."""
129
 
130
+ def convert(
131
+ self, local_path: str, **kwargs: Any
132
+ ) -> Union[None, DocumentConverterResult]:
133
  raise NotImplementedError()
134
 
135
 
136
  class PlainTextConverter(DocumentConverter):
137
  """Anything with content type text/plain"""
138
 
139
+ def convert(
140
+ self, local_path: str, **kwargs: Any
141
+ ) -> Union[None, DocumentConverterResult]:
142
  # Guess the content type from any file extension that might be around
143
+ content_type, _ = mimetypes.guess_type(
144
+ "__placeholder" + kwargs.get("file_extension", "")
145
+ )
146
 
147
  # Only accept text files
148
  if content_type is None:
 
162
  class HtmlConverter(DocumentConverter):
163
  """Anything with content type text/html"""
164
 
165
+ def convert(
166
+ self, local_path: str, **kwargs: Any
167
+ ) -> Union[None, DocumentConverterResult]:
168
  # Bail if not html
169
  extension = kwargs.get("file_extension", "")
170
  if extension.lower() not in [".html", ".htm"]:
 
197
  assert isinstance(webpage_text, str)
198
 
199
  return DocumentConverterResult(
200
+ title=None if soup.title is None else soup.title.string,
201
+ text_content=webpage_text,
202
  )
203
 
204
 
205
  class WikipediaConverter(DocumentConverter):
206
  """Handle Wikipedia pages separately, focusing only on the main document content."""
207
 
208
+ def convert(
209
+ self, local_path: str, **kwargs: Any
210
+ ) -> Union[None, DocumentConverterResult]:
211
  # Bail if not Wikipedia
212
  extension = kwargs.get("file_extension", "")
213
  if extension.lower() not in [".html", ".htm"]:
 
239
  assert isinstance(main_title, str)
240
 
241
  # Convert the page
242
+ webpage_text = f"# {main_title}\n\n" + _CustomMarkdownify().convert_soup(
243
+ body_elm
244
+ )
245
  else:
246
  webpage_text = _CustomMarkdownify().convert_soup(soup)
247
 
 
254
  class YouTubeConverter(DocumentConverter):
255
  """Handle YouTube specially, focusing on the video title, description, and transcript."""
256
 
257
+ def convert(
258
+ self, local_path: str, **kwargs: Any
259
+ ) -> Union[None, DocumentConverterResult]:
260
  # Bail if not YouTube
261
  extension = kwargs.get("file_extension", "")
262
  if extension.lower() not in [".html", ".htm"]:
 
350
  text_content=webpage_text,
351
  )
352
 
353
+ def _get(
354
+ self,
355
+ metadata: Dict[str, str],
356
+ keys: List[str],
357
+ default: Union[str, None] = None,
358
+ ) -> Union[str, None]:
359
  for k in keys:
360
  if k in metadata:
361
  return metadata[k]
 
472
 
473
  # A placeholder name
474
  filename = re.sub(r"\W", "", shape.name) + ".jpg"
475
+ md_content += (
476
+ "\n!["
477
+ + (alt_text if alt_text else shape.name)
478
+ + "]("
479
+ + filename
480
+ + ")\n"
481
+ )
482
 
483
  # Tables
484
  if self._is_table(shape):
 
494
  html_table += "</tr>"
495
  first_row = False
496
  html_table += "</table></body></html>"
497
+ md_content += (
498
+ "\n" + self._convert(html_table).text_content.strip() + "\n"
499
+ )
500
 
501
  # Text areas
502
  elif shape.has_text_frame:
 
544
  return None
545
  else:
546
  try:
547
+ result = subprocess.run(
548
+ [exiftool, "-json", local_path], capture_output=True, text=True
549
+ ).stdout
550
  return json.loads(result)[0]
551
  except Exception:
552
  return None
 
586
  # Transcribe
587
  try:
588
  transcript = self._transcribe_audio(local_path)
589
+ md_content += "\n\n### Audio Transcript:\n" + (
590
+ "[No speech detected]" if transcript == "" else transcript
591
+ )
592
  except Exception:
593
+ md_content += (
594
+ "\n\n### Audio Transcript:\nError. Could not transcribe this audio."
595
+ )
596
 
597
  return DocumentConverterResult(
598
  title=None,
 
654
  "[No speech detected]" if transcript == "" else transcript
655
  )
656
  except Exception:
657
+ md_content += (
658
+ "\n\n### Audio Transcript:\nError. Could not transcribe this audio."
659
+ )
660
 
661
  finally:
662
  os.unlink(temp_path)
 
706
  md_content += (
707
  "\n# Description:\n"
708
  + self._get_mlm_description(
709
+ local_path,
710
+ extension,
711
+ mlm_client,
712
+ mlm_model,
713
+ prompt=kwargs.get("mlm_prompt"),
714
  ).strip()
715
  + "\n"
716
  )
 
807
 
808
  # Local path or url
809
  if isinstance(source, str):
810
+ if (
811
+ source.startswith("http://")
812
+ or source.startswith("https://")
813
+ or source.startswith("file://")
814
+ ):
815
  return self.convert_url(source, **kwargs)
816
  else:
817
  return self.convert_local(source, **kwargs)
 
819
  elif isinstance(source, requests.Response):
820
  return self.convert_response(source, **kwargs)
821
 
822
+ def convert_local(
823
+ self, path: str, **kwargs: Any
824
+ ) -> DocumentConverterResult: # TODO: deal with kwargs
825
  # Prepare a list of extensions to try (in order of priority)
826
  ext = kwargs.get("file_extension")
827
  extensions = [ext] if ext is not None else []
 
835
  return self._convert(path, extensions, **kwargs)
836
 
837
  # TODO what should stream's type be?
838
+ def convert_stream(
839
+ self, stream: Any, **kwargs: Any
840
+ ) -> DocumentConverterResult: # TODO: deal with kwargs
841
  # Prepare a list of extensions to try (in order of priority)
842
  ext = kwargs.get("file_extension")
843
  extensions = [ext] if ext is not None else []
 
870
 
871
  return result
872
 
873
+ def convert_url(
874
+ self, url: str, **kwargs: Any
875
+ ) -> DocumentConverterResult: # TODO: fix kwargs type
876
  # Send a HTTP request to the URL
877
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
878
+ response = self._requests_session.get(
879
+ url, stream=True, headers={"User-Agent": user_agent}
880
+ )
881
  response.raise_for_status()
882
  return self.convert_response(response, **kwargs)
883
 
 
931
 
932
  return result
933
 
934
+ def _convert(
935
+ self, local_path: str, extensions: List[Union[str, None]], **kwargs
936
+ ) -> DocumentConverterResult:
937
  error_trace = ""
938
  for ext in extensions + [None]: # Try last with no extension
939
  for converter in self._page_converters:
 
961
 
962
  if res is not None:
963
  # Normalize the content
964
+ res.text_content = "\n".join(
965
+ [line.rstrip() for line in re.split(r"\r?\n", res.text_content)]
966
+ )
967
  res.text_content = re.sub(r"\n{3,}", "\n\n", res.text_content)
968
 
969
  # Todo
scripts/reformulator.py DELETED
@@ -1,86 +0,0 @@
1
- # Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
2
- # https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
3
- import copy
4
-
5
- from smolagents.models import MessageRole, Model
6
-
7
-
8
- def prepare_response(original_task: str, inner_messages, reformulation_model: Model) -> str:
9
- messages = [
10
- {
11
- "role": MessageRole.SYSTEM,
12
- "content": [
13
- {
14
- "type": "text",
15
- "text": f"""Earlier you were asked the following:
16
-
17
- {original_task}
18
-
19
- Your team then worked diligently to address that request. Read below a transcript of that conversation:""",
20
- }
21
- ],
22
- }
23
- ]
24
-
25
- # The first message just repeats the question, so remove it
26
- # if len(inner_messages) > 1:
27
- # del inner_messages[0]
28
-
29
- # copy them to this context
30
- try:
31
- for message in inner_messages:
32
- if not message.get("content"):
33
- continue
34
- message = copy.deepcopy(message)
35
- message["role"] = MessageRole.USER
36
- messages.append(message)
37
- except Exception:
38
- messages += [{"role": MessageRole.ASSISTANT, "content": str(inner_messages)}]
39
-
40
- # ask for the final answer
41
- messages.append(
42
- {
43
- "role": MessageRole.USER,
44
- "content": [
45
- {
46
- "type": "text",
47
- "text": f"""
48
- Read the above conversation and output a FINAL ANSWER to the question. The question is repeated here for convenience:
49
-
50
- {original_task}
51
-
52
- To output the final answer, use the following template: FINAL ANSWER: [YOUR FINAL ANSWER]
53
- Your FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
54
- ADDITIONALLY, your FINAL ANSWER MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)
55
- If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and DO NOT INCLUDE UNITS such as $ or USD or percent signs unless specified otherwise.
56
- If you are asked for a string, don't use articles or abbreviations (e.g. for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'.
57
- If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings.
58
- If you are unable to determine the final answer, output 'FINAL ANSWER: Unable to determine'
59
- """,
60
- }
61
- ],
62
- }
63
- )
64
-
65
- response = reformulation_model(messages).content
66
-
67
- final_answer = response.split("FINAL ANSWER: ")[-1].strip()
68
- print("> Reformulated answer: ", final_answer)
69
-
70
- # if "unable to determine" in final_answer.lower():
71
- # messages.append({"role": MessageRole.ASSISTANT, "content": response })
72
- # messages.append({"role": MessageRole.USER, "content": [{"type": "text", "text": """
73
- # I understand that a definitive answer could not be determined. Please make a well-informed EDUCATED GUESS based on the conversation.
74
-
75
- # To output the educated guess, use the following template: EDUCATED GUESS: [YOUR EDUCATED GUESS]
76
- # Your EDUCATED GUESS should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. DO NOT OUTPUT 'I don't know', 'Unable to determine', etc.
77
- # ADDITIONALLY, your EDUCATED GUESS MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)
78
- # If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and don't include units such as $ or percent signs unless specified otherwise.
79
- # If you are asked for a string, don't use articles or abbreviations (e.g. cit for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'.
80
- # If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings.
81
- # """.strip()}]})
82
-
83
- # response = model(messages).content
84
- # print("\n>>>Making an educated guess.\n", response)
85
- # final_answer = response.split("EDUCATED GUESS: ")[-1].strip()
86
- return final_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/run_agents.py DELETED
@@ -1,87 +0,0 @@
1
- import json
2
- import os
3
- import shutil
4
- import textwrap
5
- from pathlib import Path
6
-
7
- # import tqdm.asyncio
8
- from smolagents.utils import AgentError
9
-
10
-
11
- def serialize_agent_error(obj):
12
- if isinstance(obj, AgentError):
13
- return {"error_type": obj.__class__.__name__, "message": obj.message}
14
- else:
15
- return str(obj)
16
-
17
-
18
- def get_image_description(file_name: str, question: str, visual_inspection_tool) -> str:
19
- prompt = f"""Write a caption of 5 sentences for this image. Pay special attention to any details that might be useful for someone answering the following question:
20
- {question}. But do not try to answer the question directly!
21
- Do not add any information that is not present in the image."""
22
- return visual_inspection_tool(image_path=file_name, question=prompt)
23
-
24
-
25
- def get_document_description(file_path: str, question: str, document_inspection_tool) -> str:
26
- prompt = f"""Write a caption of 5 sentences for this document. Pay special attention to any details that might be useful for someone answering the following question:
27
- {question}. But do not try to answer the question directly!
28
- Do not add any information that is not present in the document."""
29
- return document_inspection_tool.forward_initial_exam_mode(file_path=file_path, question=prompt)
30
-
31
-
32
- def get_single_file_description(file_path: str, question: str, visual_inspection_tool, document_inspection_tool):
33
- file_extension = file_path.split(".")[-1]
34
- if file_extension in ["png", "jpg", "jpeg"]:
35
- file_description = f" - Attached image: {file_path}"
36
- file_description += (
37
- f"\n -> Image description: {get_image_description(file_path, question, visual_inspection_tool)}"
38
- )
39
- return file_description
40
- elif file_extension in ["pdf", "xls", "xlsx", "docx", "doc", "xml"]:
41
- file_description = f" - Attached document: {file_path}"
42
- image_path = file_path.split(".")[0] + ".png"
43
- if os.path.exists(image_path):
44
- description = get_image_description(image_path, question, visual_inspection_tool)
45
- else:
46
- description = get_document_description(file_path, question, document_inspection_tool)
47
- file_description += f"\n -> File description: {description}"
48
- return file_description
49
- elif file_extension in ["mp3", "m4a", "wav"]:
50
- return f" - Attached audio: {file_path}"
51
- else:
52
- return f" - Attached file: {file_path}"
53
-
54
-
55
- def get_zip_description(file_path: str, question: str, visual_inspection_tool, document_inspection_tool):
56
- folder_path = file_path.replace(".zip", "")
57
- os.makedirs(folder_path, exist_ok=True)
58
- shutil.unpack_archive(file_path, folder_path)
59
-
60
- prompt_use_files = ""
61
- for root, dirs, files in os.walk(folder_path):
62
- for file in files:
63
- file_path = os.path.join(root, file)
64
- prompt_use_files += "\n" + textwrap.indent(
65
- get_single_file_description(file_path, question, visual_inspection_tool, document_inspection_tool),
66
- prefix=" ",
67
- )
68
- return prompt_use_files
69
-
70
-
71
- def get_tasks_to_run(data, total: int, base_filename: Path, tasks_ids: list[int]):
72
- f = base_filename.parent / f"{base_filename.stem}_answers.jsonl"
73
- done = set()
74
- if f.exists():
75
- with open(f, encoding="utf-8") as fh:
76
- done = {json.loads(line)["task_id"] for line in fh if line.strip()}
77
-
78
- tasks = []
79
- for i in range(total):
80
- task_id = int(data[i]["task_id"])
81
- if task_id not in done:
82
- if tasks_ids is not None:
83
- if task_id in tasks_ids:
84
- tasks.append(data[i])
85
- else:
86
- tasks.append(data[i])
87
- return tasks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/text_cleaner_tool.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  """
2
  Text cleaning tool for smolagents.
3
 
@@ -7,32 +10,26 @@ text content with handling for various text transformation options.
7
 
8
  # Standard library imports
9
  import logging
10
- from typing import Dict, Any, Optional
11
 
12
  # Third-party imports
 
13
  from smolagents import Tool
14
 
15
- # Try to import cleantext - handle gracefully if not installed
16
- try:
17
- from cleantext import clean
18
-
19
- CLEANTEXT_AVAILABLE = True
20
- except ImportError:
21
- CLEANTEXT_AVAILABLE = False
22
-
23
  # Configure module logger
24
  logger = logging.getLogger(__name__)
25
 
26
 
27
  # pylint: disable=too-few-public-methods
28
  class TextCleanerTool(Tool):
29
- """A simplified text cleaner tool that avoids typing issues."""
30
 
31
  name = "clean_text"
32
- description = (
33
- "Cleans and normalizes text using the cleantext library. "
34
- "Transforms messy user-generated content into normalized text."
35
- )
 
36
  inputs = {
37
  "text": {"type": "string", "description": "The input text to clean"},
38
  "options": {
@@ -76,7 +73,7 @@ class TextCleanerTool(Tool):
76
  `clean-text` uses ftfy, unidecode and numerous hand-crafted rules,
77
  i.e., RegEx.
78
 
79
- Example API:
80
  clean("some input",
81
  fix_unicode=True, # fix various unicode errors
82
  to_ascii=True, # transliterate to closest ASCII
@@ -110,14 +107,6 @@ class TextCleanerTool(Tool):
110
  logger.error("Failed to convert input to string: %s", e)
111
  return f"Error: Could not process input of type {type(text)}"
112
 
113
- # Check if cleantext is available
114
- if not CLEANTEXT_AVAILABLE:
115
- logger.error(
116
- "cleantext package not installed. "
117
- "Install with: pip install clean-text"
118
- )
119
- return "Error: Required dependency 'clean-text' is not installed."
120
-
121
  # Default replacement tokens
122
  replacements = {
123
  "replace_with_url": "<URL>",
@@ -159,3 +148,6 @@ class TextCleanerTool(Tool):
159
  except (ValueError, TypeError, AttributeError) as e:
160
  logger.error("Error cleaning text: %s", e)
161
  return f"Error during text cleaning: {str(e)}"
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The Footscray Coding Collective. All rights reserved.
4
  """
5
  Text cleaning tool for smolagents.
6
 
 
10
 
11
  # Standard library imports
12
  import logging
13
+ from typing import Any, Dict, Optional
14
 
15
  # Third-party imports
16
+ from cleantext import clean
17
  from smolagents import Tool
18
 
 
 
 
 
 
 
 
 
19
  # Configure module logger
20
  logger = logging.getLogger(__name__)
21
 
22
 
23
  # pylint: disable=too-few-public-methods
24
  class TextCleanerTool(Tool):
25
+ """A simple text cleaner tool."""
26
 
27
  name = "clean_text"
28
+ description = """This tool can be used to process messy user-generated content into
29
+ normalized text. It handles a variety of text transformation options,
30
+ such as fixing unicode errors, transliterating to closest ASCII,
31
+ lowercasing text, normalizing line breaks, removing punctuation,
32
+ replacing numbers with a token, and more."""
33
  inputs = {
34
  "text": {"type": "string", "description": "The input text to clean"},
35
  "options": {
 
73
  `clean-text` uses ftfy, unidecode and numerous hand-crafted rules,
74
  i.e., RegEx.
75
 
76
+ Usage of the cleantext API:
77
  clean("some input",
78
  fix_unicode=True, # fix various unicode errors
79
  to_ascii=True, # transliterate to closest ASCII
 
107
  logger.error("Failed to convert input to string: %s", e)
108
  return f"Error: Could not process input of type {type(text)}"
109
 
 
 
 
 
 
 
 
 
110
  # Default replacement tokens
111
  replacements = {
112
  "replace_with_url": "<URL>",
 
148
  except (ValueError, TypeError, AttributeError) as e:
149
  logger.error("Error cleaning text: %s", e)
150
  return f"Error during text cleaning: {str(e)}"
151
+
152
+
153
+ __all__ = ["TextCleanerTool"]
scripts/text_inspector_tool.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from typing import Optional
2
 
3
  from smolagents import Tool
@@ -7,10 +9,24 @@ from .mdconvert import MarkdownConverter
7
 
8
 
9
  class TextInspectorTool(Tool):
10
- name = "inspect_file_as_text"
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  description = """
12
- You cannot load files yourself: instead call this tool to read a file as markdown text and ask questions about it.
13
- This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES."""
 
14
 
15
  inputs = {
16
  "file_path": {
@@ -27,15 +43,23 @@ This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pp
27
  md_converter = MarkdownConverter()
28
 
29
  def __init__(self, model: Model, text_limit: int):
 
 
 
30
  super().__init__()
31
  self.model = model
32
  self.text_limit = text_limit
33
 
34
  def forward_initial_exam_mode(self, file_path, question):
 
 
 
35
  result = self.md_converter.convert(file_path)
36
 
37
- if file_path[-4:] in [".png", ".jpg"]:
38
- raise Exception("Cannot use inspect_file_as_text tool with images: use visualizer instead!")
 
 
39
 
40
  if ".zip" in file_path:
41
  return result.text_content
@@ -73,11 +97,28 @@ This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pp
73
  ]
74
  return self.model(messages).content
75
 
76
- def forward(self, file_path, question: Optional[str] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  result = self.md_converter.convert(file_path)
78
 
79
  if file_path[-4:] in [".png", ".jpg"]:
80
- raise Exception("Cannot use inspect_file_as_text tool with images: use visualizer instead!")
 
 
81
 
82
  if ".zip" in file_path:
83
  return result.text_content
@@ -120,3 +161,6 @@ This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pp
120
  },
121
  ]
122
  return self.model(messages).content
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
  from typing import Optional
4
 
5
  from smolagents import Tool
 
9
 
10
 
11
  class TextInspectorTool(Tool):
12
+ """
13
+ Tool for converting various file types to text and answering questions about their contents.
14
+
15
+ Supported file types include:
16
+ - Text documents (.txt, .md)
17
+ - Web documents (.html, .htm)
18
+ - Office documents (.docx, .xlsx, .pptx)
19
+ - Audio files (.wav, .mp3, .flac)
20
+ - PDF documents (.pdf)
21
+
22
+ Images are not supported and should be processed with a visualizer tool instead.
23
+ """
24
+
25
+ name = "view_file"
26
  description = """
27
+ You cannot load files yourself: instead call this tool to read a file as markdown text and ask questions about it.
28
+ This tool handles the following file extensions: [".html", ".htm", ".md", ".txt", ".xlsx", ".pptx", ".wav", ".mp3", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES.
29
+ """
30
 
31
  inputs = {
32
  "file_path": {
 
43
  md_converter = MarkdownConverter()
44
 
45
  def __init__(self, model: Model, text_limit: int):
46
+ """
47
+ Initialize the TextInspectorTool with a model to use for generating text and a limit for the amount of text to generate.
48
+ """
49
  super().__init__()
50
  self.model = model
51
  self.text_limit = text_limit
52
 
53
  def forward_initial_exam_mode(self, file_path, question):
54
+ """
55
+ This is used for generating code for the initial exam, and is not used for the final exam.
56
+ """
57
  result = self.md_converter.convert(file_path)
58
 
59
+ if file_path[-4:] in [".png", ".jpg", ".webp"]:
60
+ raise Exception(
61
+ "Cannot use inspect_file_as_text tool with images: use visualizer instead!"
62
+ )
63
 
64
  if ".zip" in file_path:
65
  return result.text_content
 
97
  ]
98
  return self.model(messages).content
99
 
100
+ def forward(self, file_path: str, question: Optional[str] = None) -> str:
101
+ """
102
+ Process a file and optionally answer a question about its contents.
103
+
104
+ Args:
105
+ file_path: Path to the file to be processed. Must be a supported file type.
106
+ question: Optional question to answer about the file contents.
107
+ If None, returns the raw file content.
108
+
109
+ Returns:
110
+ Either the raw file content if no question is provided, or the model's
111
+ response to the question based on the file contents.
112
+
113
+ Raises:
114
+ Exception: If the file is an image file or has an unsupported format.
115
+ """
116
  result = self.md_converter.convert(file_path)
117
 
118
  if file_path[-4:] in [".png", ".jpg"]:
119
+ raise Exception(
120
+ "Cannot use inspect_file_as_text tool with images: use visualizer instead!"
121
+ )
122
 
123
  if ".zip" in file_path:
124
  return result.text_content
 
161
  },
162
  ]
163
  return self.model(messages).content
164
+
165
+
166
+ __all__ = ["TextInspectorTool"]
scripts/text_web_browser.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
2
  # https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
3
  import mimetypes
@@ -12,11 +15,11 @@ from urllib.parse import unquote, urljoin, urlparse
12
  import pathvalidate
13
  import requests
14
  from serpapi import GoogleSearch
15
-
16
  from smolagents import Tool
17
 
18
  from .cookies import COOKIES
19
- from .mdconvert import FileConversionException, MarkdownConverter, UnsupportedFormatException
 
20
 
21
 
22
  class SimpleTextBrowser:
@@ -45,7 +48,9 @@ class SimpleTextBrowser:
45
  self._page_content: str = ""
46
 
47
  self._find_on_page_query: Union[str, None] = None
48
- self._find_on_page_last_result: Union[int, None] = None # Location of the last result
 
 
49
 
50
  @property
51
  def address(self) -> str:
@@ -60,7 +65,9 @@ class SimpleTextBrowser:
60
  if uri_or_path == "about:blank":
61
  self._set_page_content("")
62
  elif uri_or_path.startswith("google:"):
63
- self._serpapi_search(uri_or_path[len("google:") :].strip(), filter_year=filter_year)
 
 
64
  else:
65
  if (
66
  not uri_or_path.startswith("http:")
@@ -97,7 +104,9 @@ class SimpleTextBrowser:
97
  self.viewport_current_page = len(self.viewport_pages) - 1
98
 
99
  def page_down(self) -> None:
100
- self.viewport_current_page = min(self.viewport_current_page + 1, len(self.viewport_pages) - 1)
 
 
101
 
102
  def page_up(self) -> None:
103
  self.viewport_current_page = max(self.viewport_current_page - 1, 0)
@@ -107,7 +116,10 @@ class SimpleTextBrowser:
107
 
108
  # Did we get here via a previous find_on_page search with the same query?
109
  # If so, map to find_next
110
- if query == self._find_on_page_query and self.viewport_current_page == self._find_on_page_last_result:
 
 
 
111
  return self.find_next()
112
 
113
  # Ok it's a new search start from the current viewport
@@ -135,7 +147,9 @@ class SimpleTextBrowser:
135
  if starting_viewport >= len(self.viewport_pages):
136
  starting_viewport = 0
137
 
138
- viewport_match = self._find_next_viewport(self._find_on_page_query, starting_viewport)
 
 
139
  if viewport_match is None:
140
  self._find_on_page_last_result = None
141
  return None
@@ -144,7 +158,9 @@ class SimpleTextBrowser:
144
  self._find_on_page_last_result = viewport_match
145
  return self.viewport
146
 
147
- def _find_next_viewport(self, query: str, starting_viewport: int) -> Union[int, None]:
 
 
148
  """Search for matches between the starting viewport looping when reaching the end."""
149
 
150
  if query is None:
@@ -153,7 +169,9 @@ class SimpleTextBrowser:
153
  # Normalize the query, and convert to a regular expression
154
  nquery = re.sub(r"\*", "__STAR__", query)
155
  nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
156
- nquery = nquery.replace(" __STAR__ ", "__STAR__ ") # Merge isolated stars with prior word
 
 
157
  nquery = nquery.replace("__STAR__", ".*").lower()
158
 
159
  if nquery.strip() == "":
@@ -196,7 +214,9 @@ class SimpleTextBrowser:
196
  while start_idx < len(self._page_content):
197
  end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
198
  # Adjust to end on a space
199
- while end_idx < len(self._page_content) and self._page_content[end_idx - 1] not in [" ", "\t", "\r", "\n"]:
 
 
200
  end_idx += 1
201
  self.viewport_pages.append((start_idx, end_idx))
202
  start_idx = end_idx
@@ -211,15 +231,21 @@ class SimpleTextBrowser:
211
  "api_key": self.serpapi_key,
212
  }
213
  if filter_year is not None:
214
- params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
 
 
215
 
216
  search = GoogleSearch(params)
217
  results = search.get_dict()
218
  self.page_title = f"{query} - Search"
219
  if "organic_results" not in results.keys():
220
- raise Exception(f"No results found for query: '{query}'. Use a less specific query.")
 
 
221
  if len(results["organic_results"]) == 0:
222
- year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
 
 
223
  self._set_page_content(
224
  f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
225
  )
@@ -250,7 +276,9 @@ class SimpleTextBrowser:
250
 
251
  redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{_prev_visit(page['link'])}{snippet}"
252
 
253
- redacted_version = redacted_version.replace("Your browser can't play this video.", "")
 
 
254
  web_snippets.append(redacted_version)
255
 
256
  content = (
@@ -270,7 +298,11 @@ class SimpleTextBrowser:
270
  self._set_page_content(res.text_content)
271
  else:
272
  # Prepare the request parameters
273
- request_kwargs = self.request_kwargs.copy() if self.request_kwargs is not None else {}
 
 
 
 
274
  request_kwargs["stream"] = True
275
 
276
  # Send a HTTP request to the URL
@@ -291,15 +323,21 @@ class SimpleTextBrowser:
291
  fname = None
292
  download_path = None
293
  try:
294
- fname = pathvalidate.sanitize_filename(os.path.basename(urlparse(url).path)).strip()
295
- download_path = os.path.abspath(os.path.join(self.downloads_folder, fname))
 
 
 
 
296
 
297
  suffix = 0
298
  while os.path.exists(download_path) and suffix < 1000:
299
  suffix += 1
300
  base, ext = os.path.splitext(fname)
301
  new_fname = f"{base}__{suffix}{ext}"
302
- download_path = os.path.abspath(os.path.join(self.downloads_folder, new_fname))
 
 
303
 
304
  except NameError:
305
  pass
@@ -310,7 +348,9 @@ class SimpleTextBrowser:
310
  if extension is None:
311
  extension = ".download"
312
  fname = str(uuid.uuid4()) + extension
313
- download_path = os.path.abspath(os.path.join(self.downloads_folder, fname))
 
 
314
 
315
  # Open a file for writing
316
  with open(download_path, "wb") as fh:
@@ -324,11 +364,15 @@ class SimpleTextBrowser:
324
  except UnsupportedFormatException as e:
325
  print(e)
326
  self.page_title = ("Download complete.",)
327
- self._set_page_content(f"# Download complete\n\nSaved file to '{download_path}'")
 
 
328
  except FileConversionException as e:
329
  print(e)
330
  self.page_title = ("Download complete.",)
331
- self._set_page_content(f"# Download complete\n\nSaved file to '{download_path}'")
 
 
332
  except FileNotFoundError:
333
  self.page_title = "Error 404"
334
  self._set_page_content(f"## Error 404\n\nFile not found: {download_path}")
@@ -341,10 +385,14 @@ class SimpleTextBrowser:
341
  if content_type is not None and "text/html" in content_type.lower():
342
  res = self._mdconvert.convert(response)
343
  self.page_title = f"Error {response.status_code}"
344
- self._set_page_content(f"## Error {response.status_code}\n\n{res.text_content}")
 
 
345
  else:
346
  text = ""
347
- for chunk in response.iter_content(chunk_size=512, decode_unicode=True):
 
 
348
  text += chunk
349
  self.page_title = f"Error {response.status_code}"
350
  self._set_page_content(f"## Error {response.status_code}\n\n{text}")
@@ -366,14 +414,18 @@ class SimpleTextBrowser:
366
  header += f"You previously visited this page {round(time.time() - self.history[i][1])} seconds ago.\n"
367
  break
368
 
369
- header += f"Viewport position: Showing page {current_page + 1} of {total_pages}.\n"
 
 
370
  return (header, self.viewport)
371
 
372
 
373
  class SearchInformationTool(Tool):
374
  name = "web_search"
375
  description = "Perform a web search query (think a google search) and returns the search results."
376
- inputs = {"query": {"type": "string", "description": "The web search query to perform."}}
 
 
377
  inputs["filter_year"] = {
378
  "type": "string",
379
  "description": "[Optional parameter]: filter the search results to only include pages from a specific year. For example, '2020' will only include pages from 2020. Make sure to use this parameter if you're trying to search for articles from a specific date!",
@@ -394,7 +446,12 @@ class SearchInformationTool(Tool):
394
  class VisitTool(Tool):
395
  name = "visit_page"
396
  description = "Visit a webpage at a given URL and return its text. Given a url to a YouTube video, this returns the transcript."
397
- inputs = {"url": {"type": "string", "description": "The relative or absolute url of the webapge to visit."}}
 
 
 
 
 
398
  output_type = "string"
399
 
400
  def __init__(self, browser):
@@ -413,7 +470,12 @@ class DownloadTool(Tool):
413
  Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"]
414
  After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it.
415
  DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead."""
416
- inputs = {"url": {"type": "string", "description": "The relative or absolute url of the file to be downloaded."}}
 
 
 
 
 
417
  output_type = "string"
418
 
419
  def __init__(self, browser):
@@ -435,7 +497,9 @@ DO NOT use this tool for .pdf or .txt or .htm files: for these types of files us
435
  f.write(response.content)
436
 
437
  if "pdf" in extension or "txt" in extension or "htm" in extension:
438
- raise Exception("Do not use this tool for pdf or txt or html files: use visit_page instead.")
 
 
439
 
440
  return f"File was downloaded and saved under path {new_path}."
441
 
@@ -461,15 +525,23 @@ class ArchiveSearchTool(Tool):
461
  archive_url = no_timestamp_url + f"&timestamp={date}"
462
  response = requests.get(archive_url).json()
463
  response_notimestamp = requests.get(no_timestamp_url).json()
464
- if "archived_snapshots" in response and "closest" in response["archived_snapshots"]:
 
 
 
465
  closest = response["archived_snapshots"]["closest"]
466
  print("Archive found!", closest)
467
 
468
- elif "archived_snapshots" in response_notimestamp and "closest" in response_notimestamp["archived_snapshots"]:
 
 
 
469
  closest = response_notimestamp["archived_snapshots"]["closest"]
470
  print("Archive found!", closest)
471
  else:
472
- raise Exception(f"Your {url=} was not archived on Wayback Machine, try a different url.")
 
 
473
  target_url = closest["url"]
474
  self.browser.visit_page(target_url)
475
  header, content = self.browser._state()
@@ -499,9 +571,7 @@ class PageUpTool(Tool):
499
 
500
  class PageDownTool(Tool):
501
  name = "page_down"
502
- description = (
503
- "Scroll the viewport DOWN one page-length in the current webpage and return the new viewport content."
504
- )
505
  inputs = {}
506
  output_type = "string"
507
 
@@ -558,6 +628,20 @@ class FindNextTool(Tool):
558
  header, content = self.browser._state()
559
 
560
  if find_result is None:
561
- return header.strip() + "\n=======================\nThe search string was not found on this page."
 
 
 
562
  else:
563
  return header.strip() + "\n=======================\n" + content
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # TODO: REMOVE REDUNDANT SERPAPI CODE AND IMPORT/EXTEND DEFAULT GoogleSearchTool FROM SMOLAGENTS
4
  # Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
5
  # https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
6
  import mimetypes
 
15
  import pathvalidate
16
  import requests
17
  from serpapi import GoogleSearch
 
18
  from smolagents import Tool
19
 
20
  from .cookies import COOKIES
21
+ from .mdconvert import (FileConversionException, MarkdownConverter,
22
+ UnsupportedFormatException)
23
 
24
 
25
  class SimpleTextBrowser:
 
48
  self._page_content: str = ""
49
 
50
  self._find_on_page_query: Union[str, None] = None
51
+ self._find_on_page_last_result: Union[int, None] = (
52
+ None # Location of the last result
53
+ )
54
 
55
  @property
56
  def address(self) -> str:
 
65
  if uri_or_path == "about:blank":
66
  self._set_page_content("")
67
  elif uri_or_path.startswith("google:"):
68
+ self._serpapi_search(
69
+ uri_or_path[len("google:") :].strip(), filter_year=filter_year
70
+ )
71
  else:
72
  if (
73
  not uri_or_path.startswith("http:")
 
104
  self.viewport_current_page = len(self.viewport_pages) - 1
105
 
106
  def page_down(self) -> None:
107
+ self.viewport_current_page = min(
108
+ self.viewport_current_page + 1, len(self.viewport_pages) - 1
109
+ )
110
 
111
  def page_up(self) -> None:
112
  self.viewport_current_page = max(self.viewport_current_page - 1, 0)
 
116
 
117
  # Did we get here via a previous find_on_page search with the same query?
118
  # If so, map to find_next
119
+ if (
120
+ query == self._find_on_page_query
121
+ and self.viewport_current_page == self._find_on_page_last_result
122
+ ):
123
  return self.find_next()
124
 
125
  # Ok it's a new search start from the current viewport
 
147
  if starting_viewport >= len(self.viewport_pages):
148
  starting_viewport = 0
149
 
150
+ viewport_match = self._find_next_viewport(
151
+ self._find_on_page_query, starting_viewport
152
+ )
153
  if viewport_match is None:
154
  self._find_on_page_last_result = None
155
  return None
 
158
  self._find_on_page_last_result = viewport_match
159
  return self.viewport
160
 
161
+ def _find_next_viewport(
162
+ self, query: str, starting_viewport: int
163
+ ) -> Union[int, None]:
164
  """Search for matches between the starting viewport looping when reaching the end."""
165
 
166
  if query is None:
 
169
  # Normalize the query, and convert to a regular expression
170
  nquery = re.sub(r"\*", "__STAR__", query)
171
  nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
172
+ nquery = nquery.replace(
173
+ " __STAR__ ", "__STAR__ "
174
+ ) # Merge isolated stars with prior word
175
  nquery = nquery.replace("__STAR__", ".*").lower()
176
 
177
  if nquery.strip() == "":
 
214
  while start_idx < len(self._page_content):
215
  end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
216
  # Adjust to end on a space
217
+ while end_idx < len(self._page_content) and self._page_content[
218
+ end_idx - 1
219
+ ] not in [" ", "\t", "\r", "\n"]:
220
  end_idx += 1
221
  self.viewport_pages.append((start_idx, end_idx))
222
  start_idx = end_idx
 
231
  "api_key": self.serpapi_key,
232
  }
233
  if filter_year is not None:
234
+ params["tbs"] = (
235
+ f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
236
+ )
237
 
238
  search = GoogleSearch(params)
239
  results = search.get_dict()
240
  self.page_title = f"{query} - Search"
241
  if "organic_results" not in results.keys():
242
+ raise Exception(
243
+ f"No results found for query: '{query}'. Use a less specific query."
244
+ )
245
  if len(results["organic_results"]) == 0:
246
+ year_filter_message = (
247
+ f" with filter year={filter_year}" if filter_year is not None else ""
248
+ )
249
  self._set_page_content(
250
  f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
251
  )
 
276
 
277
  redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{_prev_visit(page['link'])}{snippet}"
278
 
279
+ redacted_version = redacted_version.replace(
280
+ "Your browser can't play this video.", ""
281
+ )
282
  web_snippets.append(redacted_version)
283
 
284
  content = (
 
298
  self._set_page_content(res.text_content)
299
  else:
300
  # Prepare the request parameters
301
+ request_kwargs = (
302
+ self.request_kwargs.copy()
303
+ if self.request_kwargs is not None
304
+ else {}
305
+ )
306
  request_kwargs["stream"] = True
307
 
308
  # Send a HTTP request to the URL
 
323
  fname = None
324
  download_path = None
325
  try:
326
+ fname = pathvalidate.sanitize_filename(
327
+ os.path.basename(urlparse(url).path)
328
+ ).strip()
329
+ download_path = os.path.abspath(
330
+ os.path.join(self.downloads_folder, fname)
331
+ )
332
 
333
  suffix = 0
334
  while os.path.exists(download_path) and suffix < 1000:
335
  suffix += 1
336
  base, ext = os.path.splitext(fname)
337
  new_fname = f"{base}__{suffix}{ext}"
338
+ download_path = os.path.abspath(
339
+ os.path.join(self.downloads_folder, new_fname)
340
+ )
341
 
342
  except NameError:
343
  pass
 
348
  if extension is None:
349
  extension = ".download"
350
  fname = str(uuid.uuid4()) + extension
351
+ download_path = os.path.abspath(
352
+ os.path.join(self.downloads_folder, fname)
353
+ )
354
 
355
  # Open a file for writing
356
  with open(download_path, "wb") as fh:
 
364
  except UnsupportedFormatException as e:
365
  print(e)
366
  self.page_title = ("Download complete.",)
367
+ self._set_page_content(
368
+ f"# Download complete\n\nSaved file to '{download_path}'"
369
+ )
370
  except FileConversionException as e:
371
  print(e)
372
  self.page_title = ("Download complete.",)
373
+ self._set_page_content(
374
+ f"# Download complete\n\nSaved file to '{download_path}'"
375
+ )
376
  except FileNotFoundError:
377
  self.page_title = "Error 404"
378
  self._set_page_content(f"## Error 404\n\nFile not found: {download_path}")
 
385
  if content_type is not None and "text/html" in content_type.lower():
386
  res = self._mdconvert.convert(response)
387
  self.page_title = f"Error {response.status_code}"
388
+ self._set_page_content(
389
+ f"## Error {response.status_code}\n\n{res.text_content}"
390
+ )
391
  else:
392
  text = ""
393
+ for chunk in response.iter_content(
394
+ chunk_size=512, decode_unicode=True
395
+ ):
396
  text += chunk
397
  self.page_title = f"Error {response.status_code}"
398
  self._set_page_content(f"## Error {response.status_code}\n\n{text}")
 
414
  header += f"You previously visited this page {round(time.time() - self.history[i][1])} seconds ago.\n"
415
  break
416
 
417
+ header += (
418
+ f"Viewport position: Showing page {current_page + 1} of {total_pages}.\n"
419
+ )
420
  return (header, self.viewport)
421
 
422
 
423
  class SearchInformationTool(Tool):
424
  name = "web_search"
425
  description = "Perform a web search query (think a google search) and returns the search results."
426
+ inputs = {
427
+ "query": {"type": "string", "description": "The web search query to perform."}
428
+ }
429
  inputs["filter_year"] = {
430
  "type": "string",
431
  "description": "[Optional parameter]: filter the search results to only include pages from a specific year. For example, '2020' will only include pages from 2020. Make sure to use this parameter if you're trying to search for articles from a specific date!",
 
446
  class VisitTool(Tool):
447
  name = "visit_page"
448
  description = "Visit a webpage at a given URL and return its text. Given a url to a YouTube video, this returns the transcript."
449
+ inputs = {
450
+ "url": {
451
+ "type": "string",
452
+ "description": "The relative or absolute url of the webapge to visit.",
453
+ }
454
+ }
455
  output_type = "string"
456
 
457
  def __init__(self, browser):
 
470
  Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"]
471
  After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it.
472
  DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead."""
473
+ inputs = {
474
+ "url": {
475
+ "type": "string",
476
+ "description": "The relative or absolute url of the file to be downloaded.",
477
+ }
478
+ }
479
  output_type = "string"
480
 
481
  def __init__(self, browser):
 
497
  f.write(response.content)
498
 
499
  if "pdf" in extension or "txt" in extension or "htm" in extension:
500
+ raise Exception(
501
+ "Do not use this tool for pdf or txt or html files: use visit_page instead."
502
+ )
503
 
504
  return f"File was downloaded and saved under path {new_path}."
505
 
 
525
  archive_url = no_timestamp_url + f"&timestamp={date}"
526
  response = requests.get(archive_url).json()
527
  response_notimestamp = requests.get(no_timestamp_url).json()
528
+ if (
529
+ "archived_snapshots" in response
530
+ and "closest" in response["archived_snapshots"]
531
+ ):
532
  closest = response["archived_snapshots"]["closest"]
533
  print("Archive found!", closest)
534
 
535
+ elif (
536
+ "archived_snapshots" in response_notimestamp
537
+ and "closest" in response_notimestamp["archived_snapshots"]
538
+ ):
539
  closest = response_notimestamp["archived_snapshots"]["closest"]
540
  print("Archive found!", closest)
541
  else:
542
+ raise Exception(
543
+ f"Your {url=} was not archived on Wayback Machine, try a different url."
544
+ )
545
  target_url = closest["url"]
546
  self.browser.visit_page(target_url)
547
  header, content = self.browser._state()
 
571
 
572
  class PageDownTool(Tool):
573
  name = "page_down"
574
+ description = "Scroll the viewport DOWN one page-length in the current webpage and return the new viewport content."
 
 
575
  inputs = {}
576
  output_type = "string"
577
 
 
628
  header, content = self.browser._state()
629
 
630
  if find_result is None:
631
+ return (
632
+ header.strip()
633
+ + "\n=======================\nThe search string was not found on this page."
634
+ )
635
  else:
636
  return header.strip() + "\n=======================\n" + content
637
+
638
+
639
+ __all__ = [
640
+ "DownloadTool",
641
+ "VisitTool",
642
+ "PageUpTool",
643
+ "PageDownTool",
644
+ "FinderTool",
645
+ "FindNextTool",
646
+ "ArchiveSearchTool",
647
+ ]
scripts/time_tools.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The Footscray Coding Collective. All rights reserved.
4
+ from datetime import datetime
5
+ from typing import Optional
6
+
7
+ import pytz
8
+ from smolagents import tool
9
+
10
+
11
+ @tool
12
+ def get_temporal_context(
13
+ timezone_str: str = "US/Eastern", market: str = "US", date_str: Optional[str] = None
14
+ ) -> str:
15
+ """
16
+ Provides a concise overview of the current temporal context, including date, time, and market status.
17
+
18
+ Args:
19
+ timezone_str: The timezone to display time in (default: US/Eastern)
20
+ market: Market identifier (US, EU, ASIA) (default: US)
21
+ date_str: Date in YYYY-MM-DD format (optional, defaults to current date if not provided)
22
+
23
+ Returns:
24
+ A formatted string containing the current date, time, year, trading day status, and market hours status.
25
+ """
26
+
27
+ try:
28
+ # Get current time information using pytz
29
+ try:
30
+ tz = pytz.timezone(timezone_str)
31
+ except pytz.exceptions.UnknownTimeZoneError:
32
+ return f"Error: Unknown timezone '{timezone_str}'. Try using standard timezone names like 'US/Eastern'."
33
+
34
+ now = datetime.now(tz)
35
+ current_date = now.strftime("%Y-%m-%d")
36
+ current_time = now.strftime("%H:%M:%S")
37
+ current_year = now.year
38
+ weekday_name = now.strftime("%A")
39
+ time_info = f"""Current Time Information:
40
+ - Date: {current_date} ({weekday_name})
41
+ - Time: {current_time} ({timezone_str})
42
+ - Year: {current_year}
43
+ """
44
+
45
+ # Get Market hours Information
46
+ if market == "US":
47
+ # Convert time to US/Eastern for US market check
48
+ eastern_tz = pytz.timezone("US/Eastern")
49
+ eastern_now = now.astimezone(eastern_tz)
50
+
51
+ is_weekday_us = eastern_now.weekday() < 5
52
+ us_minutes = eastern_now.hour * 60 + eastern_now.minute
53
+ us_market_open = 9 * 60 + 30 # 9:30 AM ET
54
+ us_market_close = 16 * 60 # 4:00 PM ET
55
+
56
+ if is_weekday_us and us_market_open <= us_minutes < us_market_close:
57
+ market_status = "Open"
58
+ else:
59
+ market_status = "Closed"
60
+
61
+ market_hours_info = f"US Markets (NYSE, NASDAQ): {market_status}"
62
+
63
+ elif market == "EU":
64
+ # Convert time to London for EU market check
65
+ london_tz = pytz.timezone("Europe/London")
66
+ london_now = now.astimezone(london_tz)
67
+
68
+ is_weekday_eu = london_now.weekday() < 5
69
+ eu_minutes = london_now.hour * 60 + london_now.minute
70
+ eu_market_open = 8 * 60 # 8:00 AM London
71
+ eu_market_close = 16 * 60 + 30 # 4:30 PM London
72
+
73
+ if is_weekday_eu and eu_market_open <= eu_minutes < eu_market_close:
74
+ market_status = "Open"
75
+ else:
76
+ market_status = "Closed"
77
+
78
+ market_hours_info = f"European Markets (LSE, Euronext): {market_status}"
79
+
80
+ elif market == "ASIA":
81
+ # Convert time to Tokyo for Asian market check
82
+ tokyo_tz = pytz.timezone("Asia/Tokyo")
83
+ tokyo_now = now.astimezone(tokyo_tz)
84
+
85
+ is_weekday_tokyo = tokyo_now.weekday() < 5
86
+ tokyo_minutes = tokyo_now.hour * 60 + tokyo_now.minute
87
+ tokyo_morning_open = 9 * 60 # 9:00 AM Tokyo
88
+ tokyo_morning_close = 11 * 60 + 30 # 11:30 AM Tokyo
89
+ tokyo_afternoon_open = 12 * 60 + 30 # 12:30 PM Tokyo
90
+ tokyo_afternoon_close = 15 * 60 # 3:00 PM Tokyo
91
+
92
+ is_tokyo_session = (
93
+ tokyo_morning_open <= tokyo_minutes < tokyo_morning_close
94
+ ) or (tokyo_afternoon_open <= tokyo_minutes < tokyo_afternoon_close)
95
+
96
+ if is_weekday_tokyo and is_tokyo_session:
97
+ market_status = "Open"
98
+ else:
99
+ market_status = "Closed"
100
+
101
+ market_hours_info = (
102
+ "Asian Markets (Tokyo Stock Exchange, Shanghai Stock Exchange, "
103
+ f"Australian Securities Exchange): {market_status}"
104
+ )
105
+
106
+ else:
107
+ return f"Error: Invalid market '{market}'. Supported markets are 'US', 'EU', and 'ASIA'."
108
+
109
+ # Get Trading Day Information
110
+ if date_str:
111
+ try:
112
+ date_obj = datetime.strptime(date_str, "%Y-%m-%d")
113
+ # Apply timezone to date_obj
114
+ date_obj = tz.localize(date_obj)
115
+ except ValueError:
116
+ return (
117
+ f"Error: Invalid date format '{date_str}'. Use YYYY-MM-DD format."
118
+ )
119
+ else:
120
+ date_obj = now
121
+ date_str = now.strftime("%Y-%m-%d")
122
+
123
+ is_weekend = date_obj.weekday() > 4
124
+ trading_day = "No" if is_weekend else "Yes"
125
+ trading_info = f"Trading Day: {trading_day}"
126
+
127
+ # Combine all information
128
+ final_result = f"""{time_info}
129
+ {market_hours_info}
130
+ - {trading_info}
131
+ """
132
+
133
+ return final_result
134
+
135
+ except Exception as e:
136
+ return f"Error retrieving temporal context: {str(e)}"
137
+
138
+
139
+ __all__ = ["get_temporal_context"]
scripts/visual_qa.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import base64
2
  import json
3
  import mimetypes
@@ -10,10 +12,8 @@ import requests
10
  from dotenv import load_dotenv
11
  from huggingface_hub import InferenceClient
12
  from PIL import Image
13
- from transformers import AutoProcessor
14
-
15
  from smolagents import Tool, tool
16
-
17
 
18
  load_dotenv(override=True)
19
 
@@ -31,7 +31,9 @@ def process_images_and_text(image_path, query, client):
31
  },
32
  ]
33
 
34
- prompt_with_template = idefics_processor.apply_chat_template(messages, add_generation_prompt=True)
 
 
35
 
36
  # load images from local directory
37
 
@@ -42,7 +44,9 @@ def process_images_and_text(image_path, query, client):
42
 
43
  # Convert the image to a base64 string
44
  buffer = BytesIO()
45
- image.save(buffer, format="JPEG") # Use the appropriate format (e.g., JPEG, PNG)
 
 
46
  base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
47
 
48
  # add string formatting required by the endpoint
@@ -51,7 +55,9 @@ def process_images_and_text(image_path, query, client):
51
  return image_string
52
 
53
  image_string = encode_local_image(image_path)
54
- prompt_with_images = prompt_with_template.replace("<image>", "![]({}) ").format(image_string)
 
 
55
 
56
  payload = {
57
  "inputs": prompt_with_images,
@@ -95,7 +101,10 @@ def encode_image(image_path):
95
  return base64.b64encode(image_file.read()).decode("utf-8")
96
 
97
 
98
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"}
 
 
 
99
 
100
 
101
  def resize_image(image_path):
@@ -115,7 +124,11 @@ class VisualQATool(Tool):
115
  "description": "The path to the image on which to answer the question",
116
  "type": "string",
117
  },
118
- "question": {"description": "the question to answer", "type": "string", "nullable": True},
 
 
 
 
119
  }
120
  output_type = "string"
121
  # try use the same model with two different endpoints
@@ -136,9 +149,7 @@ class VisualQATool(Tool):
136
  output = process_images_and_text(new_image_path, question, self.client)
137
 
138
  if add_note:
139
- output = (
140
- f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
141
- )
142
 
143
  return output
144
 
@@ -156,7 +167,9 @@ def visualizer(image_path: str, question: Optional[str] = None) -> str:
156
  add_note = True
157
  question = "Please write a detailed caption for this image."
158
  if not isinstance(image_path, str):
159
- raise Exception("You should provide at least `image_path` string argument to this tool!")
 
 
160
 
161
  mime_type, _ = mimetypes.guess_type(image_path)
162
  base64_image = encode_image(image_path)
@@ -168,13 +181,18 @@ def visualizer(image_path: str, question: Optional[str] = None) -> str:
168
  "role": "user",
169
  "content": [
170
  {"type": "text", "text": "what is in this image" + question},
171
- {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}},
 
 
 
172
  ],
173
  }
174
  ],
175
  "max_tokens": 1000,
176
  }
177
- response = requests.post("https://openrouter.ai/api/v1", headers=headers, json=payload)
 
 
178
  try:
179
  output = response.json()["choices"][0]["message"]["content"]
180
  except Exception:
@@ -184,5 +202,5 @@ def visualizer(image_path: str, question: Optional[str] = None) -> str:
184
  output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
185
 
186
  # TO DO: write to yaml or chromadb -> HF Dataset in due course...
187
-
188
  return output
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
  import base64
4
  import json
5
  import mimetypes
 
12
  from dotenv import load_dotenv
13
  from huggingface_hub import InferenceClient
14
  from PIL import Image
 
 
15
  from smolagents import Tool, tool
16
+ from transformers import AutoProcessor
17
 
18
  load_dotenv(override=True)
19
 
 
31
  },
32
  ]
33
 
34
+ prompt_with_template = idefics_processor.apply_chat_template(
35
+ messages, add_generation_prompt=True
36
+ )
37
 
38
  # load images from local directory
39
 
 
44
 
45
  # Convert the image to a base64 string
46
  buffer = BytesIO()
47
+ image.save(
48
+ buffer, format="JPEG"
49
+ ) # Use the appropriate format (e.g., JPEG, PNG)
50
  base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
51
 
52
  # add string formatting required by the endpoint
 
55
  return image_string
56
 
57
  image_string = encode_local_image(image_path)
58
+ prompt_with_images = prompt_with_template.replace("<image>", "![]({}) ").format(
59
+ image_string
60
+ )
61
 
62
  payload = {
63
  "inputs": prompt_with_images,
 
101
  return base64.b64encode(image_file.read()).decode("utf-8")
102
 
103
 
104
+ headers = {
105
+ "Content-Type": "application/json",
106
+ "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}",
107
+ }
108
 
109
 
110
  def resize_image(image_path):
 
124
  "description": "The path to the image on which to answer the question",
125
  "type": "string",
126
  },
127
+ "question": {
128
+ "description": "the question to answer",
129
+ "type": "string",
130
+ "nullable": True,
131
+ },
132
  }
133
  output_type = "string"
134
  # try use the same model with two different endpoints
 
149
  output = process_images_and_text(new_image_path, question, self.client)
150
 
151
  if add_note:
152
+ output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
 
 
153
 
154
  return output
155
 
 
167
  add_note = True
168
  question = "Please write a detailed caption for this image."
169
  if not isinstance(image_path, str):
170
+ raise Exception(
171
+ "You should provide at least `image_path` string argument to this tool!"
172
+ )
173
 
174
  mime_type, _ = mimetypes.guess_type(image_path)
175
  base64_image = encode_image(image_path)
 
181
  "role": "user",
182
  "content": [
183
  {"type": "text", "text": "what is in this image" + question},
184
+ {
185
+ "type": "image_url",
186
+ "image_url": {"url": f"data:{mime_type};base64,{base64_image}"},
187
+ },
188
  ],
189
  }
190
  ],
191
  "max_tokens": 1000,
192
  }
193
+ response = requests.post(
194
+ "https://openrouter.ai/api/v1", headers=headers, json=payload
195
+ )
196
  try:
197
  output = response.json()["choices"][0]["message"]["content"]
198
  except Exception:
 
202
  output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
203
 
204
  # TO DO: write to yaml or chromadb -> HF Dataset in due course...
205
+
206
  return output