Spaces:
Running
Running
Pulastya B commited on
Commit Β·
e29cf28
1
Parent(s): a8b0cca
Fixed slow Agent loading which caused the Agent took too long to respond
Browse files- src/api/app.py +56 -51
- src/orchestrator.py +4 -3
src/api/app.py
CHANGED
|
@@ -28,6 +28,7 @@ import numpy as np
|
|
| 28 |
# Import from parent package
|
| 29 |
from src.orchestrator import DataScienceCopilot
|
| 30 |
from src.progress_manager import progress_manager
|
|
|
|
| 31 |
|
| 32 |
# Configure logging
|
| 33 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -78,10 +79,6 @@ app.add_middleware(
|
|
| 78 |
allow_headers=["*"],
|
| 79 |
)
|
| 80 |
|
| 81 |
-
# Initialize agent once (singleton pattern for stateless service)
|
| 82 |
-
# Agent itself is stateless - no conversation memory between requests
|
| 83 |
-
agent: Optional[DataScienceCopilot] = None
|
| 84 |
-
|
| 85 |
# SSE event queues for real-time streaming
|
| 86 |
class ProgressEventManager:
|
| 87 |
"""Manages SSE connections and progress events for real-time updates."""
|
|
@@ -151,69 +148,80 @@ class ProgressEventManager:
|
|
| 151 |
if session_id in self.session_status:
|
| 152 |
del self.session_status[session_id]
|
| 153 |
|
| 154 |
-
# π₯ MULTI-USER SUPPORT:
|
| 155 |
-
#
|
| 156 |
-
#
|
| 157 |
-
|
| 158 |
agent_cache_lock = asyncio.Lock()
|
| 159 |
-
MAX_CACHED_AGENTS = 10 # Limit memory usage
|
| 160 |
-
logger.info("π₯ Multi-user
|
| 161 |
|
| 162 |
-
#
|
|
|
|
|
|
|
| 163 |
agent = None
|
| 164 |
|
| 165 |
-
#
|
| 166 |
-
|
| 167 |
-
# This prevents users from interfering with each other's workflows
|
| 168 |
-
agent_cache: Dict[str, DataScienceCopilot] = {} # session_id -> agent instance
|
| 169 |
-
agent_cache_lock = asyncio.Lock()
|
| 170 |
-
MAX_CACHED_AGENTS = 10 # Limit memory usage
|
| 171 |
-
logger.info("π₯ Multi-user agent cache initialized")
|
| 172 |
-
|
| 173 |
-
# Legacy global agent for backward compatibility (will be deprecated)
|
| 174 |
-
agent = None
|
| 175 |
|
| 176 |
|
| 177 |
async def get_agent_for_session(session_id: str) -> DataScienceCopilot:
|
| 178 |
"""
|
| 179 |
-
Get
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
|
| 184 |
Args:
|
| 185 |
session_id: Unique session identifier
|
| 186 |
|
| 187 |
Returns:
|
| 188 |
-
DataScienceCopilot instance for this
|
| 189 |
"""
|
|
|
|
|
|
|
| 190 |
async with agent_cache_lock:
|
| 191 |
-
#
|
| 192 |
-
if
|
| 193 |
-
logger.
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
#
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
)
|
| 206 |
|
|
|
|
| 207 |
# Cache management: Remove oldest if cache is full
|
| 208 |
-
if len(
|
| 209 |
-
oldest_session = next(iter(
|
| 210 |
logger.info(f"[ποΈ] Cache full, removing session {oldest_session[:8]}...")
|
| 211 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
logger.info(f"β
Agent created for session {session_id[:8]} (cache: {len(agent_cache)}/{MAX_CACHED_AGENTS})")
|
| 215 |
|
| 216 |
-
return
|
| 217 |
|
| 218 |
# π REQUEST QUEUING: Global lock to prevent concurrent workflows
|
| 219 |
# This ensures only one analysis runs at a time, preventing:
|
|
@@ -460,15 +468,12 @@ async def run_analysis_async(
|
|
| 460 |
|
| 461 |
logger.info(f"[ASYNC] File saved: {file.filename}")
|
| 462 |
else:
|
| 463 |
-
# π‘οΈ VALIDATION:
|
| 464 |
-
# Note: In true multi-user setup, you'd need session_id from frontend to match exact session
|
| 465 |
has_dataset = False
|
| 466 |
async with agent_cache_lock:
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
logger.info(f"[ASYNC] Follow-up query using cached session data")
|
| 471 |
-
break
|
| 472 |
|
| 473 |
if not has_dataset:
|
| 474 |
logger.warning("[ASYNC] No file uploaded and no session dataset available")
|
|
|
|
| 28 |
# Import from parent package
|
| 29 |
from src.orchestrator import DataScienceCopilot
|
| 30 |
from src.progress_manager import progress_manager
|
| 31 |
+
from src.session_memory import SessionMemory
|
| 32 |
|
| 33 |
# Configure logging
|
| 34 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 79 |
allow_headers=["*"],
|
| 80 |
)
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
# SSE event queues for real-time streaming
|
| 83 |
class ProgressEventManager:
|
| 84 |
"""Manages SSE connections and progress events for real-time updates."""
|
|
|
|
| 148 |
if session_id in self.session_status:
|
| 149 |
del self.session_status[session_id]
|
| 150 |
|
| 151 |
+
# π₯ MULTI-USER SUPPORT: Session state isolation
|
| 152 |
+
# Heavy components (SBERT, tools, LLM client) are shared via global 'agent'
|
| 153 |
+
# Only session memory is isolated per user for fast initialization
|
| 154 |
+
session_states: Dict[str, Any] = {} # session_id -> SessionMemory
|
| 155 |
agent_cache_lock = asyncio.Lock()
|
| 156 |
+
MAX_CACHED_AGENTS = 10 # Limit memory usage (session states are lightweight)
|
| 157 |
+
logger.info("π₯ Multi-user session isolation initialized (fast mode)")
|
| 158 |
|
| 159 |
+
# Global agent - Heavy components loaded ONCE at startup
|
| 160 |
+
# SBERT model, tool functions, LLM client are shared across all users
|
| 161 |
+
agent: Optional[DataScienceCopilot] = None
|
| 162 |
agent = None
|
| 163 |
|
| 164 |
+
# Session state isolation (lightweight - just session memory)
|
| 165 |
+
session_states: Dict[str, any] = {} # session_id -> session memory only
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
|
| 168 |
async def get_agent_for_session(session_id: str) -> DataScienceCopilot:
|
| 169 |
"""
|
| 170 |
+
Get agent with isolated session state.
|
| 171 |
|
| 172 |
+
OPTIMIZATION: Instead of creating a full new agent per session (slow!),
|
| 173 |
+
we reuse the global agent but swap session memory per request.
|
| 174 |
+
Heavy components (SBERT, tools, LLM client) are shared.
|
| 175 |
+
This reduces per-user initialization from 20s to <1s.
|
| 176 |
|
| 177 |
Args:
|
| 178 |
session_id: Unique session identifier
|
| 179 |
|
| 180 |
Returns:
|
| 181 |
+
DataScienceCopilot instance with isolated session for this user
|
| 182 |
"""
|
| 183 |
+
global agent
|
| 184 |
+
|
| 185 |
async with agent_cache_lock:
|
| 186 |
+
# Ensure base agent exists (heavy components loaded once at startup)
|
| 187 |
+
if agent is None:
|
| 188 |
+
logger.warning("Base agent not initialized - this shouldn't happen after startup")
|
| 189 |
+
provider = os.getenv("LLM_PROVIDER", "mistral")
|
| 190 |
+
agent = DataScienceCopilot(
|
| 191 |
+
reasoning_effort="medium",
|
| 192 |
+
provider=provider,
|
| 193 |
+
use_compact_prompts=False
|
| 194 |
+
)
|
| 195 |
|
| 196 |
+
# Check if we have cached session memory for this session
|
| 197 |
+
if session_id in session_states:
|
| 198 |
+
logger.info(f"[β»οΈ] Reusing session state for {session_id[:8]}...")
|
| 199 |
+
agent.session = session_states[session_id]
|
| 200 |
+
agent.http_session_key = session_id
|
| 201 |
+
return agent
|
| 202 |
|
| 203 |
+
# π FAST PATH: Create new session memory only (no SBERT reload!)
|
| 204 |
+
logger.info(f"[π] Creating lightweight session for {session_id[:8]}...")
|
| 205 |
+
|
| 206 |
+
# Create isolated session memory for this user
|
| 207 |
+
new_session = SessionMemory(session_id=session_id)
|
|
|
|
| 208 |
|
| 209 |
+
# Cache session memory (lightweight)
|
| 210 |
# Cache management: Remove oldest if cache is full
|
| 211 |
+
if len(session_states) >= MAX_CACHED_AGENTS:
|
| 212 |
+
oldest_session = next(iter(session_states))
|
| 213 |
logger.info(f"[ποΈ] Cache full, removing session {oldest_session[:8]}...")
|
| 214 |
+
del session_states[oldest_session]
|
| 215 |
+
|
| 216 |
+
session_states[session_id] = new_session
|
| 217 |
+
|
| 218 |
+
# Set session on shared agent
|
| 219 |
+
agent.session = new_session
|
| 220 |
+
agent.http_session_key = session_id
|
| 221 |
|
| 222 |
+
logger.info(f"β
Session created for {session_id[:8]} (cache: {len(session_states)}/{MAX_CACHED_AGENTS}) - <1s init")
|
|
|
|
| 223 |
|
| 224 |
+
return agent
|
| 225 |
|
| 226 |
# π REQUEST QUEUING: Global lock to prevent concurrent workflows
|
| 227 |
# This ensures only one analysis runs at a time, preventing:
|
|
|
|
| 468 |
|
| 469 |
logger.info(f"[ASYNC] File saved: {file.filename}")
|
| 470 |
else:
|
| 471 |
+
# π‘οΈ VALIDATION: Check if agent's current session has dataset
|
|
|
|
| 472 |
has_dataset = False
|
| 473 |
async with agent_cache_lock:
|
| 474 |
+
if agent and hasattr(agent, 'session') and agent.session and hasattr(agent.session, 'last_dataset') and agent.session.last_dataset:
|
| 475 |
+
has_dataset = True
|
| 476 |
+
logger.info(f"[ASYNC] Follow-up query using session data")
|
|
|
|
|
|
|
| 477 |
|
| 478 |
if not has_dataset:
|
| 479 |
logger.warning("[ASYNC] No file uploaded and no session dataset available")
|
src/orchestrator.py
CHANGED
|
@@ -269,8 +269,9 @@ class DataScienceCopilot:
|
|
| 269 |
max_context = provider_max_tokens.get(self.provider, 128000)
|
| 270 |
self.token_manager = get_token_manager(model=self.model, max_tokens=max_context)
|
| 271 |
|
| 272 |
-
# β‘
|
| 273 |
-
self.parallel_executor = get_parallel_executor()
|
|
|
|
| 274 |
|
| 275 |
# π§ Initialize session memory
|
| 276 |
self.use_session_memory = use_session_memory
|
|
@@ -3438,7 +3439,7 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
|
|
| 3438 |
print(f" These will run SEQUENTIALLY to prevent resource exhaustion")
|
| 3439 |
print(f" Heavy tools: {', '.join(heavy_tools)}")
|
| 3440 |
# Fall through to sequential execution
|
| 3441 |
-
elif len(tool_executions) > 1 and len(heavy_tools) <= 1:
|
| 3442 |
try:
|
| 3443 |
results = asyncio.run(self.parallel_executor.execute_all(
|
| 3444 |
tool_executions=tool_executions,
|
|
|
|
| 269 |
max_context = provider_max_tokens.get(self.provider, 128000)
|
| 270 |
self.token_manager = get_token_manager(model=self.model, max_tokens=max_context)
|
| 271 |
|
| 272 |
+
# β‘ Parallel executor DISABLED - running tools sequentially for stability
|
| 273 |
+
# self.parallel_executor = get_parallel_executor()
|
| 274 |
+
self.parallel_executor = None # Disabled for scale optimization
|
| 275 |
|
| 276 |
# π§ Initialize session memory
|
| 277 |
self.use_session_memory = use_session_memory
|
|
|
|
| 3439 |
print(f" These will run SEQUENTIALLY to prevent resource exhaustion")
|
| 3440 |
print(f" Heavy tools: {', '.join(heavy_tools)}")
|
| 3441 |
# Fall through to sequential execution
|
| 3442 |
+
elif len(tool_executions) > 1 and len(heavy_tools) <= 1 and self.parallel_executor is not None:
|
| 3443 |
try:
|
| 3444 |
results = asyncio.run(self.parallel_executor.execute_all(
|
| 3445 |
tool_executions=tool_executions,
|