Pulastya B commited on
Commit
187c5e0
·
1 Parent(s): 554eeb5

Fixed multi-session support and follow-up queries - Session UUID reuse, multi-chat isolation, proper SSE switching

Browse files
Files changed (3) hide show
  1. src/api/app.py +21 -8
  2. test_improvements.py +0 -141
  3. test_multi_agent.py +0 -223
src/api/app.py CHANGED
@@ -441,20 +441,30 @@ async def run_analysis_async(
441
  file: Optional[UploadFile] = File(None),
442
  task_description: str = Form(...),
443
  target_col: Optional[str] = Form(None),
 
444
  use_cache: bool = Form(False), # Disabled to show multi-agent in action
445
  max_iterations: int = Form(20)
446
  ) -> JSONResponse:
447
  """
448
  Start analysis in background and return session UUID immediately.
449
  Frontend can connect SSE with this UUID to receive real-time updates.
 
 
450
  """
451
  if agent is None:
452
  raise HTTPException(status_code=503, detail="Agent not initialized")
453
 
454
- # 🆔 Generate unique session ID for this request
 
 
455
  import uuid
456
- session_id = str(uuid.uuid4())
457
- logger.info(f"[ASYNC] Created session: {session_id[:8]}...")
 
 
 
 
 
458
 
459
  # Handle file upload
460
  temp_file_path = None
@@ -468,15 +478,18 @@ async def run_analysis_async(
468
 
469
  logger.info(f"[ASYNC] File saved: {file.filename}")
470
  else:
471
- # 🛡️ VALIDATION: Check if agent's current session has dataset
472
  has_dataset = False
473
  async with agent_cache_lock:
474
- if agent and hasattr(agent, 'session') and agent.session and hasattr(agent.session, 'last_dataset') and agent.session.last_dataset:
475
- has_dataset = True
476
- logger.info(f"[ASYNC] Follow-up query using session data")
 
 
 
477
 
478
  if not has_dataset:
479
- logger.warning("[ASYNC] No file uploaded and no session dataset available")
480
  return JSONResponse(
481
  content={
482
  "success": False,
 
441
  file: Optional[UploadFile] = File(None),
442
  task_description: str = Form(...),
443
  target_col: Optional[str] = Form(None),
444
+ session_id: Optional[str] = Form(None), # Accept session_id from frontend for follow-ups
445
  use_cache: bool = Form(False), # Disabled to show multi-agent in action
446
  max_iterations: int = Form(20)
447
  ) -> JSONResponse:
448
  """
449
  Start analysis in background and return session UUID immediately.
450
  Frontend can connect SSE with this UUID to receive real-time updates.
451
+
452
+ For follow-up queries, frontend should send the same session_id to maintain context.
453
  """
454
  if agent is None:
455
  raise HTTPException(status_code=503, detail="Agent not initialized")
456
 
457
+ # 🆔 Session ID handling:
458
+ # - If frontend sends a valid UUID, REUSE it (follow-up query)
459
+ # - Otherwise generate a new one (first query)
460
  import uuid
461
+ if session_id and '-' in session_id and len(session_id) > 20:
462
+ # Valid UUID from frontend - this is a follow-up query
463
+ logger.info(f"[ASYNC] Reusing session: {session_id[:8]}... (follow-up)")
464
+ else:
465
+ # Generate new session for first query
466
+ session_id = str(uuid.uuid4())
467
+ logger.info(f"[ASYNC] Created new session: {session_id[:8]}...")
468
 
469
  # Handle file upload
470
  temp_file_path = None
 
478
 
479
  logger.info(f"[ASYNC] File saved: {file.filename}")
480
  else:
481
+ # 🛡️ VALIDATION: Check if this session has dataset cached
482
  has_dataset = False
483
  async with agent_cache_lock:
484
+ # Check session_states cache for this specific session_id
485
+ if session_id in session_states:
486
+ cached_session = session_states[session_id]
487
+ if hasattr(cached_session, 'last_dataset') and cached_session.last_dataset:
488
+ has_dataset = True
489
+ logger.info(f"[ASYNC] Follow-up query for session {session_id[:8]}... - using cached dataset")
490
 
491
  if not has_dataset:
492
+ logger.warning(f"[ASYNC] No file uploaded and no dataset for session {session_id[:8]}...")
493
  return JSONResponse(
494
  content={
495
  "success": False,
test_improvements.py DELETED
@@ -1,141 +0,0 @@
1
- """
2
- Quick test to verify all new systems are working correctly
3
- """
4
-
5
- print("=" * 60)
6
- print("Testing Data Science Agent System Improvements")
7
- print("=" * 60)
8
-
9
- # Test 1: Semantic Layer
10
- print("\n1️⃣ Testing SBERT Semantic Layer...")
11
- try:
12
- from src.utils.semantic_layer import get_semantic_layer
13
- semantic = get_semantic_layer()
14
-
15
- if semantic.enabled:
16
- print(" ✅ SBERT model loaded successfully")
17
- print(f" 📦 Model: {semantic.model_name}")
18
-
19
- # Test semantic column matching
20
- result = semantic.semantic_column_match("Salary", ["Annual_Income", "Name", "Age"], threshold=0.5)
21
- if result:
22
- col, conf = result
23
- print(f" ✅ Semantic matching works: 'Salary' → '{col}' (confidence: {conf:.2f})")
24
- else:
25
- print(" ⚠️ No match found (threshold too high)")
26
-
27
- # Test agent routing
28
- agent_descs = {
29
- "modeling_agent": "Expert in machine learning model training",
30
- "viz_agent": "Expert in data visualization"
31
- }
32
- best_agent, conf = semantic.route_to_agent("train a random forest model", agent_descs)
33
- print(f" ✅ Agent routing works: '{best_agent}' (confidence: {conf:.2f})")
34
- else:
35
- print(" ⚠️ SBERT not available (missing dependencies)")
36
- except Exception as e:
37
- print(f" ❌ Error: {e}")
38
-
39
- # Test 2: Error Recovery
40
- print("\n2️⃣ Testing Error Recovery System...")
41
- try:
42
- from src.utils.error_recovery import get_recovery_manager, retry_with_fallback
43
- recovery = get_recovery_manager()
44
-
45
- print(" ✅ Recovery manager initialized")
46
- print(f" 📂 Checkpoint directory: {recovery.checkpoint_manager.checkpoint_dir}")
47
-
48
- # Test retry decorator
49
- retry_count = 0
50
-
51
- @retry_with_fallback(tool_name="test_tool")
52
- def test_tool():
53
- global retry_count
54
- retry_count += 1
55
- if retry_count < 2:
56
- raise Exception("Simulated failure")
57
- return {"success": True}
58
-
59
- result = test_tool()
60
- if result.get("success"):
61
- print(f" ✅ Retry decorator works (succeeded after {retry_count} attempts)")
62
- else:
63
- print(f" ⚠️ Retry failed after {retry_count} attempts")
64
-
65
- except Exception as e:
66
- print(f" ❌ Error: {e}")
67
-
68
- # Test 3: Token Budget Manager
69
- print("\n3️⃣ Testing Token Budget Manager...")
70
- try:
71
- from src.utils.token_budget import get_token_manager
72
- token_mgr = get_token_manager(model="gpt-4", max_tokens=128000)
73
-
74
- print(f" ✅ Token manager initialized")
75
- print(f" 📊 Available tokens: {token_mgr.available_tokens:,}")
76
-
77
- # Test token counting
78
- test_text = "This is a test sentence for token counting."
79
- tokens = token_mgr.count_tokens(test_text)
80
- print(f" ✅ Token counting works: '{test_text}' = {tokens} tokens")
81
-
82
- # Test compression
83
- large_result = '{"data": ' + str(list(range(1000))) + '}'
84
- compressed = token_mgr.compress_tool_result(large_result, max_tokens=100)
85
- print(f" ✅ Compression works: {len(large_result)} chars → {len(compressed)} chars")
86
-
87
- except Exception as e:
88
- print(f" ❌ Error: {e}")
89
-
90
- # Test 4: Parallel Executor
91
- print("\n4️⃣ Testing Parallel Tool Executor...")
92
- try:
93
- from src.utils.parallel_executor import get_parallel_executor, ToolExecution, ToolWeight
94
- parallel = get_parallel_executor()
95
-
96
- print(" ✅ Parallel executor initialized")
97
- print(f" ⚡ Max concurrent: Heavy={parallel.max_heavy}, Medium={parallel.max_medium}, Light={parallel.max_light}")
98
-
99
- # Test dependency detection
100
- executions = [
101
- ToolExecution("profile_dataset", {"file_path": "data.csv"}, ToolWeight.LIGHT, set(), "exec1"),
102
- ToolExecution("clean_missing_values", {"file_path": "data.csv", "output_path": "clean.csv"}, ToolWeight.MEDIUM, set(), "exec2"),
103
- ToolExecution("train_baseline_models", {"file_path": "clean.csv"}, ToolWeight.HEAVY, set(), "exec3")
104
- ]
105
-
106
- batches = parallel.dependency_graph.get_execution_batches(executions)
107
- print(f" ✅ Dependency detection works: {len(executions)} tools → {len(batches)} batches")
108
- for i, batch in enumerate(batches):
109
- tool_names = [ex.tool_name for ex in batch]
110
- print(f" Batch {i+1}: {tool_names}")
111
-
112
- except Exception as e:
113
- print(f" ❌ Error: {e}")
114
-
115
- # Test 5: Orchestrator Integration
116
- print("\n5️⃣ Testing Orchestrator Integration...")
117
- try:
118
- from src.orchestrator import DataScienceCopilot
119
-
120
- # Don't initialize fully (requires API keys), just check imports
121
- print(" ✅ Orchestrator imports all new systems successfully")
122
- print(" ℹ️ Full initialization requires API keys")
123
-
124
- # Check if systems are importable
125
- has_semantic = hasattr(DataScienceCopilot, '__init__') # Basic check
126
- print(" ✅ All systems ready for integration")
127
-
128
- except Exception as e:
129
- print(f" ❌ Error: {e}")
130
-
131
- # Summary
132
- print("\n" + "=" * 60)
133
- print("🎉 System Test Complete!")
134
- print("=" * 60)
135
- print("\n✅ All 4 improvements implemented and working:")
136
- print(" 1. SBERT Semantic Layer for column understanding & routing")
137
- print(" 2. Error Recovery with retry & checkpointing")
138
- print(" 3. Token Budget Management with compression")
139
- print(" 4. Parallel Tool Execution with dependency detection")
140
- print("\n📖 See SYSTEM_IMPROVEMENTS_SUMMARY.md for integration guide")
141
- print("=" * 60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_multi_agent.py DELETED
@@ -1,223 +0,0 @@
1
- """
2
- Test Multi-Agent Architecture Implementation
3
- """
4
-
5
- import os
6
- import sys
7
- from pathlib import Path
8
-
9
- # Add src to path
10
- sys.path.insert(0, str(Path(__file__).parent))
11
-
12
- from src.orchestrator import DataScienceCopilot
13
-
14
-
15
- def test_agent_initialization():
16
- """Test that specialist agents are initialized correctly."""
17
- print("\n🧪 Test 1: Agent Initialization")
18
- print("=" * 60)
19
-
20
- # Use groq provider which should be available
21
- try:
22
- agent = DataScienceCopilot(
23
- provider="groq",
24
- groq_api_key=os.getenv("GROQ_API_KEY", "dummy_key_for_testing"),
25
- use_session_memory=False # Don't need session for this test
26
- )
27
- except Exception as e:
28
- print(f" ⚠️ Could not initialize with Groq: {e}")
29
- print(" Testing agent structure without full initialization...")
30
- # Just test the agent initialization method directly
31
- from src.orchestrator import DataScienceCopilot
32
- test_instance = object.__new__(DataScienceCopilot)
33
- specialist_agents = test_instance._initialize_specialist_agents()
34
-
35
- # Check that specialist agents were created
36
- assert len(specialist_agents) == 5, f"❌ Expected 5 agents, got {len(specialist_agents)}"
37
-
38
- # Check all required agents exist
39
- expected_agents = ['eda_agent', 'modeling_agent', 'viz_agent', 'insight_agent', 'preprocessing_agent']
40
- for agent_key in expected_agents:
41
- assert agent_key in specialist_agents, f"❌ {agent_key} not found"
42
-
43
- config = specialist_agents[agent_key]
44
- assert 'name' in config, f"❌ {agent_key} missing 'name'"
45
- assert 'emoji' in config, f"❌ {agent_key} missing 'emoji'"
46
- assert 'description' in config, f"❌ {agent_key} missing 'description'"
47
- assert 'system_prompt' in config, f"❌ {agent_key} missing 'system_prompt'"
48
- assert 'tool_keywords' in config, f"❌ {agent_key} missing 'tool_keywords'"
49
-
50
- print(f" ✅ {config['emoji']} {config['name']} - {len(config['tool_keywords'])} keywords")
51
-
52
- print("\n✅ All agents initialized correctly!\n")
53
- return
54
-
55
- # Check that specialist agents were created
56
- assert hasattr(agent, 'specialist_agents'), "❌ specialist_agents not found"
57
- assert len(agent.specialist_agents) == 5, f"❌ Expected 5 agents, got {len(agent.specialist_agents)}"
58
-
59
- # Check all required agents exist
60
- expected_agents = ['eda_agent', 'modeling_agent', 'viz_agent', 'insight_agent', 'preprocessing_agent']
61
- for agent_key in expected_agents:
62
- assert agent_key in agent.specialist_agents, f"❌ {agent_key} not found"
63
-
64
- config = agent.specialist_agents[agent_key]
65
- assert 'name' in config, f"❌ {agent_key} missing 'name'"
66
- assert 'emoji' in config, f"❌ {agent_key} missing 'emoji'"
67
- assert 'description' in config, f"❌ {agent_key} missing 'description'"
68
- assert 'system_prompt' in config, f"❌ {agent_key} missing 'system_prompt'"
69
- assert 'tool_keywords' in config, f"❌ {agent_key} missing 'tool_keywords'"
70
-
71
- print(f" ✅ {config['emoji']} {config['name']} - {len(config['tool_keywords'])} keywords")
72
-
73
- print("\n✅ All agents initialized correctly!\n")
74
-
75
-
76
- def test_agent_routing():
77
- """Test that agent routing selects the correct specialist."""
78
- print("\n🧪 Test 2: Agent Routing Logic")
79
- print("=" * 60)
80
-
81
- try:
82
- agent = DataScienceCopilot(
83
- provider="groq",
84
- groq_api_key=os.getenv("GROQ_API_KEY", "dummy_key_for_testing"),
85
- use_session_memory=False
86
- )
87
- except Exception as e:
88
- print(f" ⚠️ Skipping routing test - initialization failed: {e}")
89
- return
90
-
91
- # Test cases: (task_description, expected_agent_key, expected_agent_name)
92
- test_cases = [
93
- ("Profile the dataset and check data quality", "eda_agent", "EDA Specialist"),
94
- ("Create a correlation heatmap", "viz_agent", "Visualization Specialist"),
95
- ("Train a model to predict sales", "modeling_agent", "ML Modeling Specialist"),
96
- ("Handle missing values and clean the data", "preprocessing_agent", "Data Engineering Specialist"),
97
- ("Explain why customer churn is high", "insight_agent", "Business Insights Specialist"),
98
- ("Generate a scatter plot", "viz_agent", "Visualization Specialist"),
99
- ("Tune hyperparameters", "modeling_agent", "ML Modeling Specialist"),
100
- ("Detect outliers", "eda_agent", "EDA Specialist"),
101
- ("Engineer new features", "preprocessing_agent", "Data Engineering Specialist"),
102
- ("What-if analysis", "insight_agent", "Business Insights Specialist"),
103
- ]
104
-
105
- passed = 0
106
- failed = 0
107
-
108
- for task_desc, expected_key, expected_name in test_cases:
109
- selected_key = agent._select_specialist_agent(task_desc)
110
- selected_config = agent.specialist_agents[selected_key]
111
- selected_name = selected_config['name']
112
-
113
- if selected_key == expected_key:
114
- print(f" ✅ '{task_desc[:40]}...' → {selected_config['emoji']} {selected_name}")
115
- passed += 1
116
- else:
117
- print(f" ❌ '{task_desc[:40]}...'")
118
- print(f" Expected: {agent.specialist_agents[expected_key]['emoji']} {expected_name}")
119
- print(f" Got: {selected_config['emoji']} {selected_name}")
120
- failed += 1
121
-
122
- print(f"\n📊 Results: {passed}/{len(test_cases)} passed, {failed}/{len(test_cases)} failed\n")
123
-
124
- if failed == 0:
125
- print("✅ All routing tests passed!\n")
126
- else:
127
- print("⚠️ Some routing tests failed - may need keyword tuning\n")
128
-
129
-
130
- def test_system_prompt_generation():
131
- """Test that specialist system prompts are generated correctly."""
132
- print("\n🧪 Test 3: System Prompt Generation")
133
- print("=" * 60)
134
-
135
- try:
136
- agent = DataScienceCopilot(
137
- provider="groq",
138
- groq_api_key=os.getenv("GROQ_API_KEY", "dummy_key_for_testing"),
139
- use_session_memory=False
140
- )
141
- except Exception as e:
142
- print(f" ⚠️ Skipping prompt test - initialization failed: {e}")
143
- return
144
-
145
- for agent_key, config in agent.specialist_agents.items():
146
- # Get the specialist's system prompt
147
- system_prompt = agent._get_agent_system_prompt(agent_key)
148
-
149
- # Check that it's not empty and is different from main prompt
150
- assert len(system_prompt) > 100, f"❌ {agent_key} prompt too short"
151
- assert config['name'] in system_prompt, f"❌ {agent_key} prompt doesn't mention agent name"
152
-
153
- print(f" ✅ {config['emoji']} {config['name']} - {len(system_prompt)} chars")
154
- print(f" Preview: {system_prompt[:80]}...")
155
-
156
- # Test fallback to main prompt
157
- fallback_prompt = agent._get_agent_system_prompt("non_existent_agent")
158
- assert len(fallback_prompt) > 100, "❌ Fallback prompt too short"
159
- print(f" ✅ Fallback to main orchestrator prompt works")
160
-
161
- print("\n✅ All system prompts generated correctly!\n")
162
-
163
-
164
- def test_backward_compatibility():
165
- """Test that all tools are still accessible."""
166
- print("\n🧪 Test 4: Backward Compatibility")
167
- print("=" * 60)
168
-
169
- try:
170
- agent = DataScienceCopilot(
171
- provider="groq",
172
- groq_api_key=os.getenv("GROQ_API_KEY", "dummy_key_for_testing"),
173
- use_session_memory=False
174
- )
175
- except Exception as e:
176
- print(f" ⚠️ Skipping compatibility test - initialization failed: {e}")
177
- return
178
-
179
- # Build tool functions map
180
- tool_functions = agent._build_tool_functions_map()
181
-
182
- print(f" ✅ {len(tool_functions)} tools still accessible")
183
-
184
- # Check that some key tools exist
185
- key_tools = [
186
- 'profile_dataset',
187
- 'train_baseline_models',
188
- 'generate_interactive_scatter', # Correct tool name
189
- 'clean_missing_values',
190
- 'generate_business_insights' # Correct tool name
191
- ]
192
-
193
- for tool_name in key_tools:
194
- assert tool_name in tool_functions, f"❌ Tool {tool_name} not found"
195
- print(f" ✅ {tool_name} available")
196
-
197
- print("\n✅ All key tools accessible - no breaking changes!\n")
198
-
199
-
200
- if __name__ == "__main__":
201
- print("\n" + "=" * 60)
202
- print("🔬 MULTI-AGENT ARCHITECTURE TEST SUITE")
203
- print("=" * 60)
204
-
205
- try:
206
- test_agent_initialization()
207
- test_agent_routing()
208
- test_system_prompt_generation()
209
- test_backward_compatibility()
210
-
211
- print("\n" + "=" * 60)
212
- print("✅ ALL TESTS PASSED!")
213
- print("=" * 60)
214
- print("\n🎉 Multi-agent architecture successfully implemented without breaking existing code!\n")
215
-
216
- except AssertionError as e:
217
- print(f"\n❌ TEST FAILED: {e}\n")
218
- sys.exit(1)
219
- except Exception as e:
220
- print(f"\n❌ UNEXPECTED ERROR: {e}\n")
221
- import traceback
222
- traceback.print_exc()
223
- sys.exit(1)