Pulastya B commited on
Commit
e29cf28
Β·
1 Parent(s): a8b0cca

Fixed slow Agent loading which caused the Agent took too long to respond

Browse files
Files changed (2) hide show
  1. src/api/app.py +56 -51
  2. src/orchestrator.py +4 -3
src/api/app.py CHANGED
@@ -28,6 +28,7 @@ import numpy as np
28
  # Import from parent package
29
  from src.orchestrator import DataScienceCopilot
30
  from src.progress_manager import progress_manager
 
31
 
32
  # Configure logging
33
  logging.basicConfig(level=logging.INFO)
@@ -78,10 +79,6 @@ app.add_middleware(
78
  allow_headers=["*"],
79
  )
80
 
81
- # Initialize agent once (singleton pattern for stateless service)
82
- # Agent itself is stateless - no conversation memory between requests
83
- agent: Optional[DataScienceCopilot] = None
84
-
85
  # SSE event queues for real-time streaming
86
  class ProgressEventManager:
87
  """Manages SSE connections and progress events for real-time updates."""
@@ -151,69 +148,80 @@ class ProgressEventManager:
151
  if session_id in self.session_status:
152
  del self.session_status[session_id]
153
 
154
- # πŸ‘₯ MULTI-USER SUPPORT: Per-session agent instances
155
- # Instead of one global agent, create isolated instances per session
156
- # This prevents users from interfering with each other's workflows
157
- agent_cache: Dict[str, DataScienceCopilot] = {} # session_id -> agent instance
158
  agent_cache_lock = asyncio.Lock()
159
- MAX_CACHED_AGENTS = 10 # Limit memory usage
160
- logger.info("πŸ‘₯ Multi-user agent cache initialized")
161
 
162
- # Legacy global agent for backward compatibility (will be deprecated)
 
 
163
  agent = None
164
 
165
- # πŸ‘₯ MULTI-USER SUPPORT: Per-session agent instances
166
- # Instead of one global agent, create isolated instances per session
167
- # This prevents users from interfering with each other's workflows
168
- agent_cache: Dict[str, DataScienceCopilot] = {} # session_id -> agent instance
169
- agent_cache_lock = asyncio.Lock()
170
- MAX_CACHED_AGENTS = 10 # Limit memory usage
171
- logger.info("πŸ‘₯ Multi-user agent cache initialized")
172
-
173
- # Legacy global agent for backward compatibility (will be deprecated)
174
- agent = None
175
 
176
 
177
  async def get_agent_for_session(session_id: str) -> DataScienceCopilot:
178
  """
179
- Get or create an isolated agent instance for a session.
180
 
181
- This ensures each user gets their own agent with isolated state,
182
- preventing session collisions and race conditions.
 
 
183
 
184
  Args:
185
  session_id: Unique session identifier
186
 
187
  Returns:
188
- DataScienceCopilot instance for this session
189
  """
 
 
190
  async with agent_cache_lock:
191
- # Return existing agent if cached
192
- if session_id in agent_cache:
193
- logger.info(f"[♻️] Reusing cached agent for session {session_id[:8]}...")
194
- return agent_cache[session_id]
 
 
 
 
 
195
 
196
- # Create new agent instance
197
- logger.info(f"[πŸ†•] Creating new agent for session {session_id[:8]}...")
198
- provider = os.getenv("LLM_PROVIDER", "mistral")
 
 
 
199
 
200
- new_agent = DataScienceCopilot(
201
- reasoning_effort="medium",
202
- provider=provider,
203
- use_compact_prompts=False, # Multi-agent architecture
204
- session_id=session_id # Pass session_id for isolation
205
- )
206
 
 
207
  # Cache management: Remove oldest if cache is full
208
- if len(agent_cache) >= MAX_CACHED_AGENTS:
209
- oldest_session = next(iter(agent_cache))
210
  logger.info(f"[πŸ—‘οΈ] Cache full, removing session {oldest_session[:8]}...")
211
- del agent_cache[oldest_session]
 
 
 
 
 
 
212
 
213
- agent_cache[session_id] = new_agent
214
- logger.info(f"βœ… Agent created for session {session_id[:8]} (cache: {len(agent_cache)}/{MAX_CACHED_AGENTS})")
215
 
216
- return new_agent
217
 
218
  # πŸ”’ REQUEST QUEUING: Global lock to prevent concurrent workflows
219
  # This ensures only one analysis runs at a time, preventing:
@@ -460,15 +468,12 @@ async def run_analysis_async(
460
 
461
  logger.info(f"[ASYNC] File saved: {file.filename}")
462
  else:
463
- # πŸ›‘οΈ VALIDATION: For follow-up queries, check if any cached agent has dataset
464
- # Note: In true multi-user setup, you'd need session_id from frontend to match exact session
465
  has_dataset = False
466
  async with agent_cache_lock:
467
- for cached_agent in agent_cache.values():
468
- if hasattr(cached_agent, 'session') and cached_agent.session and cached_agent.session.last_dataset:
469
- has_dataset = True
470
- logger.info(f"[ASYNC] Follow-up query using cached session data")
471
- break
472
 
473
  if not has_dataset:
474
  logger.warning("[ASYNC] No file uploaded and no session dataset available")
 
28
  # Import from parent package
29
  from src.orchestrator import DataScienceCopilot
30
  from src.progress_manager import progress_manager
31
+ from src.session_memory import SessionMemory
32
 
33
  # Configure logging
34
  logging.basicConfig(level=logging.INFO)
 
79
  allow_headers=["*"],
80
  )
81
 
 
 
 
 
82
  # SSE event queues for real-time streaming
83
  class ProgressEventManager:
84
  """Manages SSE connections and progress events for real-time updates."""
 
148
  if session_id in self.session_status:
149
  del self.session_status[session_id]
150
 
151
+ # πŸ‘₯ MULTI-USER SUPPORT: Session state isolation
152
+ # Heavy components (SBERT, tools, LLM client) are shared via global 'agent'
153
+ # Only session memory is isolated per user for fast initialization
154
+ session_states: Dict[str, Any] = {} # session_id -> SessionMemory
155
  agent_cache_lock = asyncio.Lock()
156
+ MAX_CACHED_AGENTS = 10 # Limit memory usage (session states are lightweight)
157
+ logger.info("πŸ‘₯ Multi-user session isolation initialized (fast mode)")
158
 
159
+ # Global agent - Heavy components loaded ONCE at startup
160
+ # SBERT model, tool functions, LLM client are shared across all users
161
+ agent: Optional[DataScienceCopilot] = None
162
  agent = None
163
 
164
+ # Session state isolation (lightweight - just session memory)
165
+ session_states: Dict[str, any] = {} # session_id -> session memory only
 
 
 
 
 
 
 
 
166
 
167
 
168
  async def get_agent_for_session(session_id: str) -> DataScienceCopilot:
169
  """
170
+ Get agent with isolated session state.
171
 
172
+ OPTIMIZATION: Instead of creating a full new agent per session (slow!),
173
+ we reuse the global agent but swap session memory per request.
174
+ Heavy components (SBERT, tools, LLM client) are shared.
175
+ This reduces per-user initialization from 20s to <1s.
176
 
177
  Args:
178
  session_id: Unique session identifier
179
 
180
  Returns:
181
+ DataScienceCopilot instance with isolated session for this user
182
  """
183
+ global agent
184
+
185
  async with agent_cache_lock:
186
+ # Ensure base agent exists (heavy components loaded once at startup)
187
+ if agent is None:
188
+ logger.warning("Base agent not initialized - this shouldn't happen after startup")
189
+ provider = os.getenv("LLM_PROVIDER", "mistral")
190
+ agent = DataScienceCopilot(
191
+ reasoning_effort="medium",
192
+ provider=provider,
193
+ use_compact_prompts=False
194
+ )
195
 
196
+ # Check if we have cached session memory for this session
197
+ if session_id in session_states:
198
+ logger.info(f"[♻️] Reusing session state for {session_id[:8]}...")
199
+ agent.session = session_states[session_id]
200
+ agent.http_session_key = session_id
201
+ return agent
202
 
203
+ # πŸš€ FAST PATH: Create new session memory only (no SBERT reload!)
204
+ logger.info(f"[πŸ†•] Creating lightweight session for {session_id[:8]}...")
205
+
206
+ # Create isolated session memory for this user
207
+ new_session = SessionMemory(session_id=session_id)
 
208
 
209
+ # Cache session memory (lightweight)
210
  # Cache management: Remove oldest if cache is full
211
+ if len(session_states) >= MAX_CACHED_AGENTS:
212
+ oldest_session = next(iter(session_states))
213
  logger.info(f"[πŸ—‘οΈ] Cache full, removing session {oldest_session[:8]}...")
214
+ del session_states[oldest_session]
215
+
216
+ session_states[session_id] = new_session
217
+
218
+ # Set session on shared agent
219
+ agent.session = new_session
220
+ agent.http_session_key = session_id
221
 
222
+ logger.info(f"βœ… Session created for {session_id[:8]} (cache: {len(session_states)}/{MAX_CACHED_AGENTS}) - <1s init")
 
223
 
224
+ return agent
225
 
226
  # πŸ”’ REQUEST QUEUING: Global lock to prevent concurrent workflows
227
  # This ensures only one analysis runs at a time, preventing:
 
468
 
469
  logger.info(f"[ASYNC] File saved: {file.filename}")
470
  else:
471
+ # πŸ›‘οΈ VALIDATION: Check if agent's current session has dataset
 
472
  has_dataset = False
473
  async with agent_cache_lock:
474
+ if agent and hasattr(agent, 'session') and agent.session and hasattr(agent.session, 'last_dataset') and agent.session.last_dataset:
475
+ has_dataset = True
476
+ logger.info(f"[ASYNC] Follow-up query using session data")
 
 
477
 
478
  if not has_dataset:
479
  logger.warning("[ASYNC] No file uploaded and no session dataset available")
src/orchestrator.py CHANGED
@@ -269,8 +269,9 @@ class DataScienceCopilot:
269
  max_context = provider_max_tokens.get(self.provider, 128000)
270
  self.token_manager = get_token_manager(model=self.model, max_tokens=max_context)
271
 
272
- # ⚑ Initialize parallel executor
273
- self.parallel_executor = get_parallel_executor()
 
274
 
275
  # 🧠 Initialize session memory
276
  self.use_session_memory = use_session_memory
@@ -3438,7 +3439,7 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
3438
  print(f" These will run SEQUENTIALLY to prevent resource exhaustion")
3439
  print(f" Heavy tools: {', '.join(heavy_tools)}")
3440
  # Fall through to sequential execution
3441
- elif len(tool_executions) > 1 and len(heavy_tools) <= 1:
3442
  try:
3443
  results = asyncio.run(self.parallel_executor.execute_all(
3444
  tool_executions=tool_executions,
 
269
  max_context = provider_max_tokens.get(self.provider, 128000)
270
  self.token_manager = get_token_manager(model=self.model, max_tokens=max_context)
271
 
272
+ # ⚑ Parallel executor DISABLED - running tools sequentially for stability
273
+ # self.parallel_executor = get_parallel_executor()
274
+ self.parallel_executor = None # Disabled for scale optimization
275
 
276
  # 🧠 Initialize session memory
277
  self.use_session_memory = use_session_memory
 
3439
  print(f" These will run SEQUENTIALLY to prevent resource exhaustion")
3440
  print(f" Heavy tools: {', '.join(heavy_tools)}")
3441
  # Fall through to sequential execution
3442
+ elif len(tool_executions) > 1 and len(heavy_tools) <= 1 and self.parallel_executor is not None:
3443
  try:
3444
  results = asyncio.run(self.parallel_executor.execute_all(
3445
  tool_executions=tool_executions,