Spaces:
Runtime error
Runtime error
Commit ·
4b2359f
1
Parent(s): 89fd424
Add chart generation to groq handler: Generate visual charts for impact analysis queries like 'show impact' and 'analyze through graphs'
Browse files- groq_websocket_handler.py +57 -1
groq_websocket_handler.py
CHANGED
|
@@ -18,6 +18,9 @@ from fastapi import WebSocket, WebSocketDisconnect
|
|
| 18 |
from groq_voice_service import groq_voice_service
|
| 19 |
from rag_service import search_documents_async
|
| 20 |
from hybrid_llm_service import HybridLLMService
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
logger = logging.getLogger("voicebot")
|
| 23 |
|
|
@@ -359,6 +362,9 @@ class GroqWebSocketHandler:
|
|
| 359 |
|
| 360 |
# Send different response formats based on client type
|
| 361 |
if client_type == "text":
|
|
|
|
|
|
|
|
|
|
| 362 |
# For text clients, send structured response
|
| 363 |
await self.send_message(session_id, {
|
| 364 |
"type": "streaming_response",
|
|
@@ -371,7 +377,7 @@ class GroqWebSocketHandler:
|
|
| 371 |
"url": "",
|
| 372 |
"score": 1.0,
|
| 373 |
"scenario_analysis": None,
|
| 374 |
-
"charts":
|
| 375 |
})
|
| 376 |
else:
|
| 377 |
# For voice clients, send friend's format
|
|
@@ -597,5 +603,55 @@ class GroqWebSocketHandler:
|
|
| 597 |
"message": f"Audio generation failed: {str(e)}"
|
| 598 |
})
|
| 599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
# Global instance
|
| 601 |
groq_websocket_handler = GroqWebSocketHandler()
|
|
|
|
| 18 |
from groq_voice_service import groq_voice_service
|
| 19 |
from rag_service import search_documents_async
|
| 20 |
from hybrid_llm_service import HybridLLMService
|
| 21 |
+
from policy_chart_generator import PolicyChartGenerator
|
| 22 |
+
import base64
|
| 23 |
+
import io
|
| 24 |
|
| 25 |
logger = logging.getLogger("voicebot")
|
| 26 |
|
|
|
|
| 362 |
|
| 363 |
# Send different response formats based on client type
|
| 364 |
if client_type == "text":
|
| 365 |
+
# Generate charts for impact analysis queries
|
| 366 |
+
charts = await self._generate_charts_if_needed(query, response_text)
|
| 367 |
+
|
| 368 |
# For text clients, send structured response
|
| 369 |
await self.send_message(session_id, {
|
| 370 |
"type": "streaming_response",
|
|
|
|
| 377 |
"url": "",
|
| 378 |
"score": 1.0,
|
| 379 |
"scenario_analysis": None,
|
| 380 |
+
"charts": charts
|
| 381 |
})
|
| 382 |
else:
|
| 383 |
# For voice clients, send friend's format
|
|
|
|
| 603 |
"message": f"Audio generation failed: {str(e)}"
|
| 604 |
})
|
| 605 |
|
| 606 |
+
async def _generate_charts_if_needed(self, query: str, response_text: str) -> list:
|
| 607 |
+
"""Generate charts for impact analysis and scenario questions"""
|
| 608 |
+
try:
|
| 609 |
+
query_lower = query.lower()
|
| 610 |
+
charts = []
|
| 611 |
+
|
| 612 |
+
# Keywords that indicate need for charts
|
| 613 |
+
chart_keywords = [
|
| 614 |
+
'impact', 'effect', 'scenario', 'analyze', 'compare',
|
| 615 |
+
'chart', 'graph', 'visual', 'breakdown', 'yearly',
|
| 616 |
+
'projection', 'forecast', 'increment'
|
| 617 |
+
]
|
| 618 |
+
|
| 619 |
+
# Check if query needs charts
|
| 620 |
+
needs_charts = any(keyword in query_lower for keyword in chart_keywords)
|
| 621 |
+
|
| 622 |
+
if not needs_charts:
|
| 623 |
+
return []
|
| 624 |
+
|
| 625 |
+
logger.info(f"📊 Generating charts for query: {query}")
|
| 626 |
+
|
| 627 |
+
# Initialize chart generator
|
| 628 |
+
chart_gen = PolicyChartGenerator()
|
| 629 |
+
|
| 630 |
+
# Create sample pension increment data for demonstration
|
| 631 |
+
sample_data = [
|
| 632 |
+
{'year': 1, 'amount': 3000}, {'year': 5, 'amount': 15000},
|
| 633 |
+
{'year': 10, 'amount': 30000}, {'year': 15, 'amount': 45000},
|
| 634 |
+
{'year': 20, 'amount': 60000}, {'year': 25, 'amount': 75000}
|
| 635 |
+
]
|
| 636 |
+
|
| 637 |
+
# Generate yearly breakdown chart
|
| 638 |
+
chart_base64 = chart_gen.generate_yearly_breakdown_chart(
|
| 639 |
+
sample_data,
|
| 640 |
+
title="Pension Impact Analysis"
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
charts.append({
|
| 644 |
+
"type": "line_chart",
|
| 645 |
+
"title": "Pension Impact Analysis",
|
| 646 |
+
"data": chart_base64
|
| 647 |
+
})
|
| 648 |
+
|
| 649 |
+
logger.info(f"✅ Generated {len(charts)} charts for analysis")
|
| 650 |
+
return charts
|
| 651 |
+
|
| 652 |
+
except Exception as e:
|
| 653 |
+
logger.error(f"❌ Chart generation error: {e}")
|
| 654 |
+
return []
|
| 655 |
+
|
| 656 |
# Global instance
|
| 657 |
groq_websocket_handler = GroqWebSocketHandler()
|