Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from backend.models import ChatRequest | |
| from backend.llm_utils import sanitize_history, route_message, get_reply | |
| from backend.rag_utils import get_user_data | |
| from backend.models import ChatRequest, SummaryRequest | |
| from backend.llm_utils import sanitize_history, route_message, get_reply, generate_chat_summary | |
| from backend.voice.stt import transcribe_audio | |
| from backend.voice.tts import synthesize_speech | |
| from fastapi import UploadFile, File, Form | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| import json | |
| import io | |
| import base64 | |
| from backend.cache_utils import get_cached_user_data, cache_user_data, cleanup_expired_cache | |
| import json | |
| import os | |
| from backend.credentials import setup_google_credentials | |
| setup_google_credentials() | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def chat_endpoint(req: ChatRequest): | |
| user_message = req.message | |
| history = req.history or [] | |
| user_id = req.uid | |
| if not user_message: | |
| return {"error": "message is required"} | |
| user_data = {} | |
| if user_id: | |
| try: | |
| user_data = get_user_data(user_id) | |
| except Exception as e: | |
| user_data = {} | |
| try: | |
| route = await route_message(user_message) | |
| simple_history = sanitize_history(history) | |
| simple_history.append({"role": "user", "content": user_message}) | |
| reply = await get_reply(route, simple_history, user_data, user_id) | |
| if not reply: | |
| reply = "I'm here to help with your wellness journey! What would you like to work on today?" | |
| return {"reply": reply} | |
| except Exception as e: | |
| return {"reply": "Sorry, I'm having trouble right now. Could you try again in a moment?"} | |
| import time | |
| import asyncio | |
| async def summarize_endpoint(req: SummaryRequest): | |
| start_time = time.time() | |
| try: | |
| messages = req.messages | |
| if not messages: | |
| print(f"[TIMING] Summary - No messages: {(time.time() - start_time):.2f}ms") | |
| return {"summary": "New Chat"} | |
| import_start = time.time() | |
| from backend.llm_utils import generate_chat_summary | |
| print(f"[TIMING] Summary - Import: {(time.time() - import_start):.2f}s") | |
| summary_start = time.time() | |
| summary = await generate_chat_summary(messages) | |
| print(f"[TIMING] Summary - Generation: {(time.time() - summary_start):.2f}ms") | |
| print(f"[TIMING] Summary - Total: {(time.time() - start_time):.2f}ms") | |
| return {"summary": summary} | |
| except Exception as e: | |
| print(f"[TIMING] Summary - Error after {(time.time() - start_time):.2f}ms:", e) | |
| return {"summary": "New Chat"} | |
| async def voice_chat_endpoint( | |
| file: UploadFile = File(...), | |
| history: str = Form(None), | |
| uid: str = Form(None), | |
| voice: str = Form("alloy") | |
| ): | |
| start_time = time.time() | |
| try: | |
| # Step 1: File reading | |
| file_start = time.time() | |
| audio_bytes = await file.read() | |
| print(f"[TIMING] Voice - File read: {(time.time() - file_start) :.2f}ms ({len(audio_bytes)} bytes)") | |
| # Step 2: Start transcription immediately | |
| transcription_start = time.time() | |
| transcription_task = asyncio.create_task(transcribe_audio(audio_bytes, ".m4a")) | |
| # Step 3: Prepare other data in parallel | |
| user_data_task = None | |
| if uid: | |
| user_data_start = time.time() | |
| user_data_task = asyncio.create_task(get_user_data_async(uid)) | |
| print(f"[TIMING] Voice - User data task started: {(time.time() - user_data_start):.2f}ms") | |
| # Step 4: Parse history while transcription runs | |
| history_start = time.time() | |
| simple_history = json.loads(history) if history else [] | |
| print(f"[TIMING] Voice - History parsing: {(time.time() - history_start):.2f}ms ({len(simple_history)} messages)") | |
| # Step 5: Wait for transcription | |
| transcription_wait_start = time.time() | |
| user_message = await transcription_task | |
| print(f"[TIMING] Voice - Transcription total: {(time.time() - transcription_start):.2f}ms") | |
| print(f"[TIMING] Voice - Transcription wait: {(time.time() - transcription_wait_start):.2f}ms") | |
| print("WHISPER transcript:", repr(user_message)) | |
| if not user_message.strip(): | |
| print(f"[TIMING] Voice - Empty transcript, returning early: {(time.time() - start_time) :.2f}ms") | |
| return {"user_transcript": "", "reply": "I didn't catch that", "audio_base64": ""} | |
| # Step 6: Get user data (if task was started) | |
| user_data = {} | |
| if user_data_task: | |
| user_data_wait_start = time.time() | |
| try: | |
| user_data = await user_data_task | |
| print(f"[TIMING] Voice - User data retrieval: {(time.time() - user_data_wait_start) :.2f}ms") | |
| except Exception as e: | |
| print(f"[TIMING] Voice - User data error after {(time.time() - user_data_wait_start) :.2f}ms: {e}") | |
| user_data = {} | |
| # Step 7: Process through your logic | |
| history_append_start = time.time() | |
| simple_history.append({"role": "user", "content": user_message}) | |
| print(f"[TIMING] Voice - History append: {(time.time() - history_append_start) :.2f}ms") | |
| # Step 8: Run routing | |
| routing_start = time.time() | |
| route_task = asyncio.create_task(route_message(user_message)) | |
| route = await route_task | |
| print(f"[TIMING] Voice - Message routing: {(time.time() - routing_start):.2f}ms (route: {route})") | |
| # Step 9: Generate reply | |
| reply_start = time.time() | |
| reply = await get_reply(route, simple_history, user_data, uid) | |
| if not reply: | |
| reply = "I'm here to help with your wellness journey! What would you like to work on today?" | |
| print(f"[TIMING] Voice - Reply generation: {(time.time() - reply_start) :.2f}ms") | |
| # Step 10: Generate speech | |
| tts_start = time.time() | |
| audio_data = await synthesize_speech(reply, voice) | |
| print(f"[TIMING] Voice - TTS generation: {(time.time() - tts_start):.2f}ms") | |
| # Step 11: Base64 encoding | |
| encoding_start = time.time() | |
| base64_audio = base64.b64encode(audio_data).decode() | |
| print(f"[TIMING] Voice - Base64 encoding: {(time.time() - encoding_start) :.2f}ms") | |
| # Total timing | |
| total_time = (time.time() - start_time) | |
| print(f"[TIMING] Voice - TOTAL PIPELINE: {total_time:.2f}ms") | |
| # Breakdown summary | |
| print(f"[TIMING] Voice - BREAKDOWN:") | |
| print(f" • File read: {(file_start - start_time) :.2f}ms") | |
| print(f" • Transcription: {(time.time() - transcription_start) :.2f}ms") | |
| print(f" • Routing: {(time.time() - routing_start) :.2f}ms") | |
| print(f" • Reply: {(time.time() - reply_start) :.2f}ms") | |
| print(f" • TTS: {(time.time() - tts_start) :.2f}ms") | |
| return { | |
| "user_transcript": user_message, | |
| "reply": reply, | |
| "audio_base64": base64_audio | |
| } | |
| except Exception as e: | |
| error_time = (time.time() - start_time) | |
| print(f"[TIMING] Voice - ERROR after {error_time:.2f}ms:", e) | |
| return JSONResponse({"error": str(e)}, status_code=500) | |
| # Add async wrapper for get_user_data | |
| async def get_user_data_async(uid: str): | |
| start_time = time.time() | |
| # Try to get from cache first | |
| cached_data = get_cached_user_data(uid) | |
| if cached_data: | |
| print(f"[TIMING] User data (cached): {(time.time() - start_time) :.2f}ms") | |
| return cached_data | |
| # Cache miss - fetch fresh data | |
| print("[CACHE] User data cache miss, fetching fresh data...") | |
| result = get_user_data(uid) | |
| print(f"[TIMING] User data fetch: {(time.time() - start_time) :.2f}ms") | |
| return result | |
| async def cache_stats_endpoint(): | |
| """Get cache performance statistics""" | |
| from backend.cache_utils import get_cache_stats, cleanup_expired_cache | |
| cleanup_expired_cache() # Clean up while we're at it | |
| stats = get_cache_stats() | |
| return stats | |
| async def clear_cache_endpoint(user_id: str = None): | |
| """Clear cache for specific user or all users""" | |
| from backend.cache_utils import clear_user_cache | |
| clear_user_cache(user_id) | |
| return {"message": f"Cache cleared for {'all users' if not user_id else f'user {user_id}'}"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 3000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |