ChAbhishek28 commited on
Commit
f139d4e
·
1 Parent(s): c3ade54

Add 89999999999999999999

Browse files
Files changed (1) hide show
  1. enhanced_websocket_handler.py +28 -2
enhanced_websocket_handler.py CHANGED
@@ -11,12 +11,15 @@ import uuid
11
  import tempfile
12
  import base64
13
  from pathlib import Path
 
 
14
 
15
  from llm_service import create_graph, create_basic_graph
16
  from lancedb_service import lancedb_service
17
  from hybrid_llm_service import HybridLLMService
18
  from voice_service import voice_service
19
  from rag_service import search_government_docs
 
20
 
21
  # Initialize hybrid LLM service
22
  hybrid_llm_service = HybridLLMService()
@@ -257,7 +260,8 @@ async def handle_text_message(websocket: WebSocket, data: dict, session_data: di
257
  "date": chunk.get("date", ""),
258
  "url": chunk.get("url", ""),
259
  "score": chunk.get("score", 1.0),
260
- "scenario_analysis": chunk.get("scenario_analysis", None)
 
261
  })
262
  # Optionally, aggregate or select the best chunk for final response
263
  # Here, just use the first chunk for context update and provider
@@ -451,6 +455,27 @@ async def get_hybrid_response(user_message: str, context: str, config: dict, kno
451
  'inflation': 0.05
452
  }
453
  scenario_result = run_scenario_analysis(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  else:
455
  scenario_result = None
456
  for doc in docs:
@@ -463,7 +488,8 @@ async def get_hybrid_response(user_message: str, context: str, config: dict, kno
463
  "date": doc.get("date", ""),
464
  "url": doc.get("url", ""),
465
  "score": doc.get("score", 1.0),
466
- "scenario_analysis": scenario_result
 
467
  }
468
  yield response_obj
469
  else:
 
11
  import tempfile
12
  import base64
13
  from pathlib import Path
14
+ import io
15
+ import matplotlib.pyplot as plt
16
 
17
  from llm_service import create_graph, create_basic_graph
18
  from lancedb_service import lancedb_service
19
  from hybrid_llm_service import HybridLLMService
20
  from voice_service import voice_service
21
  from rag_service import search_government_docs
22
+ from policy_chart_generator import PolicyChartGenerator
23
 
24
  # Initialize hybrid LLM service
25
  hybrid_llm_service = HybridLLMService()
 
260
  "date": chunk.get("date", ""),
261
  "url": chunk.get("url", ""),
262
  "score": chunk.get("score", 1.0),
263
+ "scenario_analysis": chunk.get("scenario_analysis", None),
264
+ "charts": chunk.get("charts", [])
265
  })
266
  # Optionally, aggregate or select the best chunk for final response
267
  # Here, just use the first chunk for context update and provider
 
455
  'inflation': 0.05
456
  }
457
  scenario_result = run_scenario_analysis(params)
458
+ # Generate charts for scenario_result
459
+ chart_gen = PolicyChartGenerator()
460
+ charts = []
461
+ # Example: line chart for yearly results
462
+ if "yearly_results" in scenario_result:
463
+ years = [r['year'] for r in scenario_result['yearly_results']]
464
+ base_costs = [r['base_cost'] for r in scenario_result['yearly_results']]
465
+ scenario_costs = [r['scenario_cost'] for r in scenario_result['yearly_results']]
466
+ # Generate chart and append to charts list
467
+ fig, ax = plt.subplots()
468
+ ax.plot(years, base_costs, label='Base Cost')
469
+ ax.plot(years, scenario_costs, label='Scenario Cost')
470
+ ax.legend()
471
+ ax.set_title('Scenario Analysis: Cost Over Years')
472
+ buf = io.BytesIO()
473
+ fig.savefig(buf, format='png')
474
+ buf.seek(0)
475
+ chart_base64 = base64.b64encode(buf.read()).decode('utf-8')
476
+ plt.close(fig)
477
+ charts.append({"type": "line_chart", "data": chart_base64})
478
+ scenario_result["charts"] = charts
479
  else:
480
  scenario_result = None
481
  for doc in docs:
 
488
  "date": doc.get("date", ""),
489
  "url": doc.get("url", ""),
490
  "score": doc.get("score", 1.0),
491
+ "scenario_analysis": scenario_result,
492
+ "charts": scenario_result.get("charts", []) if scenario_result else []
493
  }
494
  yield response_obj
495
  else: