| import base64 |
| from typing import Any, Dict, List |
|
|
| from fastapi import APIRouter, HTTPException |
| from pydantic import BaseModel |
|
|
| from app.models.mirai_engine import MiraiEngine |
| from app.services.rag_service import RAGService |
| from app.services.tts_service import TTSService |
|
|
|
|
| class ChatRequest(BaseModel): |
| message: str |
|
|
|
|
| def _extract_rag_text(payload: Dict[str, Any]) -> str: |
| text = (payload.get("result") or payload.get("answer") or "").strip() |
| return text |
|
|
|
|
| def _extract_rag_sources(payload: Dict[str, Any]) -> List[str]: |
| sources: List[str] = [] |
| for doc in payload.get("source_documents", []): |
| meta = getattr(doc, "metadata", {}) or {} |
| source = meta.get("source") |
| if source: |
| sources.append(str(source)) |
| return sources |
|
|
|
|
| def create_chat_router(engine: MiraiEngine, tts_service: TTSService, rag_service: RAGService) -> APIRouter: |
| router = APIRouter(tags=["chat"]) |
|
|
| @router.post("/chat") |
| async def chat_endpoint(req: ChatRequest): |
| raw_message = (req.message or "").strip() |
| if not raw_message: |
| raise HTTPException(status_code=400, detail="Message is required") |
|
|
| reply_text = "" |
| source_paths: List[str] = [] |
|
|
| try: |
| rag_payload = rag_service.query(raw_message) |
| reply_text = _extract_rag_text(rag_payload) |
| source_paths = _extract_rag_sources(rag_payload) |
| except Exception: |
| reply_text = "" |
|
|
| if not reply_text: |
| if not engine.is_ready: |
| raise HTTPException(status_code=500, detail=engine.unavailable_reason) |
| try: |
| reply_text = engine.generate_reply(raw_message) |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
| except Exception as exc: |
| raise HTTPException(status_code=500, detail=str(exc)) from exc |
|
|
| try: |
| audio_bytes, lip_sync_data = await tts_service.generate_audio(reply_text) |
| audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") |
| except Exception as exc: |
| raise HTTPException(status_code=500, detail=str(exc)) from exc |
|
|
| return { |
| "text": reply_text, |
| "sources": source_paths, |
| "audio": audio_base64, |
| "lipSync": lip_sync_data, |
| } |
|
|
| return router |
|
|