Spaces:
Sleeping
Sleeping
Commit
·
f139d4e
1
Parent(s):
c3ade54
Add 89999999999999999999
Browse files
enhanced_websocket_handler.py
CHANGED
|
@@ -11,12 +11,15 @@ import uuid
|
|
| 11 |
import tempfile
|
| 12 |
import base64
|
| 13 |
from pathlib import Path
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from llm_service import create_graph, create_basic_graph
|
| 16 |
from lancedb_service import lancedb_service
|
| 17 |
from hybrid_llm_service import HybridLLMService
|
| 18 |
from voice_service import voice_service
|
| 19 |
from rag_service import search_government_docs
|
|
|
|
| 20 |
|
| 21 |
# Initialize hybrid LLM service
|
| 22 |
hybrid_llm_service = HybridLLMService()
|
|
@@ -257,7 +260,8 @@ async def handle_text_message(websocket: WebSocket, data: dict, session_data: di
|
|
| 257 |
"date": chunk.get("date", ""),
|
| 258 |
"url": chunk.get("url", ""),
|
| 259 |
"score": chunk.get("score", 1.0),
|
| 260 |
-
"scenario_analysis": chunk.get("scenario_analysis", None)
|
|
|
|
| 261 |
})
|
| 262 |
# Optionally, aggregate or select the best chunk for final response
|
| 263 |
# Here, just use the first chunk for context update and provider
|
|
@@ -451,6 +455,27 @@ async def get_hybrid_response(user_message: str, context: str, config: dict, kno
|
|
| 451 |
'inflation': 0.05
|
| 452 |
}
|
| 453 |
scenario_result = run_scenario_analysis(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
else:
|
| 455 |
scenario_result = None
|
| 456 |
for doc in docs:
|
|
@@ -463,7 +488,8 @@ async def get_hybrid_response(user_message: str, context: str, config: dict, kno
|
|
| 463 |
"date": doc.get("date", ""),
|
| 464 |
"url": doc.get("url", ""),
|
| 465 |
"score": doc.get("score", 1.0),
|
| 466 |
-
"scenario_analysis": scenario_result
|
|
|
|
| 467 |
}
|
| 468 |
yield response_obj
|
| 469 |
else:
|
|
|
|
| 11 |
import tempfile
|
| 12 |
import base64
|
| 13 |
from pathlib import Path
|
| 14 |
+
import io
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
|
| 17 |
from llm_service import create_graph, create_basic_graph
|
| 18 |
from lancedb_service import lancedb_service
|
| 19 |
from hybrid_llm_service import HybridLLMService
|
| 20 |
from voice_service import voice_service
|
| 21 |
from rag_service import search_government_docs
|
| 22 |
+
from policy_chart_generator import PolicyChartGenerator
|
| 23 |
|
| 24 |
# Initialize hybrid LLM service
|
| 25 |
hybrid_llm_service = HybridLLMService()
|
|
|
|
| 260 |
"date": chunk.get("date", ""),
|
| 261 |
"url": chunk.get("url", ""),
|
| 262 |
"score": chunk.get("score", 1.0),
|
| 263 |
+
"scenario_analysis": chunk.get("scenario_analysis", None),
|
| 264 |
+
"charts": chunk.get("charts", [])
|
| 265 |
})
|
| 266 |
# Optionally, aggregate or select the best chunk for final response
|
| 267 |
# Here, just use the first chunk for context update and provider
|
|
|
|
| 455 |
'inflation': 0.05
|
| 456 |
}
|
| 457 |
scenario_result = run_scenario_analysis(params)
|
| 458 |
+
# Generate charts for scenario_result
|
| 459 |
+
chart_gen = PolicyChartGenerator()
|
| 460 |
+
charts = []
|
| 461 |
+
# Example: line chart for yearly results
|
| 462 |
+
if "yearly_results" in scenario_result:
|
| 463 |
+
years = [r['year'] for r in scenario_result['yearly_results']]
|
| 464 |
+
base_costs = [r['base_cost'] for r in scenario_result['yearly_results']]
|
| 465 |
+
scenario_costs = [r['scenario_cost'] for r in scenario_result['yearly_results']]
|
| 466 |
+
# Generate chart and append to charts list
|
| 467 |
+
fig, ax = plt.subplots()
|
| 468 |
+
ax.plot(years, base_costs, label='Base Cost')
|
| 469 |
+
ax.plot(years, scenario_costs, label='Scenario Cost')
|
| 470 |
+
ax.legend()
|
| 471 |
+
ax.set_title('Scenario Analysis: Cost Over Years')
|
| 472 |
+
buf = io.BytesIO()
|
| 473 |
+
fig.savefig(buf, format='png')
|
| 474 |
+
buf.seek(0)
|
| 475 |
+
chart_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
| 476 |
+
plt.close(fig)
|
| 477 |
+
charts.append({"type": "line_chart", "data": chart_base64})
|
| 478 |
+
scenario_result["charts"] = charts
|
| 479 |
else:
|
| 480 |
scenario_result = None
|
| 481 |
for doc in docs:
|
|
|
|
| 488 |
"date": doc.get("date", ""),
|
| 489 |
"url": doc.get("url", ""),
|
| 490 |
"score": doc.get("score", 1.0),
|
| 491 |
+
"scenario_analysis": scenario_result,
|
| 492 |
+
"charts": scenario_result.get("charts", []) if scenario_result else []
|
| 493 |
}
|
| 494 |
yield response_obj
|
| 495 |
else:
|