Spaces:
Running
Running
Pulastya B commited on
Commit Β·
08646ab
1
Parent(s): 293e0b4
Fixed Scalability issues
Browse files- src/api/app.py +172 -40
- src/orchestrator.py +28 -12
- src/tools/model_training.py +3 -4
src/api/app.py
CHANGED
|
@@ -151,8 +151,77 @@ class ProgressEventManager:
|
|
| 151 |
if session_id in self.session_status:
|
| 152 |
del self.session_status[session_id]
|
| 153 |
|
| 154 |
-
#
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# Mount static files for React frontend
|
| 158 |
frontend_path = Path(__file__).parent.parent.parent / "FRRONTEEEND" / "dist"
|
|
@@ -166,18 +235,19 @@ async def startup_event():
|
|
| 166 |
"""Initialize DataScienceCopilot on service startup."""
|
| 167 |
global agent
|
| 168 |
try:
|
| 169 |
-
logger.info("Initializing
|
| 170 |
provider = os.getenv("LLM_PROVIDER", "mistral")
|
| 171 |
-
# Disable compact prompts to enable multi-agent architecture
|
| 172 |
-
# Multi-agent system has focused prompts per specialist (~3K tokens each)
|
| 173 |
use_compact = False # Always use multi-agent routing
|
| 174 |
|
|
|
|
|
|
|
| 175 |
agent = DataScienceCopilot(
|
| 176 |
reasoning_effort="medium",
|
| 177 |
provider=provider,
|
| 178 |
use_compact_prompts=use_compact
|
| 179 |
)
|
| 180 |
-
logger.info(f"β
|
|
|
|
| 181 |
logger.info("π€ Multi-agent architecture enabled with 5 specialists")
|
| 182 |
except Exception as e:
|
| 183 |
logger.error(f"β Failed to initialize agent: {e}")
|
|
@@ -311,34 +381,50 @@ class AnalysisRequest(BaseModel):
|
|
| 311 |
def run_analysis_background(file_path: str, task_description: str, target_col: Optional[str],
|
| 312 |
use_cache: bool, max_iterations: int, session_id: str):
|
| 313 |
"""Background task to run analysis and emit events."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
try:
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
use_cache=use_cache,
|
| 322 |
-
max_iterations=max_iterations
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
logger.info(f"[BACKGROUND] Analysis completed for session {session_id}")
|
| 326 |
-
|
| 327 |
-
# Send completion event
|
| 328 |
-
progress_manager.emit(session_id, {
|
| 329 |
-
"type": "analysis_complete",
|
| 330 |
-
"status": result.get("status"),
|
| 331 |
-
"message": "β
Analysis completed successfully!",
|
| 332 |
-
"result": result
|
| 333 |
-
})
|
| 334 |
-
|
| 335 |
-
except Exception as e:
|
| 336 |
-
logger.error(f"[BACKGROUND] Analysis failed for session {session_id}: {e}")
|
| 337 |
-
progress_manager.emit(session_id, {
|
| 338 |
-
"type": "analysis_failed",
|
| 339 |
-
"error": str(e),
|
| 340 |
-
"message": f"β Analysis failed: {str(e)}"
|
| 341 |
-
})
|
| 342 |
|
| 343 |
|
| 344 |
@app.post("/run-async")
|
|
@@ -357,9 +443,10 @@ async def run_analysis_async(
|
|
| 357 |
if agent is None:
|
| 358 |
raise HTTPException(status_code=503, detail="Agent not initialized")
|
| 359 |
|
| 360 |
-
#
|
| 361 |
-
|
| 362 |
-
|
|
|
|
| 363 |
|
| 364 |
# Handle file upload
|
| 365 |
temp_file_path = None
|
|
@@ -372,6 +459,28 @@ async def run_analysis_async(
|
|
| 372 |
shutil.copyfileobj(file.file, buffer)
|
| 373 |
|
| 374 |
logger.info(f"[ASYNC] File saved: {file.filename}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
# Start background analysis
|
| 377 |
background_tasks.add_task(
|
|
@@ -427,20 +536,43 @@ async def run_analysis(
|
|
| 427 |
if agent is None:
|
| 428 |
raise HTTPException(status_code=503, detail="Agent not initialized")
|
| 429 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
# Handle follow-up requests (no file, using session memory)
|
| 431 |
if file is None:
|
| 432 |
logger.info(f"Follow-up request without file, using session memory")
|
| 433 |
logger.info(f"Task: {task_description}")
|
| 434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
# Get the agent's actual session UUID for SSE routing
|
| 436 |
-
actual_session_id =
|
| 437 |
print(f"[SSE] Follow-up using agent session UUID: {actual_session_id}")
|
| 438 |
|
| 439 |
# NO progress_callback - orchestrator emits directly to UUID
|
| 440 |
|
| 441 |
try:
|
| 442 |
# Agent's session memory should resolve file_path from context
|
| 443 |
-
result =
|
| 444 |
file_path="", # Empty - will be resolved by session memory
|
| 445 |
task_description=task_description,
|
| 446 |
target_col=target_col,
|
|
@@ -526,14 +658,14 @@ async def run_analysis(
|
|
| 526 |
logger.info(f"File saved successfully: {file.filename} ({os.path.getsize(temp_file_path)} bytes)")
|
| 527 |
|
| 528 |
# Get the agent's actual session UUID for SSE routing (BEFORE analyze())
|
| 529 |
-
actual_session_id =
|
| 530 |
print(f"[SSE] File upload using agent session UUID: {actual_session_id}")
|
| 531 |
|
| 532 |
# NO progress_callback - orchestrator emits directly to UUID
|
| 533 |
|
| 534 |
# Call existing agent logic
|
| 535 |
logger.info(f"Starting analysis with task: {task_description}")
|
| 536 |
-
result =
|
| 537 |
file_path=str(temp_file_path),
|
| 538 |
task_description=task_description,
|
| 539 |
target_col=target_col,
|
|
|
|
| 151 |
if session_id in self.session_status:
|
| 152 |
del self.session_status[session_id]
|
| 153 |
|
| 154 |
+
# π₯ MULTI-USER SUPPORT: Per-session agent instances
|
| 155 |
+
# Instead of one global agent, create isolated instances per session
|
| 156 |
+
# This prevents users from interfering with each other's workflows
|
| 157 |
+
agent_cache: Dict[str, DataScienceCopilot] = {} # session_id -> agent instance
|
| 158 |
+
agent_cache_lock = asyncio.Lock()
|
| 159 |
+
MAX_CACHED_AGENTS = 10 # Limit memory usage
|
| 160 |
+
logger.info("π₯ Multi-user agent cache initialized")
|
| 161 |
+
|
| 162 |
+
# Legacy global agent for backward compatibility (will be deprecated)
|
| 163 |
+
agent = None
|
| 164 |
+
|
| 165 |
+
# π₯ MULTI-USER SUPPORT: Per-session agent instances
|
| 166 |
+
# Instead of one global agent, create isolated instances per session
|
| 167 |
+
# This prevents users from interfering with each other's workflows
|
| 168 |
+
agent_cache: Dict[str, DataScienceCopilot] = {} # session_id -> agent instance
|
| 169 |
+
agent_cache_lock = asyncio.Lock()
|
| 170 |
+
MAX_CACHED_AGENTS = 10 # Limit memory usage
|
| 171 |
+
logger.info("π₯ Multi-user agent cache initialized")
|
| 172 |
+
|
| 173 |
+
# Legacy global agent for backward compatibility (will be deprecated)
|
| 174 |
+
agent = None
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
async def get_agent_for_session(session_id: str) -> DataScienceCopilot:
|
| 178 |
+
"""
|
| 179 |
+
Get or create an isolated agent instance for a session.
|
| 180 |
+
|
| 181 |
+
This ensures each user gets their own agent with isolated state,
|
| 182 |
+
preventing session collisions and race conditions.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
session_id: Unique session identifier
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
DataScienceCopilot instance for this session
|
| 189 |
+
"""
|
| 190 |
+
async with agent_cache_lock:
|
| 191 |
+
# Return existing agent if cached
|
| 192 |
+
if session_id in agent_cache:
|
| 193 |
+
logger.info(f"[β»οΈ] Reusing cached agent for session {session_id[:8]}...")
|
| 194 |
+
return agent_cache[session_id]
|
| 195 |
+
|
| 196 |
+
# Create new agent instance
|
| 197 |
+
logger.info(f"[π] Creating new agent for session {session_id[:8]}...")
|
| 198 |
+
provider = os.getenv("LLM_PROVIDER", "mistral")
|
| 199 |
+
|
| 200 |
+
new_agent = DataScienceCopilot(
|
| 201 |
+
reasoning_effort="medium",
|
| 202 |
+
provider=provider,
|
| 203 |
+
use_compact_prompts=False, # Multi-agent architecture
|
| 204 |
+
session_id=session_id # Pass session_id for isolation
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Cache management: Remove oldest if cache is full
|
| 208 |
+
if len(agent_cache) >= MAX_CACHED_AGENTS:
|
| 209 |
+
oldest_session = next(iter(agent_cache))
|
| 210 |
+
logger.info(f"[ποΈ] Cache full, removing session {oldest_session[:8]}...")
|
| 211 |
+
del agent_cache[oldest_session]
|
| 212 |
+
|
| 213 |
+
agent_cache[session_id] = new_agent
|
| 214 |
+
logger.info(f"β
Agent created for session {session_id[:8]} (cache: {len(agent_cache)}/{MAX_CACHED_AGENTS})")
|
| 215 |
+
|
| 216 |
+
return new_agent
|
| 217 |
+
|
| 218 |
+
# π REQUEST QUEUING: Global lock to prevent concurrent workflows
|
| 219 |
+
# This ensures only one analysis runs at a time, preventing:
|
| 220 |
+
# - Race conditions on file writes
|
| 221 |
+
# - Memory exhaustion from parallel model training
|
| 222 |
+
# - Session state corruption
|
| 223 |
+
workflow_lock = asyncio.Lock()
|
| 224 |
+
logger.info("π Workflow lock initialized for request queuing")
|
| 225 |
|
| 226 |
# Mount static files for React frontend
|
| 227 |
frontend_path = Path(__file__).parent.parent.parent / "FRRONTEEEND" / "dist"
|
|
|
|
| 235 |
"""Initialize DataScienceCopilot on service startup."""
|
| 236 |
global agent
|
| 237 |
try:
|
| 238 |
+
logger.info("Initializing legacy global agent for health checks...")
|
| 239 |
provider = os.getenv("LLM_PROVIDER", "mistral")
|
|
|
|
|
|
|
| 240 |
use_compact = False # Always use multi-agent routing
|
| 241 |
|
| 242 |
+
# Create one agent for health checks only
|
| 243 |
+
# Real requests will use get_agent_for_session() for isolation
|
| 244 |
agent = DataScienceCopilot(
|
| 245 |
reasoning_effort="medium",
|
| 246 |
provider=provider,
|
| 247 |
use_compact_prompts=use_compact
|
| 248 |
)
|
| 249 |
+
logger.info(f"β
Health check agent initialized with provider: {agent.provider}")
|
| 250 |
+
logger.info("π₯ Per-session agents enabled - each user gets isolated instance")
|
| 251 |
logger.info("π€ Multi-agent architecture enabled with 5 specialists")
|
| 252 |
except Exception as e:
|
| 253 |
logger.error(f"β Failed to initialize agent: {e}")
|
|
|
|
| 381 |
def run_analysis_background(file_path: str, task_description: str, target_col: Optional[str],
|
| 382 |
use_cache: bool, max_iterations: int, session_id: str):
|
| 383 |
"""Background task to run analysis and emit events."""
|
| 384 |
+
async def _run_with_lock():
|
| 385 |
+
"""Wrap analysis in lock to ensure sequential execution."""
|
| 386 |
+
async with workflow_lock:
|
| 387 |
+
try:
|
| 388 |
+
logger.info(f"[BACKGROUND] Starting analysis for session {session_id[:8]}...")
|
| 389 |
+
|
| 390 |
+
# π₯ Get isolated agent for this session
|
| 391 |
+
session_agent = await get_agent_for_session(session_id)
|
| 392 |
+
|
| 393 |
+
result = session_agent.analyze(
|
| 394 |
+
file_path=file_path,
|
| 395 |
+
task_description=task_description,
|
| 396 |
+
target_col=target_col,
|
| 397 |
+
use_cache=use_cache,
|
| 398 |
+
max_iterations=max_iterations
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
logger.info(f"[BACKGROUND] Analysis completed for session {session_id[:8]}...")
|
| 402 |
+
|
| 403 |
+
# Send completion event
|
| 404 |
+
progress_manager.emit(session_id, {
|
| 405 |
+
"type": "analysis_complete",
|
| 406 |
+
"status": result.get("status"),
|
| 407 |
+
"message": "β
Analysis completed successfully!",
|
| 408 |
+
"result": result
|
| 409 |
+
})
|
| 410 |
+
|
| 411 |
+
except Exception as e:
|
| 412 |
+
logger.error(f"[BACKGROUND] Analysis failed for session {session_id[:8]}...: {e}")
|
| 413 |
+
progress_manager.emit(session_id, {
|
| 414 |
+
"type": "analysis_failed",
|
| 415 |
+
"error": str(e),
|
| 416 |
+
"message": f"β Analysis failed: {str(e)}"
|
| 417 |
+
})
|
| 418 |
+
|
| 419 |
+
# Run async function in event loop
|
| 420 |
+
import asyncio
|
| 421 |
try:
|
| 422 |
+
loop = asyncio.get_event_loop()
|
| 423 |
+
except RuntimeError:
|
| 424 |
+
loop = asyncio.new_event_loop()
|
| 425 |
+
asyncio.set_event_loop(loop)
|
| 426 |
+
|
| 427 |
+
loop.run_until_complete(_run_with_lock())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
|
| 430 |
@app.post("/run-async")
|
|
|
|
| 443 |
if agent is None:
|
| 444 |
raise HTTPException(status_code=503, detail="Agent not initialized")
|
| 445 |
|
| 446 |
+
# π Generate unique session ID for this request
|
| 447 |
+
import uuid
|
| 448 |
+
session_id = str(uuid.uuid4())
|
| 449 |
+
logger.info(f"[ASYNC] Created session: {session_id[:8]}...")
|
| 450 |
|
| 451 |
# Handle file upload
|
| 452 |
temp_file_path = None
|
|
|
|
| 459 |
shutil.copyfileobj(file.file, buffer)
|
| 460 |
|
| 461 |
logger.info(f"[ASYNC] File saved: {file.filename}")
|
| 462 |
+
else:
|
| 463 |
+
# π‘οΈ VALIDATION: For follow-up queries, check if any cached agent has dataset
|
| 464 |
+
# Note: In true multi-user setup, you'd need session_id from frontend to match exact session
|
| 465 |
+
has_dataset = False
|
| 466 |
+
async with agent_cache_lock:
|
| 467 |
+
for cached_agent in agent_cache.values():
|
| 468 |
+
if hasattr(cached_agent, 'session') and cached_agent.session and cached_agent.session.last_dataset:
|
| 469 |
+
has_dataset = True
|
| 470 |
+
logger.info(f"[ASYNC] Follow-up query using cached session data")
|
| 471 |
+
break
|
| 472 |
+
|
| 473 |
+
if not has_dataset:
|
| 474 |
+
logger.warning("[ASYNC] No file uploaded and no session dataset available")
|
| 475 |
+
return JSONResponse(
|
| 476 |
+
content={
|
| 477 |
+
"success": False,
|
| 478 |
+
"error": "No dataset available",
|
| 479 |
+
"message": "Please upload a CSV, Excel, or Parquet file first.",
|
| 480 |
+
"session_id": session_id
|
| 481 |
+
},
|
| 482 |
+
status_code=400
|
| 483 |
+
)
|
| 484 |
|
| 485 |
# Start background analysis
|
| 486 |
background_tasks.add_task(
|
|
|
|
| 536 |
if agent is None:
|
| 537 |
raise HTTPException(status_code=503, detail="Agent not initialized")
|
| 538 |
|
| 539 |
+
# π Generate or use provided session ID
|
| 540 |
+
if not session_id:
|
| 541 |
+
import uuid
|
| 542 |
+
session_id = str(uuid.uuid4())
|
| 543 |
+
logger.info(f"[SYNC] Created new session: {session_id[:8]}...")
|
| 544 |
+
else:
|
| 545 |
+
logger.info(f"[SYNC] Using provided session: {session_id[:8]}...")
|
| 546 |
+
|
| 547 |
+
# π₯ Get isolated agent for this session
|
| 548 |
+
session_agent = await get_agent_for_session(session_id)
|
| 549 |
+
|
| 550 |
# Handle follow-up requests (no file, using session memory)
|
| 551 |
if file is None:
|
| 552 |
logger.info(f"Follow-up request without file, using session memory")
|
| 553 |
logger.info(f"Task: {task_description}")
|
| 554 |
|
| 555 |
+
# π‘οΈ VALIDATION: Check if session has a dataset
|
| 556 |
+
if not (hasattr(session_agent, 'session') and session_agent.session and session_agent.session.last_dataset):
|
| 557 |
+
logger.warning("No file uploaded and no session dataset available")
|
| 558 |
+
return JSONResponse(
|
| 559 |
+
content={
|
| 560 |
+
"success": False,
|
| 561 |
+
"error": "No dataset available",
|
| 562 |
+
"message": "Please upload a CSV, Excel, or Parquet file first before asking questions."
|
| 563 |
+
},
|
| 564 |
+
status_code=400
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
# Get the agent's actual session UUID for SSE routing
|
| 568 |
+
actual_session_id = session_agent.session.session_id if hasattr(session_agent, 'session') and session_agent.session else session_id
|
| 569 |
print(f"[SSE] Follow-up using agent session UUID: {actual_session_id}")
|
| 570 |
|
| 571 |
# NO progress_callback - orchestrator emits directly to UUID
|
| 572 |
|
| 573 |
try:
|
| 574 |
# Agent's session memory should resolve file_path from context
|
| 575 |
+
result = session_agent.analyze(
|
| 576 |
file_path="", # Empty - will be resolved by session memory
|
| 577 |
task_description=task_description,
|
| 578 |
target_col=target_col,
|
|
|
|
| 658 |
logger.info(f"File saved successfully: {file.filename} ({os.path.getsize(temp_file_path)} bytes)")
|
| 659 |
|
| 660 |
# Get the agent's actual session UUID for SSE routing (BEFORE analyze())
|
| 661 |
+
actual_session_id = session_agent.session.session_id if hasattr(session_agent, 'session') and session_agent.session else session_id
|
| 662 |
print(f"[SSE] File upload using agent session UUID: {actual_session_id}")
|
| 663 |
|
| 664 |
# NO progress_callback - orchestrator emits directly to UUID
|
| 665 |
|
| 666 |
# Call existing agent logic
|
| 667 |
logger.info(f"Starting analysis with task: {task_description}")
|
| 668 |
+
result = session_agent.analyze(
|
| 669 |
file_path=str(temp_file_path),
|
| 670 |
task_description=task_description,
|
| 671 |
target_col=target_col,
|
src/orchestrator.py
CHANGED
|
@@ -402,7 +402,7 @@ class DataScienceCopilot:
|
|
| 402 |
"split_data_strategically": split_data_strategically,
|
| 403 |
# Advanced Training (3)
|
| 404 |
"hyperparameter_tuning": hyperparameter_tuning,
|
| 405 |
-
"train_ensemble_models": train_ensemble_models,
|
| 406 |
"perform_cross_validation": perform_cross_validation,
|
| 407 |
# Business Intelligence (4)
|
| 408 |
"perform_cohort_analysis": perform_cohort_analysis,
|
|
@@ -554,7 +554,8 @@ When you need to use a tool, respond with a JSON block like this:
|
|
| 554 |
- Keywords: "train model", "predict", "classify", "build model", "forecast"
|
| 555 |
- User wants: cleaning + feature engineering + model training
|
| 556 |
- **ACTION**: Run full ML workflow (steps 1-15 below)
|
| 557 |
-
- **
|
|
|
|
| 558 |
|
| 559 |
**E. UNCLEAR/AMBIGUOUS REQUESTS** - Intent is not obvious:
|
| 560 |
- User says: "analyze", "look at", "check", "review" (without specifics)
|
|
@@ -657,16 +658,16 @@ structure, variable relationships, and expected insights - not hardcoded domain
|
|
| 657 |
8. encode_categorical(latest, method="auto", output="./outputs/data/encoded.csv")
|
| 658 |
9. generate_eda_plots(encoded, target_col, output_dir="./outputs/plots/eda") - Generate EDA visualizations
|
| 659 |
10. **ONLY IF USER EXPLICITLY REQUESTED ML**: train_baseline_models(encoded, target_col, task_type="auto")
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
- **
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
- **How**: hyperparameter_tuning(file_path=encoded, target_col=target_col, model_type="xgboost", n_trials=50)
|
| 667 |
- **Large datasets (>100K rows)**: n_trials automatically reduced to 20 to prevent timeout
|
| 668 |
- **Only tune the WINNING model** (don't waste time on others)
|
| 669 |
-
- **Map model names**: XGBoostβ"xgboost",
|
| 670 |
- **Note**: Time features should already be extracted in step 7 (create_time_features)
|
| 671 |
12. **CROSS-VALIDATION (OPTIONAL - Production Models)**:
|
| 672 |
- IF user says "validate", "production", "robust", "deploy" β ALWAYS cross-validate
|
|
@@ -836,7 +837,7 @@ Use specialized tools FIRST. Only use execute_python_code for:
|
|
| 836 |
- train_baseline_models: Trains multiple models automatically
|
| 837 |
- **β execute_python_code**: Write and run custom Python code for ANY task not covered by tools (TRUE AI AGENT capability)
|
| 838 |
- **execute_code_from_file**: Run existing Python scripts
|
| 839 |
-
- Advanced: hyperparameter_tuning,
|
| 840 |
- NEW Advanced Insights: analyze_root_cause, detect_trends_and_seasonality, detect_anomalies_advanced, perform_hypothesis_testing, analyze_distribution, perform_segment_analysis
|
| 841 |
- NEW Automation: auto_ml_pipeline (zero-config full pipeline), auto_feature_selection
|
| 842 |
- NEW Visualization: generate_all_plots, generate_data_quality_plots, generate_eda_plots, generate_model_performance_plots, generate_feature_importance_plot
|
|
@@ -1020,7 +1021,7 @@ BEFORE calling any training tools, you MUST:
|
|
| 1020 |
|
| 1021 |
**Your Tools (6 modeling-focused):**
|
| 1022 |
- train_baseline_models, hyperparameter_tuning
|
| 1023 |
-
-
|
| 1024 |
- generate_model_report, detect_model_issues
|
| 1025 |
|
| 1026 |
**Your Approach:**
|
|
@@ -2746,6 +2747,19 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
|
|
| 2746 |
|
| 2747 |
# π LOCAL SCHEMA EXTRACTION (NO LLM) - Extract metadata before any LLM calls
|
| 2748 |
# Now that file_path is resolved from session if needed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2749 |
print("π Extracting dataset schema locally (no LLM)...")
|
| 2750 |
schema_info = extract_schema_local(file_path, sample_rows=3)
|
| 2751 |
|
|
@@ -3366,7 +3380,9 @@ You receive quality reports from EDA agent and deliver clean data to modeling ag
|
|
| 3366 |
messages.append(response_message)
|
| 3367 |
|
| 3368 |
# π PARALLEL EXECUTION: Detect multiple independent tool calls
|
| 3369 |
-
|
|
|
|
|
|
|
| 3370 |
print(f"π Detected {len(tool_calls)} tool calls - attempting parallel execution")
|
| 3371 |
|
| 3372 |
# Extract tool executions with proper weight classification
|
|
|
|
| 402 |
"split_data_strategically": split_data_strategically,
|
| 403 |
# Advanced Training (3)
|
| 404 |
"hyperparameter_tuning": hyperparameter_tuning,
|
| 405 |
+
# "train_ensemble_models": train_ensemble_models, # DISABLED - Too resource intensive for scale
|
| 406 |
"perform_cross_validation": perform_cross_validation,
|
| 407 |
# Business Intelligence (4)
|
| 408 |
"perform_cohort_analysis": perform_cohort_analysis,
|
|
|
|
| 554 |
- Keywords: "train model", "predict", "classify", "build model", "forecast"
|
| 555 |
- User wants: cleaning + feature engineering + model training
|
| 556 |
- **ACTION**: Run full ML workflow (steps 1-15 below)
|
| 557 |
+
- **π― IMPORTANT**: ALWAYS generate ydata_profiling_report at the END of workflow for comprehensive final analysis
|
| 558 |
+
- **Example**: "Train a model to predict earthquake magnitude" β Full pipeline + ydata_profiling_report at end
|
| 559 |
|
| 560 |
**E. UNCLEAR/AMBIGUOUS REQUESTS** - Intent is not obvious:
|
| 561 |
- User says: "analyze", "look at", "check", "review" (without specifics)
|
|
|
|
| 658 |
8. encode_categorical(latest, method="auto", output="./outputs/data/encoded.csv")
|
| 659 |
9. generate_eda_plots(encoded, target_col, output_dir="./outputs/plots/eda") - Generate EDA visualizations
|
| 660 |
10. **ONLY IF USER EXPLICITLY REQUESTED ML**: train_baseline_models(encoded, target_col, task_type="auto")
|
| 661 |
+
10b. **ALWAYS AFTER MODEL TRAINING**: generate_ydata_profiling_report(encoded, output_path="./outputs/reports/ydata_profile.html") - Comprehensive data analysis report
|
| 662 |
+
11. **HYPERPARAMETER TUNING (β οΈ ONLY WHEN EXPLICITLY REQUESTED)**:
|
| 663 |
+
- β οΈ **CRITICAL WARNING**: This is EXTREMELY expensive (5-10 minutes) and resource-intensive!
|
| 664 |
+
- β οΈ **DO NOT USE UNLESS USER EXPLICITLY ASKS FOR IT**
|
| 665 |
+
- **ONLY use when user says**: "tune", "optimize", "hyperparameter", "improve model", "best parameters"
|
| 666 |
+
- **NEVER auto-trigger** based on scores - user must explicitly request it
|
| 667 |
- **How**: hyperparameter_tuning(file_path=encoded, target_col=target_col, model_type="xgboost", n_trials=50)
|
| 668 |
- **Large datasets (>100K rows)**: n_trials automatically reduced to 20 to prevent timeout
|
| 669 |
- **Only tune the WINNING model** (don't waste time on others)
|
| 670 |
+
- **Map model names**: XGBoostβ"xgboost", Ridgeβ"ridge", Lassoβuse Ridge
|
| 671 |
- **Note**: Time features should already be extracted in step 7 (create_time_features)
|
| 672 |
12. **CROSS-VALIDATION (OPTIONAL - Production Models)**:
|
| 673 |
- IF user says "validate", "production", "robust", "deploy" β ALWAYS cross-validate
|
|
|
|
| 837 |
- train_baseline_models: Trains multiple models automatically
|
| 838 |
- **β execute_python_code**: Write and run custom Python code for ANY task not covered by tools (TRUE AI AGENT capability)
|
| 839 |
- **execute_code_from_file**: Run existing Python scripts
|
| 840 |
+
- Advanced: hyperparameter_tuning, perform_eda_analysis, handle_imbalanced_data, perform_feature_scaling, detect_anomalies, detect_and_handle_multicollinearity, auto_feature_engineering, forecast_time_series, explain_predictions, generate_business_insights, perform_topic_modeling, extract_image_features, monitor_model_drift
|
| 841 |
- NEW Advanced Insights: analyze_root_cause, detect_trends_and_seasonality, detect_anomalies_advanced, perform_hypothesis_testing, analyze_distribution, perform_segment_analysis
|
| 842 |
- NEW Automation: auto_ml_pipeline (zero-config full pipeline), auto_feature_selection
|
| 843 |
- NEW Visualization: generate_all_plots, generate_data_quality_plots, generate_eda_plots, generate_model_performance_plots, generate_feature_importance_plot
|
|
|
|
| 1021 |
|
| 1022 |
**Your Tools (6 modeling-focused):**
|
| 1023 |
- train_baseline_models, hyperparameter_tuning
|
| 1024 |
+
- perform_cross_validation
|
| 1025 |
- generate_model_report, detect_model_issues
|
| 1026 |
|
| 1027 |
**Your Approach:**
|
|
|
|
| 2747 |
|
| 2748 |
# π LOCAL SCHEMA EXTRACTION (NO LLM) - Extract metadata before any LLM calls
|
| 2749 |
# Now that file_path is resolved from session if needed
|
| 2750 |
+
|
| 2751 |
+
# π‘οΈ VALIDATION: Ensure we have a valid file path
|
| 2752 |
+
if not file_path or file_path == "":
|
| 2753 |
+
error_msg = "No dataset file provided. Please upload a CSV, Excel, or Parquet file."
|
| 2754 |
+
print(f"β {error_msg}")
|
| 2755 |
+
return {
|
| 2756 |
+
"status": "error",
|
| 2757 |
+
"error": error_msg,
|
| 2758 |
+
"summary": "Cannot proceed without a dataset file.",
|
| 2759 |
+
"workflow_history": [],
|
| 2760 |
+
"execution_time": 0.0
|
| 2761 |
+
}
|
| 2762 |
+
|
| 2763 |
print("π Extracting dataset schema locally (no LLM)...")
|
| 2764 |
schema_info = extract_schema_local(file_path, sample_rows=3)
|
| 2765 |
|
|
|
|
| 3380 |
messages.append(response_message)
|
| 3381 |
|
| 3382 |
# π PARALLEL EXECUTION: Detect multiple independent tool calls
|
| 3383 |
+
# β οΈ DISABLED FOR STABILITY - Parallel execution causes race conditions and OOM errors
|
| 3384 |
+
# Re-enable only after implementing proper request isolation per user
|
| 3385 |
+
if len(tool_calls) > 1 and False: # Disabled with "and False"
|
| 3386 |
print(f"π Detected {len(tool_calls)} tool calls - attempting parallel execution")
|
| 3387 |
|
| 3388 |
# Extract tool executions with proper weight classification
|
src/tools/model_training.py
CHANGED
|
@@ -129,15 +129,15 @@ def train_baseline_models(file_path: str, target_col: str,
|
|
| 129 |
|
| 130 |
# Train models based on task type
|
| 131 |
import sys
|
| 132 |
-
print(f"\nπ Training {
|
| 133 |
print(f" π Training set: {len(X_train):,} samples Γ {X_train.shape[1]} features", flush=True)
|
| 134 |
print(f" π Test set: {len(X_test):,} samples", flush=True)
|
|
|
|
| 135 |
sys.stdout.flush()
|
| 136 |
|
| 137 |
if task_type == "classification":
|
| 138 |
models = {
|
| 139 |
"logistic_regression": LogisticRegression(max_iter=1000, random_state=random_state),
|
| 140 |
-
"random_forest": RandomForestClassifier(n_estimators=100, random_state=random_state, n_jobs=-1),
|
| 141 |
"xgboost": XGBClassifier(n_estimators=100, random_state=random_state, n_jobs=-1),
|
| 142 |
"lightgbm": LGBMClassifier(n_estimators=100, random_state=random_state, n_jobs=-1, verbose=-1),
|
| 143 |
"catboost": CatBoostClassifier(iterations=100, random_state=random_state, verbose=0, allow_writing_files=False)
|
|
@@ -213,7 +213,6 @@ def train_baseline_models(file_path: str, target_col: str,
|
|
| 213 |
models = {
|
| 214 |
"ridge": Ridge(random_state=random_state),
|
| 215 |
"lasso": Lasso(random_state=random_state),
|
| 216 |
-
"random_forest": RandomForestRegressor(n_estimators=100, random_state=random_state, n_jobs=-1),
|
| 217 |
"xgboost": XGBRegressor(n_estimators=100, random_state=random_state, n_jobs=-1),
|
| 218 |
"lightgbm": LGBMRegressor(n_estimators=100, random_state=random_state, n_jobs=-1, verbose=-1),
|
| 219 |
"catboost": CatBoostRegressor(iterations=100, random_state=random_state, verbose=0, allow_writing_files=False)
|
|
@@ -316,7 +315,7 @@ def train_baseline_models(file_path: str, target_col: str,
|
|
| 316 |
"suggested_model": best_model_name,
|
| 317 |
"reason": f"{best_model_name} is optimal for large datasets - fast training and good performance"
|
| 318 |
}
|
| 319 |
-
elif best_model_name == "
|
| 320 |
# Find next best fast model
|
| 321 |
fast_model_scores = {name: results["models"][name]["test_metrics"].get("r2" if task_type == "regression" else "f1", 0)
|
| 322 |
for name in fast_models if name in results["models"]}
|
|
|
|
| 129 |
|
| 130 |
# Train models based on task type
|
| 131 |
import sys
|
| 132 |
+
print(f"\nπ Training {5 if task_type == 'classification' else 5} baseline models...", flush=True)
|
| 133 |
print(f" π Training set: {len(X_train):,} samples Γ {X_train.shape[1]} features", flush=True)
|
| 134 |
print(f" π Test set: {len(X_test):,} samples", flush=True)
|
| 135 |
+
print(f" β‘ Note: Random Forest excluded to optimize compute resources", flush=True)
|
| 136 |
sys.stdout.flush()
|
| 137 |
|
| 138 |
if task_type == "classification":
|
| 139 |
models = {
|
| 140 |
"logistic_regression": LogisticRegression(max_iter=1000, random_state=random_state),
|
|
|
|
| 141 |
"xgboost": XGBClassifier(n_estimators=100, random_state=random_state, n_jobs=-1),
|
| 142 |
"lightgbm": LGBMClassifier(n_estimators=100, random_state=random_state, n_jobs=-1, verbose=-1),
|
| 143 |
"catboost": CatBoostClassifier(iterations=100, random_state=random_state, verbose=0, allow_writing_files=False)
|
|
|
|
| 213 |
models = {
|
| 214 |
"ridge": Ridge(random_state=random_state),
|
| 215 |
"lasso": Lasso(random_state=random_state),
|
|
|
|
| 216 |
"xgboost": XGBRegressor(n_estimators=100, random_state=random_state, n_jobs=-1),
|
| 217 |
"lightgbm": LGBMRegressor(n_estimators=100, random_state=random_state, n_jobs=-1, verbose=-1),
|
| 218 |
"catboost": CatBoostRegressor(iterations=100, random_state=random_state, verbose=0, allow_writing_files=False)
|
|
|
|
| 315 |
"suggested_model": best_model_name,
|
| 316 |
"reason": f"{best_model_name} is optimal for large datasets - fast training and good performance"
|
| 317 |
}
|
| 318 |
+
elif best_model_name == "random_forest_legacy": # Disabled for compute optimization
|
| 319 |
# Find next best fast model
|
| 320 |
fast_model_scores = {name: results["models"][name]["test_metrics"].get("r2" if task_type == "regression" else "f1", 0)
|
| 321 |
for name in fast_models if name in results["models"]}
|