Leonardo commited on
Commit
1a8421c
·
verified ·
1 Parent(s): 6e3a63d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -148
app.py CHANGED
@@ -2,8 +2,7 @@ import mimetypes
2
  import os
3
  import re
4
  import shutil
5
- import threading
6
- from typing import Optional, List, Dict, Any
7
 
8
  from dotenv import load_dotenv
9
  from huggingface_hub import login
@@ -20,7 +19,7 @@ from scripts.text_web_browser import (
20
  VisitTool,
21
  )
22
  from scripts.visual_qa import visualizer
23
- from scripts.legal_document_tool import LegalDocumentTool
24
  from smolagents import (
25
  CodeAgent,
26
  HfApiModel,
@@ -59,8 +58,6 @@ AUTHORIZED_IMPORTS = [
59
  "fractions",
60
  "csv",
61
  "clean-text",
62
- "langchain",
63
- "llama_index", # Fixed trailing comma
64
  ]
65
 
66
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
@@ -104,52 +101,31 @@ def setup_environment():
104
 
105
  # ------------------------ Model and Tool Management ------------------------
106
  class ModelManager:
107
- """Manages model loading and initialization with Zhou Protocol patterns."""
108
-
109
- _instance = None
110
- _lock = threading.Lock()
111
-
112
- @classmethod
113
- def get_instance(cls):
114
- """Thread-safe singleton access to model manager."""
115
- if cls._instance is None:
116
- with cls._lock:
117
- if cls._instance is None:
118
- cls._instance = cls()
119
- return cls._instance
120
-
121
- def __init__(self):
122
- """Initialize with model cache."""
123
- self.model_cache = {}
124
-
125
- def load_model(self, chosen_inference: str, model_id: str, key_manager=None):
126
- """Load the specified model with appropriate configuration and caching."""
127
- cache_key = f"{chosen_inference}:{model_id}"
128
-
129
- # Return cached model if available
130
- if cache_key in self.model_cache:
131
- return self.model_cache[cache_key]
132
 
 
 
 
133
  try:
134
  if chosen_inference == "hf_api":
135
- model = HfApiModel(model_id=model_id)
136
 
137
  elif chosen_inference == "hf_api_provider":
138
- model = HfApiModel(provider="together")
139
 
140
  elif chosen_inference == "litellm":
141
- model = LiteLLMModel(model_id=model_id)
142
 
143
  elif chosen_inference == "openai":
144
  if not key_manager:
145
  raise ValueError("Key manager required for OpenAI model")
146
 
147
- model = OpenAIServerModel(
148
  model_id=model_id, api_key=key_manager.get_key("openai_api_key")
149
  )
150
 
151
  elif chosen_inference == "transformers":
152
- model = TransformersModel(
153
  model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
154
  device_map="auto",
155
  max_new_tokens=1000,
@@ -158,22 +134,13 @@ class ModelManager:
158
  else:
159
  raise ValueError(f"Invalid inference type: {chosen_inference}")
160
 
161
- # Cache the model for future use
162
- self.model_cache[cache_key] = model
163
- return model
164
-
165
  except Exception as e:
166
  print(f"✗ Couldn't load model: {e}")
167
  raise
168
 
169
 
170
  class ToolRegistry:
171
- """Manages tool initialization and organization with validation."""
172
-
173
- @staticmethod
174
- def validate_tools(tools: List[Tool]) -> List[Tool]:
175
- """Validate tools and filter out any None values."""
176
- return [tool for tool in tools if isinstance(tool, Tool)]
177
 
178
  @staticmethod
179
  def load_web_tools(model, browser, text_limit=20000):
@@ -200,93 +167,51 @@ class ToolRegistry:
200
  )
201
  except Exception as e:
202
  print(f"✗ Couldn't initialize image generation tool: {e}")
203
- return None
204
-
205
- @staticmethod
206
- def load_legal_document_tool():
207
- """Initialize and return the legal document processing tool."""
208
- try:
209
- # Create a simple instance with default parameters
210
- return LegalDocumentTool()
211
- except Exception as e:
212
- print(f"✗ Couldn't initialize legal document tool: {e}")
213
- # Return None instead of raising to make this tool optional
214
- return None
215
 
216
 
217
  # ------------------------ Agent Creation and Execution ------------------------
218
- class AgentFactory:
219
- """Factory for creating and managing agent instances with Zhou Protocol patterns."""
220
-
221
- _instance = None
222
- _lock = threading.Lock()
223
-
224
- @classmethod
225
- def get_instance(cls):
226
- """Thread-safe singleton access."""
227
- if cls._instance is None:
228
- with cls._lock:
229
- if cls._instance is None:
230
- cls._instance = cls()
231
- return cls._instance
232
-
233
- def __init__(self):
234
- """Initialize with agent cache."""
235
- self.agent_cache = {}
236
-
237
- def create_agent(self, session_id: str = "default") -> CodeAgent:
238
- """Creates a fresh agent instance with properly configured tools."""
239
- # Return cached agent if available for this session
240
- if session_id in self.agent_cache:
241
- return self.agent_cache[session_id]
242
-
243
- # Initialize model
244
- model = LiteLLMModel(
245
- custom_role_conversions=custom_role_conversions,
246
- model_id="openrouter/perplexity/r1-1776", # currently serving:
247
- ) # DEEPSEEK = openrouter/perplexity/r1-1776 <--- boss model
248
-
249
- # Initialize tools
250
- text_limit = 30000
251
- browser = SimpleTextBrowser(**BROWSER_CONFIG)
252
-
253
- # Collect all tools in a single list
254
- web_tools = ToolRegistry.load_web_tools(model, browser, text_limit)
255
- image_generator = ToolRegistry.load_image_generation_tools()
256
- legal_tool = ToolRegistry.load_legal_document_tool()
257
-
258
- # Combine and validate all tools
259
- all_tools = [visualizer] + web_tools
260
-
261
- # Only add tools that are properly initialized (not None)
262
- if image_generator:
263
- all_tools.append(image_generator)
264
-
265
- if legal_tool:
266
- all_tools.append(legal_tool)
267
-
268
- # Final validation to ensure all tools are valid
269
- all_tools = ToolRegistry.validate_tools(all_tools)
270
-
271
- agent = CodeAgent(
272
- model=model,
273
- tools=all_tools, # Pass a single list containing all tools
274
- max_steps=10,
275
- verbosity_level=1,
276
- additional_authorized_imports=AUTHORIZED_IMPORTS,
277
- planning_interval=4,
278
- )
279
 
280
- # Cache the agent for future use
281
- self.agent_cache[session_id] = agent
282
- return agent
 
 
 
 
 
283
 
284
 
285
  def stream_to_gradio(
286
  agent,
287
  task: str,
288
  reset_agent_memory: bool = False,
289
- additional_args: Optional[Dict[str, Any]] = None,
290
  ):
291
  """Runs an agent with the given task and streams messages as Gradio ChatMessages."""
292
  for step_log in agent.run(
@@ -322,50 +247,44 @@ def stream_to_gradio(
322
 
323
  # ------------------------ Gradio UI Components ------------------------
324
  class GradioUI:
325
- """A streamlined interface to launch your agent in Gradio with Zhou Protocol patterns."""
326
 
327
- def __init__(self, file_upload_folder: Optional[str] = None):
328
  """Initialize the Gradio UI with optional file upload functionality."""
329
  self.file_upload_folder = file_upload_folder
330
- self.agent_factory = AgentFactory.get_instance()
331
 
332
  if self.file_upload_folder is not None:
333
- os.makedirs(self.file_upload_folder, exist_ok=True)
 
334
 
335
  def interact_with_agent(self, prompt, messages, session_state):
336
  """Main interaction handler with the agent."""
337
 
338
- # Generate unique session ID if not present
339
- if "session_id" not in session_state:
340
- session_state["session_id"] = f"session_{id(session_state)}"
341
-
342
  # Get or create session-specific agent
343
- agent = self.agent_factory.create_agent(session_state["session_id"])
 
344
 
345
  # Adding monitoring
346
  try:
347
  # Log the existence of agent memory
348
- has_memory = hasattr(agent, "memory")
349
  print(f"Agent has memory: {has_memory}")
350
  if has_memory:
351
- print(f"Memory type: {type(agent.memory)}")
352
 
353
  messages.append(gr.ChatMessage(role="user", content=prompt))
354
  yield messages
355
 
356
- for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False):
 
 
357
  messages.append(msg)
358
  yield messages # Yield messages after each step
359
  yield messages # Yield messages one last time
360
 
361
  except Exception as e:
362
  print(f"Error in interaction: {str(e)}")
363
- messages.append(
364
- gr.ChatMessage(
365
- role="assistant", content=f"Error processing request: {str(e)}"
366
- )
367
- )
368
- yield messages
369
 
370
  def upload_file(
371
  self,
@@ -378,8 +297,6 @@ class GradioUI:
378
 
379
  try:
380
  mime_type, _ = mimetypes.guess_type(file.name)
381
- if not mime_type:
382
- return gr.Textbox("Unknown file type", visible=True), file_uploads_log
383
  except Exception as e:
384
  return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
385
 
@@ -427,7 +344,7 @@ class GradioUI:
427
  """Process user message and handle file references."""
428
  message = text_input
429
 
430
- if file_uploads_log and len(file_uploads_log) > 0:
431
  message += f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" # Added file list
432
 
433
  return (
@@ -443,7 +360,7 @@ class GradioUI:
443
  def detect_device(self, request: gr.Request):
444
  """Detect whether the user is on mobile or desktop device."""
445
  if not request:
446
- return "Desktop" # Default to desktop for safety
447
 
448
  # Method 1: Check sec-ch-ua-mobile header
449
  is_mobile_header = request.headers.get("sec-ch-ua-mobile")
@@ -492,7 +409,7 @@ class GradioUI:
492
  with gr.Sidebar():
493
  gr.Markdown(
494
  """#OpenDeepResearch - 3theSmolagents!
495
- Model_id: R1-1776"""
496
  )
497
  with gr.Group():
498
  gr.Markdown("**What's on your mind mate?**", container=True)
@@ -534,9 +451,8 @@ class GradioUI:
534
  # Add session state to store session-specific data
535
  session_state = gr.State({}) # Initialize empty state for each session
536
  stored_messages = gr.State([])
537
-
538
- # Ensure file_uploads_log is always defined
539
- file_uploads_log = gr.State([])
540
 
541
  chatbot = gr.Chatbot(
542
  label="open-Deep-Research",
 
2
  import os
3
  import re
4
  import shutil
5
+ from typing import Optional
 
6
 
7
  from dotenv import load_dotenv
8
  from huggingface_hub import login
 
19
  VisitTool,
20
  )
21
  from scripts.visual_qa import visualizer
22
+
23
  from smolagents import (
24
  CodeAgent,
25
  HfApiModel,
 
58
  "fractions",
59
  "csv",
60
  "clean-text",
 
 
61
  ]
62
 
63
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
 
101
 
102
  # ------------------------ Model and Tool Management ------------------------
103
  class ModelManager:
104
+ """Manages model loading and initialization."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ @staticmethod
107
+ def load_model(chosen_inference: str, model_id: str, key_manager=None):
108
+ """Load the specified model with appropriate configuration."""
109
  try:
110
  if chosen_inference == "hf_api":
111
+ return HfApiModel(model_id=model_id)
112
 
113
  elif chosen_inference == "hf_api_provider":
114
+ return HfApiModel(provider="together")
115
 
116
  elif chosen_inference == "litellm":
117
+ return LiteLLMModel(model_id=model_id)
118
 
119
  elif chosen_inference == "openai":
120
  if not key_manager:
121
  raise ValueError("Key manager required for OpenAI model")
122
 
123
+ return OpenAIServerModel(
124
  model_id=model_id, api_key=key_manager.get_key("openai_api_key")
125
  )
126
 
127
  elif chosen_inference == "transformers":
128
+ return TransformersModel(
129
  model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
130
  device_map="auto",
131
  max_new_tokens=1000,
 
134
  else:
135
  raise ValueError(f"Invalid inference type: {chosen_inference}")
136
 
 
 
 
 
137
  except Exception as e:
138
  print(f"✗ Couldn't load model: {e}")
139
  raise
140
 
141
 
142
  class ToolRegistry:
143
+ """Manages tool initialization and organization."""
 
 
 
 
 
144
 
145
  @staticmethod
146
  def load_web_tools(model, browser, text_limit=20000):
 
167
  )
168
  except Exception as e:
169
  print(f"✗ Couldn't initialize image generation tool: {e}")
170
+ raise
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  # ------------------------ Agent Creation and Execution ------------------------
174
+ def create_agent():
175
+ """Creates a fresh agent instance with properly configured tools."""
176
+ # Initialize model
177
+ model = LiteLLMModel(
178
+ custom_role_conversions=custom_role_conversions,
179
+ model_id="openrouter/google/gemini-2.0-flash-001", # currently serving:
180
+ ) # DEEPSEEK = openrouter/perplexity/r1-1776 <--- boss model
181
+
182
+ # Initialize tools
183
+ text_limit = 30000
184
+ browser = SimpleTextBrowser(**BROWSER_CONFIG)
185
+
186
+ # Collect all tools in a single list
187
+ web_tools = ToolRegistry.load_web_tools(model, browser, text_limit)
188
+ image_generator = ToolRegistry.load_image_generation_tools()
189
+
190
+ # Combine all tools into a single list (not a tuple)
191
+ all_tools = [visualizer] + web_tools + [image_generator]
192
+
193
+ # Validate tools before creating agent
194
+ for tool in all_tools:
195
+ if not isinstance(tool, Tool):
196
+ raise ValueError(
197
+ f"Invalid tool type: {type(tool)}. All tools must be instances of Tool class."
198
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ return CodeAgent(
201
+ model=model,
202
+ tools=all_tools, # Pass a single list containing all tools
203
+ max_steps=10,
204
+ verbosity_level=1,
205
+ additional_authorized_imports=AUTHORIZED_IMPORTS,
206
+ planning_interval=4,
207
+ )
208
 
209
 
210
  def stream_to_gradio(
211
  agent,
212
  task: str,
213
  reset_agent_memory: bool = False,
214
+ additional_args: Optional[dict] = None,
215
  ):
216
  """Runs an agent with the given task and streams messages as Gradio ChatMessages."""
217
  for step_log in agent.run(
 
247
 
248
  # ------------------------ Gradio UI Components ------------------------
249
  class GradioUI:
250
+ """A one-line interface to launch your agent in Gradio."""
251
 
252
+ def __init__(self, file_upload_folder: str | None = None):
253
  """Initialize the Gradio UI with optional file upload functionality."""
254
  self.file_upload_folder = file_upload_folder
 
255
 
256
  if self.file_upload_folder is not None:
257
+ if not os.path.exists(file_upload_folder):
258
+ os.mkdir(file_upload_folder)
259
 
260
  def interact_with_agent(self, prompt, messages, session_state):
261
  """Main interaction handler with the agent."""
262
 
 
 
 
 
263
  # Get or create session-specific agent
264
+ if "agent" not in session_state:
265
+ session_state["agent"] = create_agent()
266
 
267
  # Adding monitoring
268
  try:
269
  # Log the existence of agent memory
270
+ has_memory = hasattr(session_state["agent"], "memory")
271
  print(f"Agent has memory: {has_memory}")
272
  if has_memory:
273
+ print(f"Memory type: {type(session_state['agent'].memory)}")
274
 
275
  messages.append(gr.ChatMessage(role="user", content=prompt))
276
  yield messages
277
 
278
+ for msg in stream_to_gradio(
279
+ session_state["agent"], task=prompt, reset_agent_memory=False
280
+ ):
281
  messages.append(msg)
282
  yield messages # Yield messages after each step
283
  yield messages # Yield messages one last time
284
 
285
  except Exception as e:
286
  print(f"Error in interaction: {str(e)}")
287
+ raise
 
 
 
 
 
288
 
289
  def upload_file(
290
  self,
 
297
 
298
  try:
299
  mime_type, _ = mimetypes.guess_type(file.name)
 
 
300
  except Exception as e:
301
  return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
302
 
 
344
  """Process user message and handle file references."""
345
  message = text_input
346
 
347
+ if len(file_uploads_log) > 0:
348
  message += f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" # Added file list
349
 
350
  return (
 
360
  def detect_device(self, request: gr.Request):
361
  """Detect whether the user is on mobile or desktop device."""
362
  if not request:
363
+ return "Unknown device" # Handle case where request is none.
364
 
365
  # Method 1: Check sec-ch-ua-mobile header
366
  is_mobile_header = request.headers.get("sec-ch-ua-mobile")
 
409
  with gr.Sidebar():
410
  gr.Markdown(
411
  """#OpenDeepResearch - 3theSmolagents!
412
+ Model_id: google/gemini-2.0-flash-001"""
413
  )
414
  with gr.Group():
415
  gr.Markdown("**What's on your mind mate?**", container=True)
 
451
  # Add session state to store session-specific data
452
  session_state = gr.State({}) # Initialize empty state for each session
453
  stored_messages = gr.State([])
454
+ if "file_uploads_log" not in locals():
455
+ file_uploads_log = gr.State([])
 
456
 
457
  chatbot = gr.Chatbot(
458
  label="open-Deep-Research",