Leonardo commited on
Commit
0cd91f2
·
verified ·
1 Parent(s): d2dfe27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -64
app.py CHANGED
@@ -2,7 +2,8 @@ import mimetypes
2
  import os
3
  import re
4
  import shutil
5
- from typing import Optional
 
6
 
7
  from dotenv import load_dotenv
8
  from huggingface_hub import login
@@ -19,7 +20,7 @@ from scripts.text_web_browser import (
19
  VisitTool,
20
  )
21
  from scripts.visual_qa import visualizer
22
-
23
  from smolagents import (
24
  CodeAgent,
25
  HfApiModel,
@@ -58,6 +59,8 @@ AUTHORIZED_IMPORTS = [
58
  "fractions",
59
  "csv",
60
  "clean-text",
 
 
61
  ]
62
 
63
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
@@ -101,31 +104,52 @@ def setup_environment():
101
 
102
  # ------------------------ Model and Tool Management ------------------------
103
  class ModelManager:
104
- """Manages model loading and initialization."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- @staticmethod
107
- def load_model(chosen_inference: str, model_id: str, key_manager=None):
108
- """Load the specified model with appropriate configuration."""
109
  try:
110
  if chosen_inference == "hf_api":
111
- return HfApiModel(model_id=model_id)
112
 
113
  elif chosen_inference == "hf_api_provider":
114
- return HfApiModel(provider="together")
115
 
116
  elif chosen_inference == "litellm":
117
- return LiteLLMModel(model_id=model_id)
118
 
119
  elif chosen_inference == "openai":
120
  if not key_manager:
121
  raise ValueError("Key manager required for OpenAI model")
122
 
123
- return OpenAIServerModel(
124
  model_id=model_id, api_key=key_manager.get_key("openai_api_key")
125
  )
126
 
127
  elif chosen_inference == "transformers":
128
- return TransformersModel(
129
  model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
130
  device_map="auto",
131
  max_new_tokens=1000,
@@ -134,13 +158,22 @@ class ModelManager:
134
  else:
135
  raise ValueError(f"Invalid inference type: {chosen_inference}")
136
 
 
 
 
 
137
  except Exception as e:
138
  print(f"✗ Couldn't load model: {e}")
139
  raise
140
 
141
 
142
  class ToolRegistry:
143
- """Manages tool initialization and organization."""
 
 
 
 
 
144
 
145
  @staticmethod
146
  def load_web_tools(model, browser, text_limit=20000):
@@ -167,51 +200,93 @@ class ToolRegistry:
167
  )
168
  except Exception as e:
169
  print(f"✗ Couldn't initialize image generation tool: {e}")
170
- raise
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  # ------------------------ Agent Creation and Execution ------------------------
174
- def create_agent():
175
- """Creates a fresh agent instance with properly configured tools."""
176
- # Initialize model
177
- model = LiteLLMModel(
178
- custom_role_conversions=custom_role_conversions,
179
- model_id="openrouter/google/gemini-2.0-flash-001", # currently serving:
180
- ) # DEEPSEEK = openrouter/perplexity/r1-1776 <--- boss model
181
-
182
- # Initialize tools
183
- text_limit = 30000
184
- browser = SimpleTextBrowser(**BROWSER_CONFIG)
185
-
186
- # Collect all tools in a single list
187
- web_tools = ToolRegistry.load_web_tools(model, browser, text_limit)
188
- image_generator = ToolRegistry.load_image_generation_tools()
189
-
190
- # Combine all tools into a single list (not a tuple)
191
- all_tools = [visualizer] + web_tools + [image_generator]
192
-
193
- # Validate tools before creating agent
194
- for tool in all_tools:
195
- if not isinstance(tool, Tool):
196
- raise ValueError(
197
- f"Invalid tool type: {type(tool)}. All tools must be instances of Tool class."
198
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- return CodeAgent(
201
- model=model,
202
- tools=all_tools, # Pass a single list containing all tools
203
- max_steps=10,
204
- verbosity_level=1,
205
- additional_authorized_imports=AUTHORIZED_IMPORTS,
206
- planning_interval=4,
207
- )
208
 
209
 
210
  def stream_to_gradio(
211
  agent,
212
  task: str,
213
  reset_agent_memory: bool = False,
214
- additional_args: Optional[dict] = None,
215
  ):
216
  """Runs an agent with the given task and streams messages as Gradio ChatMessages."""
217
  for step_log in agent.run(
@@ -247,44 +322,50 @@ def stream_to_gradio(
247
 
248
  # ------------------------ Gradio UI Components ------------------------
249
  class GradioUI:
250
- """A one-line interface to launch your agent in Gradio."""
251
 
252
- def __init__(self, file_upload_folder: str | None = None):
253
  """Initialize the Gradio UI with optional file upload functionality."""
254
  self.file_upload_folder = file_upload_folder
 
255
 
256
  if self.file_upload_folder is not None:
257
- if not os.path.exists(file_upload_folder):
258
- os.mkdir(file_upload_folder)
259
 
260
  def interact_with_agent(self, prompt, messages, session_state):
261
  """Main interaction handler with the agent."""
262
 
 
 
 
 
263
  # Get or create session-specific agent
264
- if "agent" not in session_state:
265
- session_state["agent"] = create_agent()
266
 
267
  # Adding monitoring
268
  try:
269
  # Log the existence of agent memory
270
- has_memory = hasattr(session_state["agent"], "memory")
271
  print(f"Agent has memory: {has_memory}")
272
  if has_memory:
273
- print(f"Memory type: {type(session_state['agent'].memory)}")
274
 
275
  messages.append(gr.ChatMessage(role="user", content=prompt))
276
  yield messages
277
 
278
- for msg in stream_to_gradio(
279
- session_state["agent"], task=prompt, reset_agent_memory=False
280
- ):
281
  messages.append(msg)
282
  yield messages # Yield messages after each step
283
  yield messages # Yield messages one last time
284
 
285
  except Exception as e:
286
  print(f"Error in interaction: {str(e)}")
287
- raise
 
 
 
 
 
288
 
289
  def upload_file(
290
  self,
@@ -297,6 +378,8 @@ class GradioUI:
297
 
298
  try:
299
  mime_type, _ = mimetypes.guess_type(file.name)
 
 
300
  except Exception as e:
301
  return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
302
 
@@ -344,7 +427,7 @@ class GradioUI:
344
  """Process user message and handle file references."""
345
  message = text_input
346
 
347
- if len(file_uploads_log) > 0:
348
  message += f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" # Added file list
349
 
350
  return (
@@ -360,7 +443,7 @@ class GradioUI:
360
  def detect_device(self, request: gr.Request):
361
  """Detect whether the user is on mobile or desktop device."""
362
  if not request:
363
- return "Unknown device" # Handle case where request is none.
364
 
365
  # Method 1: Check sec-ch-ua-mobile header
366
  is_mobile_header = request.headers.get("sec-ch-ua-mobile")
@@ -409,7 +492,7 @@ class GradioUI:
409
  with gr.Sidebar():
410
  gr.Markdown(
411
  """#OpenDeepResearch - 3theSmolagents!
412
- Model_id: google/gemini-2.0-flash-001"""
413
  )
414
  with gr.Group():
415
  gr.Markdown("**What's on your mind mate?**", container=True)
@@ -451,8 +534,9 @@ class GradioUI:
451
  # Add session state to store session-specific data
452
  session_state = gr.State({}) # Initialize empty state for each session
453
  stored_messages = gr.State([])
454
- if "file_uploads_log" not in locals():
455
- file_uploads_log = gr.State([])
 
456
 
457
  chatbot = gr.Chatbot(
458
  label="open-Deep-Research",
 
2
  import os
3
  import re
4
  import shutil
5
+ import threading
6
+ from typing import Optional, List, Dict, Any
7
 
8
  from dotenv import load_dotenv
9
  from huggingface_hub import login
 
20
  VisitTool,
21
  )
22
  from scripts.visual_qa import visualizer
23
+ from scripts.legal_document_tool import LegalDocumentTool
24
  from smolagents import (
25
  CodeAgent,
26
  HfApiModel,
 
59
  "fractions",
60
  "csv",
61
  "clean-text",
62
+ "langchain",
63
+ "llama_index", # Fixed trailing comma
64
  ]
65
 
66
  user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
 
104
 
105
  # ------------------------ Model and Tool Management ------------------------
106
  class ModelManager:
107
+ """Manages model loading and initialization with Zhou Protocol patterns."""
108
+
109
+ _instance = None
110
+ _lock = threading.Lock()
111
+
112
+ @classmethod
113
+ def get_instance(cls):
114
+ """Thread-safe singleton access to model manager."""
115
+ if cls._instance is None:
116
+ with cls._lock:
117
+ if cls._instance is None:
118
+ cls._instance = cls()
119
+ return cls._instance
120
+
121
+ def __init__(self):
122
+ """Initialize with model cache."""
123
+ self.model_cache = {}
124
+
125
+ def load_model(self, chosen_inference: str, model_id: str, key_manager=None):
126
+ """Load the specified model with appropriate configuration and caching."""
127
+ cache_key = f"{chosen_inference}:{model_id}"
128
+
129
+ # Return cached model if available
130
+ if cache_key in self.model_cache:
131
+ return self.model_cache[cache_key]
132
 
 
 
 
133
  try:
134
  if chosen_inference == "hf_api":
135
+ model = HfApiModel(model_id=model_id)
136
 
137
  elif chosen_inference == "hf_api_provider":
138
+ model = HfApiModel(provider="together")
139
 
140
  elif chosen_inference == "litellm":
141
+ model = LiteLLMModel(model_id=model_id)
142
 
143
  elif chosen_inference == "openai":
144
  if not key_manager:
145
  raise ValueError("Key manager required for OpenAI model")
146
 
147
+ model = OpenAIServerModel(
148
  model_id=model_id, api_key=key_manager.get_key("openai_api_key")
149
  )
150
 
151
  elif chosen_inference == "transformers":
152
+ model = TransformersModel(
153
  model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct",
154
  device_map="auto",
155
  max_new_tokens=1000,
 
158
  else:
159
  raise ValueError(f"Invalid inference type: {chosen_inference}")
160
 
161
+ # Cache the model for future use
162
+ self.model_cache[cache_key] = model
163
+ return model
164
+
165
  except Exception as e:
166
  print(f"✗ Couldn't load model: {e}")
167
  raise
168
 
169
 
170
  class ToolRegistry:
171
+ """Manages tool initialization and organization with validation."""
172
+
173
+ @staticmethod
174
+ def validate_tools(tools: List[Tool]) -> List[Tool]:
175
+ """Validate tools and filter out any None values."""
176
+ return [tool for tool in tools if isinstance(tool, Tool)]
177
 
178
  @staticmethod
179
  def load_web_tools(model, browser, text_limit=20000):
 
200
  )
201
  except Exception as e:
202
  print(f"✗ Couldn't initialize image generation tool: {e}")
203
+ return None
204
+
205
+ @staticmethod
206
+ def load_legal_document_tool():
207
+ """Initialize and return the legal document processing tool."""
208
+ try:
209
+ # Create a simple instance with default parameters
210
+ return LegalDocumentTool()
211
+ except Exception as e:
212
+ print(f"✗ Couldn't initialize legal document tool: {e}")
213
+ # Return None instead of raising to make this tool optional
214
+ return None
215
 
216
 
217
  # ------------------------ Agent Creation and Execution ------------------------
218
+ class AgentFactory:
219
+ """Factory for creating and managing agent instances with Zhou Protocol patterns."""
220
+
221
+ _instance = None
222
+ _lock = threading.Lock()
223
+
224
+ @classmethod
225
+ def get_instance(cls):
226
+ """Thread-safe singleton access."""
227
+ if cls._instance is None:
228
+ with cls._lock:
229
+ if cls._instance is None:
230
+ cls._instance = cls()
231
+ return cls._instance
232
+
233
+ def __init__(self):
234
+ """Initialize with agent cache."""
235
+ self.agent_cache = {}
236
+
237
+ def create_agent(self, session_id: str = "default") -> CodeAgent:
238
+ """Creates a fresh agent instance with properly configured tools."""
239
+ # Return cached agent if available for this session
240
+ if session_id in self.agent_cache:
241
+ return self.agent_cache[session_id]
242
+
243
+ # Initialize model
244
+ model = LiteLLMModel(
245
+ custom_role_conversions=custom_role_conversions,
246
+ model_id="openrouter/perplexity/r1-1776", # currently serving:
247
+ ) # DEEPSEEK = openrouter/perplexity/r1-1776 <--- boss model
248
+
249
+ # Initialize tools
250
+ text_limit = 30000
251
+ browser = SimpleTextBrowser(**BROWSER_CONFIG)
252
+
253
+ # Collect all tools in a single list
254
+ web_tools = ToolRegistry.load_web_tools(model, browser, text_limit)
255
+ image_generator = ToolRegistry.load_image_generation_tools()
256
+ legal_tool = ToolRegistry.load_legal_document_tool()
257
+
258
+ # Combine and validate all tools
259
+ all_tools = [visualizer] + web_tools
260
+
261
+ # Only add tools that are properly initialized (not None)
262
+ if image_generator:
263
+ all_tools.append(image_generator)
264
+
265
+ if legal_tool:
266
+ all_tools.append(legal_tool)
267
+
268
+ # Final validation to ensure all tools are valid
269
+ all_tools = ToolRegistry.validate_tools(all_tools)
270
+
271
+ agent = CodeAgent(
272
+ model=model,
273
+ tools=all_tools, # Pass a single list containing all tools
274
+ max_steps=10,
275
+ verbosity_level=1,
276
+ additional_authorized_imports=AUTHORIZED_IMPORTS,
277
+ planning_interval=4,
278
+ )
279
 
280
+ # Cache the agent for future use
281
+ self.agent_cache[session_id] = agent
282
+ return agent
 
 
 
 
 
283
 
284
 
285
  def stream_to_gradio(
286
  agent,
287
  task: str,
288
  reset_agent_memory: bool = False,
289
+ additional_args: Optional[Dict[str, Any]] = None,
290
  ):
291
  """Runs an agent with the given task and streams messages as Gradio ChatMessages."""
292
  for step_log in agent.run(
 
322
 
323
  # ------------------------ Gradio UI Components ------------------------
324
  class GradioUI:
325
+ """A streamlined interface to launch your agent in Gradio with Zhou Protocol patterns."""
326
 
327
+ def __init__(self, file_upload_folder: Optional[str] = None):
328
  """Initialize the Gradio UI with optional file upload functionality."""
329
  self.file_upload_folder = file_upload_folder
330
+ self.agent_factory = AgentFactory.get_instance()
331
 
332
  if self.file_upload_folder is not None:
333
+ os.makedirs(self.file_upload_folder, exist_ok=True)
 
334
 
335
  def interact_with_agent(self, prompt, messages, session_state):
336
  """Main interaction handler with the agent."""
337
 
338
+ # Generate unique session ID if not present
339
+ if "session_id" not in session_state:
340
+ session_state["session_id"] = f"session_{id(session_state)}"
341
+
342
  # Get or create session-specific agent
343
+ agent = self.agent_factory.create_agent(session_state["session_id"])
 
344
 
345
  # Adding monitoring
346
  try:
347
  # Log the existence of agent memory
348
+ has_memory = hasattr(agent, "memory")
349
  print(f"Agent has memory: {has_memory}")
350
  if has_memory:
351
+ print(f"Memory type: {type(agent.memory)}")
352
 
353
  messages.append(gr.ChatMessage(role="user", content=prompt))
354
  yield messages
355
 
356
+ for msg in stream_to_gradio(agent, task=prompt, reset_agent_memory=False):
 
 
357
  messages.append(msg)
358
  yield messages # Yield messages after each step
359
  yield messages # Yield messages one last time
360
 
361
  except Exception as e:
362
  print(f"Error in interaction: {str(e)}")
363
+ messages.append(
364
+ gr.ChatMessage(
365
+ role="assistant", content=f"Error processing request: {str(e)}"
366
+ )
367
+ )
368
+ yield messages
369
 
370
  def upload_file(
371
  self,
 
378
 
379
  try:
380
  mime_type, _ = mimetypes.guess_type(file.name)
381
+ if not mime_type:
382
+ return gr.Textbox("Unknown file type", visible=True), file_uploads_log
383
  except Exception as e:
384
  return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
385
 
 
427
  """Process user message and handle file references."""
428
  message = text_input
429
 
430
+ if file_uploads_log and len(file_uploads_log) > 0:
431
  message += f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" # Added file list
432
 
433
  return (
 
443
  def detect_device(self, request: gr.Request):
444
  """Detect whether the user is on mobile or desktop device."""
445
  if not request:
446
+ return "Desktop" # Default to desktop for safety
447
 
448
  # Method 1: Check sec-ch-ua-mobile header
449
  is_mobile_header = request.headers.get("sec-ch-ua-mobile")
 
492
  with gr.Sidebar():
493
  gr.Markdown(
494
  """#OpenDeepResearch - 3theSmolagents!
495
+ Model_id: R1-1776"""
496
  )
497
  with gr.Group():
498
  gr.Markdown("**What's on your mind mate?**", container=True)
 
534
  # Add session state to store session-specific data
535
  session_state = gr.State({}) # Initialize empty state for each session
536
  stored_messages = gr.State([])
537
+
538
+ # Ensure file_uploads_log is always defined
539
+ file_uploads_log = gr.State([])
540
 
541
  chatbot = gr.Chatbot(
542
  label="open-Deep-Research",