Spaces:
Sleeping
Sleeping
| # ============================================================================ | |
| # backend/app/services/streaming_service.py - NEW FILE | |
| # ============================================================================ | |
| """ | |
| Streaming Service - Server-Sent Events (SSE) | |
| Handles real-time streaming of AI responses. | |
| Integrates with chat_service.py RAG pipeline. | |
| """ | |
| import asyncio | |
| import json | |
| from typing import AsyncGenerator, Dict, Any, List, Optional | |
| from datetime import datetime | |
| from app.config import settings | |
| from app.ml.policy_network import predict_policy_action | |
| from app.ml.retriever import retrieve_documents, format_context | |
| from app.core.llm_manager import llm_manager | |
| # ============================================================================ | |
| # STREAMING SERVICE | |
| # ============================================================================ | |
| class StreamingService: | |
| """ | |
| Handles SSE streaming for real-time chat responses. | |
| Events sent: | |
| - status: Progress updates (retrieval, generation stages) | |
| - content: Response chunks (word by word) | |
| - metadata: Final stats (policy action, docs retrieved, etc.) | |
| - done: Stream completion signal | |
| - error: Error occurred | |
| """ | |
| def __init__(self): | |
| print("🌊 StreamingService initialized") | |
| async def stream_chat_response( | |
| self, | |
| query: str, | |
| conversation_history: List[Dict[str, str]] = None, | |
| user_id: Optional[str] = None | |
| ) -> AsyncGenerator[str, None]: | |
| """ | |
| Stream chat response with progress updates. | |
| Yields SSE-formatted events: | |
| - event: status, content, metadata, done, error | |
| - data: JSON payload | |
| Args: | |
| query: User query | |
| conversation_history: Previous messages | |
| user_id: User ID | |
| Yields: | |
| str: SSE formatted events | |
| """ | |
| import time | |
| start_time = time.time() | |
| if conversation_history is None: | |
| conversation_history = [] | |
| try: | |
| # ================================================================ | |
| # STAGE 1: Policy Decision | |
| # ================================================================ | |
| yield self._format_sse_event( | |
| event="status", | |
| data={"stage": "policy", "message": "Analyzing query..."} | |
| ) | |
| await asyncio.sleep(0.1) # Small delay for UX | |
| policy_result = predict_policy_action( | |
| query=query, | |
| history=conversation_history, | |
| return_probs=True | |
| ) | |
| # ================================================================ | |
| # STAGE 2: Retrieval (if needed) | |
| # ================================================================ | |
| retrieved_docs = [] | |
| context = "" | |
| retrieval_time = 0 | |
| if policy_result['should_retrieve']: | |
| yield self._format_sse_event( | |
| event="status", | |
| data={"stage": "retrieval", "message": "Searching knowledge base..."} | |
| ) | |
| retrieval_start = time.time() | |
| try: | |
| retrieved_docs = retrieve_documents( | |
| query=query, | |
| top_k=settings.TOP_K, | |
| min_similarity=settings.SIMILARITY_THRESHOLD | |
| ) | |
| retrieval_time = (time.time() - retrieval_start) * 1000 | |
| if retrieved_docs: | |
| context = format_context( | |
| retrieved_docs, | |
| max_context_length=settings.MAX_CONTEXT_LENGTH | |
| ) | |
| yield self._format_sse_event( | |
| event="status", | |
| data={ | |
| "stage": "retrieval", | |
| "message": f"Found {len(retrieved_docs)} relevant documents" | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"⚠️ Retrieval error during streaming: {e}") | |
| # Continue without retrieval | |
| # ================================================================ | |
| # STAGE 3: Stream Generation | |
| # ================================================================ | |
| yield self._format_sse_event( | |
| event="status", | |
| data={"stage": "generation", "message": "Generating response..."} | |
| ) | |
| generation_start = time.time() | |
| full_response = "" | |
| # Stream from LLM | |
| async for chunk in llm_manager.stream_chat_response( | |
| query=query, | |
| context=context, | |
| history=conversation_history | |
| ): | |
| full_response += chunk | |
| yield self._format_sse_event( | |
| event="content", | |
| data={"text": chunk} | |
| ) | |
| generation_time = (time.time() - generation_start) * 1000 | |
| total_time = (time.time() - start_time) * 1000 | |
| # ================================================================ | |
| # STAGE 4: Send Metadata | |
| # ================================================================ | |
| metadata = { | |
| "policy_action": policy_result['action'], | |
| "policy_confidence": policy_result['confidence'], | |
| "documents_retrieved": len(retrieved_docs), | |
| "top_doc_score": retrieved_docs[0]['score'] if retrieved_docs else None, | |
| "retrieval_time_ms": round(retrieval_time, 2), | |
| "generation_time_ms": round(generation_time, 2), | |
| "total_time_ms": round(total_time, 2), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # Add retrieved docs metadata | |
| if retrieved_docs: | |
| metadata['retrieved_docs_metadata'] = [ | |
| { | |
| 'faq_id': doc['faq_id'], | |
| 'score': doc['score'], | |
| 'category': doc['category'], | |
| 'rank': doc['rank'] | |
| } | |
| for doc in retrieved_docs | |
| ] | |
| yield self._format_sse_event( | |
| event="metadata", | |
| data=metadata | |
| ) | |
| # ================================================================ | |
| # STAGE 5: Done | |
| # ================================================================ | |
| yield self._format_sse_event( | |
| event="done", | |
| data={"message": "Stream completed"} | |
| ) | |
| except Exception as e: | |
| print(f"❌ Streaming error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| yield self._format_sse_event( | |
| event="error", | |
| data={"error": str(e), "message": "An error occurred during streaming"} | |
| ) | |
| def _format_sse_event(self, event: str, data: Dict[str, Any]) -> str: | |
| """ | |
| Format data as SSE event. | |
| SSE format: | |
| event: <event_name> | |
| data: <json_data> | |
| (blank line to separate events) | |
| """ | |
| json_data = json.dumps(data, ensure_ascii=False) | |
| return f"event: {event}\ndata: {json_data}\n\n" | |
| # ============================================================================ | |
| # GLOBAL INSTANCE | |
| # ============================================================================ | |
| streaming_service = StreamingService() |