import logging import asyncio from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends from sqlalchemy.orm import Session from datetime import datetime from typing import Dict, Any from api.auth import get_current_user_ws from models import db_models from core.database import get_db from services.flashcard_service import flashcard_service from services.quiz_service import quiz_service from services.report_service import report_service from services.mindmap_service import mindmap_service from services.podcast_service import podcast_service from services.s3_service import s3_service from services.video_generator_service import video_generator_service from services.slides_video_service import slides_video_service from models.schemas import VideoSummaryGenerateRequest, ReportGenerateRequest, MindMapGenerateRequest router = APIRouter(prefix="/ws", tags=["websockets"]) logger = logging.getLogger(__name__) class ConnectionManager: """Manages WebSocket connections for parallel execution""" def __init__(self): self.active_connections: Dict[str, WebSocket] = {} async def connect(self, websocket: WebSocket, connection_id: str): await websocket.accept() self.active_connections[connection_id] = websocket logger.info(f"WebSocket connected: {connection_id}") def disconnect(self, connection_id: str): if connection_id in self.active_connections: del self.active_connections[connection_id] logger.info(f"WebSocket disconnected: {connection_id}") async def send_progress(self, connection_id: str, progress: int, status: str, message: str = ""): if connection_id in self.active_connections: try: await self.active_connections[connection_id].send_json({ "type": "progress", "progress": progress, "status": status, "message": message }) except Exception as e: logger.error(f"Error sending progress to {connection_id}: {e}") async def send_result(self, connection_id: str, data: Any): if connection_id in self.active_connections: try: await self.active_connections[connection_id].send_json({ "type": "result", "status": "complete", "progress": 100, "data": data }) except Exception as e: logger.error(f"Error sending result to {connection_id}: {e}") async def send_error(self, connection_id: str, error: str): if connection_id in self.active_connections: try: await self.active_connections[connection_id].send_json({ "type": "error", "status": "error", "message": error }) except Exception as e: logger.error(f"Error sending error to {connection_id}: {e}") manager = ConnectionManager() @router.websocket("/generate") async def unified_generate_ws( websocket: WebSocket, token: str, db: Session = Depends(get_db)): """ Unified WebSocket gateway for all generation tasks. Client sends JSON: { "type": "podcast|flashcards|quiz|mindmap|report|video", "data": { ... } } """ await websocket.accept() try: current_user = await get_current_user_ws(token, db) connection_id = f"user_{current_user.id}" manager.active_connections[connection_id] = websocket # Receive the task specification message = await websocket.receive_json() task_type = message.get("type") data = message.get("data", {}) if not task_type: await manager.send_error(connection_id, "Missing 'type' in request") return await manager.send_progress(connection_id, 2, "processing", f"Initializing {task_type} task...") # --- ROUTING LOGIC --- if task_type == "podcast": await handle_podcast_task(connection_id, data, current_user, db) elif task_type == "video": await handle_video_task(connection_id, data, current_user, db) elif task_type == "report": await handle_report_task(connection_id, data, current_user, db) elif task_type == "mindmap": await handle_mindmap_task(connection_id, data, current_user, db) elif task_type == "flashcards": await handle_flashcards_task(connection_id, data, current_user, db) elif task_type == "quiz": await handle_quiz_task(connection_id, data, current_user, db) else: await manager.send_error(connection_id, f"Unsupported task type: {task_type}") except WebSocketDisconnect: logger.info(f"Client disconnected") except Exception as e: logger.error(f"Unified WebSocket error: {e}") try: await manager.send_error(connection_id, str(e)) except: pass finally: if 'connection_id' in locals(): manager.disconnect(connection_id) async def handle_podcast_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): """Internal handler for podcast generation""" try: source_id = None if data.get("file_key"): source = db.query(db_models.Source).filter( db_models.Source.s3_key == data["file_key"], db_models.Source.user_id == current_user.id ).first() if not source: await manager.send_error(connection_id, "Not authorized to access this file") return source_id = source.id file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None title = f"Podcast-{file_base}" if file_base else f"Podcast {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" db_podcast = db_models.Podcast( title=title, user_id=current_user.id, source_id=source_id, status="processing" ) db.add(db_podcast) db.commit() db.refresh(db_podcast) db_podcast.status = "processing" db.commit() analysis_report = "" if data.get("file_key"): await manager.send_progress(connection_id, 10, "processing", "Analyzing source file...") analysis_report = await podcast_service.analyze_pdf( file_key=data["file_key"], duration_minutes=data.get("duration_minutes", 10) ) await manager.send_progress(connection_id, 15, "processing", "Generating podcast script...") script = await podcast_service.generate_script( user_prompt=data["user_prompt"], model=data.get("model", "gpt-4o"), duration_minutes=data.get("duration_minutes", 10), podcast_format=data.get("podcast_format", "conversational"), pdf_suggestions=analysis_report, file_key=data.get("file_key") ) if not script: raise Exception("Failed to generate script") await manager.send_progress(connection_id, 45, "processing", "Generating audio...") audio_path = await podcast_service.generate_full_audio( script=script, tts_model=data.get("tts_model", "gemini-2.0-flash-exp"), spk1_voice=data.get("spk1_voice", "Puck"), spk2_voice=data.get("spk2_voice", "Charon"), temperature=data.get("temperature", 1.0), bgm_choice=data.get("bgm_choice", "No BGM") ) if not audio_path: raise Exception("Failed to generate audio") await manager.send_progress(connection_id, 90, "processing", "Uploading to S3...") import os filename = os.path.basename(audio_path) s3_key = f"users/{current_user.id}/outputs/podcasts/{filename}" def upload_audio_sync(): with open(audio_path, "rb") as f: content = f.read() import boto3 from core.config import settings boto3.client('s3', aws_access_key_id=settings.AWS_ACCESS_KEY_ID, aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, region_name=settings.AWS_REGION).put_object(Bucket=settings.AWS_S3_BUCKET, Key=s3_key, Body=content) await asyncio.to_thread(upload_audio_sync) public_url = s3_service.get_public_url(s3_key) db_podcast.s3_key = s3_key db_podcast.s3_url = public_url db_podcast.script = script db_podcast.status = "completed" db.commit() if os.path.exists(audio_path): os.remove(audio_path) await manager.send_result(connection_id, { "id": db_podcast.id, "status": "completed", "message": "Podcast generated successfully", "public_url": public_url }) except Exception as e: logger.error(f"Podcast task failed: {e}") if 'db_podcast' in locals(): db_podcast.status = "failed" db_podcast.error_message = str(e) db.commit() await manager.send_error(connection_id, str(e)) async def handle_flashcards_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): """Internal handler for flashcard generation""" try: source_id = None source = None if data.get("file_key"): source = db.query(db_models.Source).filter( db_models.Source.s3_key == data["file_key"], db_models.Source.user_id == current_user.id ).first() if not source: await manager.send_error(connection_id, "Not authorized to access this file") return source_id = source.id # Create initial processing record file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None if file_base: title = f"Flashcard-{file_base}" elif data.get("topic") and data.get("topic") != "string": title = data.get("topic") else: title = f"Flashcards {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" db_set = db_models.FlashcardSet( title=title, difficulty=data.get("difficulty", "medium"), user_id=current_user.id, source_id=source_id, status="processing" ) db.add(db_set) db.commit() db.refresh(db_set) await manager.send_progress(connection_id, 10, "processing", "Generating flashcards...") cards_data = await flashcard_service.generate_flashcards( file_key=data.get("file_key"), text_input=data.get("text_input"), difficulty=data.get("difficulty", "medium"), quantity=data.get("quantity", "standard"), topic=data.get("topic"), language=data.get("language", "English"), progress_callback=lambda p, m: asyncio.create_task( manager.send_progress(connection_id, 10 + int(p * 0.7), "processing", m) ) ) if not cards_data: raise Exception("AI returned empty flashcard data") await manager.send_progress(connection_id, 85, "processing", "Saving to database...") for item in cards_data: db_card = db_models.Flashcard( flashcard_set_id=db_set.id, question=item.get("question", ""), answer=item.get("answer", "") ) db.add(db_card) db_set.status = "completed" db.commit() await manager.send_result(connection_id, { "id": db_set.id, "title": db_set.title, "flashcards_count": len(db_set.flashcards), "status": "completed" }) except Exception as e: logger.error(f"Flashcard task failed: {e}") if 'db_set' in locals(): db_set.status = "failed" db_set.error_message = str(e) db.commit() await manager.send_error(connection_id, str(e)) async def handle_quiz_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): """Internal handler for quiz generation""" try: source_id = None if data.get("file_key"): source = db.query(db_models.Source).filter( db_models.Source.s3_key == data["file_key"], db_models.Source.user_id == current_user.id ).first() if not source: await manager.send_error(connection_id, "Not authorized to access this file") return source_id = source.id # Create initial processing record file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None if file_base: title = f"Quiz-{file_base}" elif data.get("topic") and data.get("topic") != "string": title = data.get("topic") else: title = f"Quiz {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" db_set = db_models.QuizSet( title=title, difficulty=data.get("difficulty", "medium"), user_id=current_user.id, source_id=source_id, status="processing" ) db.add(db_set) db.commit() db.refresh(db_set) await manager.send_progress(connection_id, 10, "processing", "Generating quiz...") quizzes_data = await quiz_service.generate_quiz( file_key=data.get("file_key"), text_input=data.get("text_input"), difficulty=data.get("difficulty", "medium"), topic=data.get("topic"), language=data.get("language", "English"), count_mode=data.get("count", "STANDARD"), progress_callback=lambda p, m: asyncio.create_task( manager.send_progress(connection_id, 10 + int(p * 0.7), "processing", m) ) ) if not quizzes_data: raise Exception("AI failed to generate quiz data") for item in quizzes_data: db_question = db_models.QuizQuestion( quiz_set_id=db_set.id, question=item.get("question", ""), choices=item.get("choices", {}), answer=str(item.get("answer", "1")), explanation=item.get("explanation", "") ) db.add(db_question) db_set.status = "completed" db.commit() await manager.send_result(connection_id, {"id": db_set.id, "title": db_set.title, "status": "completed"}) except Exception as e: logger.error(f"Quiz task failed: {e}") if 'db_set' in locals(): db_set.status = "failed" db_set.error_message = str(e) db.commit() await manager.send_error(connection_id, str(e)) async def handle_video_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): """Internal handler for video summary generation""" try: source = db.query(db_models.Source).filter( db_models.Source.s3_key == data.get("file_key"), db_models.Source.user_id == current_user.id ).first() if not source: await manager.send_error(connection_id, "Not authorized to access this file") return file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None title = f"Video Summary {file_base}" if file_base else f"Video Summary {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" db_summary = db_models.VideoSummary( title=title, user_id=current_user.id, source_id=source.id, status="processing" ) db.add(db_summary) db.commit() db.refresh(db_summary) await manager.send_progress(connection_id, 10, "processing", "Starting video generation...") if data.get("use_slides_transformation", True): result = await slides_video_service.generate_transformed_video_summary( file_key=data["file_key"], language=data.get("language", "Japanese"), voice_name=data.get("voice_name", "Kore"), custom_prompt=data.get("custom_prompt", "") ) else: result = await video_generator_service.generate_video_summary( file_key=data["file_key"], language=data.get("language", "Japanese"), voice_name=data.get("voice_name", "Kore") ) db_summary.title = result["title"] db_summary.s3_key = result["s3_key"] db_summary.s3_url = result["s3_url"] db_summary.status = "completed" db.commit() await manager.send_result(connection_id, { "type": "video", "id": db_summary.id, "status": "completed", "title": db_summary.title, "public_url": db_summary.s3_url }) except Exception as e: logger.error(f"Video task failed: {e}") if 'db_summary' in locals(): db_summary.status = "failed" db_summary.error_message = str(e) db.commit() await manager.send_error(connection_id, str(e)) async def handle_report_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): """Internal handler for report generation""" try: source_id = None if data.get("file_key"): source = db.query(db_models.Source).filter( db_models.Source.s3_key == data["file_key"], db_models.Source.user_id == current_user.id ).first() if not source: await manager.send_error(connection_id, "Not authorized to access this file") return source_id = source.id file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None title = f"Report-{file_base}" if file_base else f"Report {data.get('format_key', 'custom')} {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" db_report = db_models.Report( title=title, format_key=data.get("format_key", "custom"), user_id=current_user.id, source_id=source_id, status="processing" ) db.add(db_report) db.commit() db.refresh(db_report) await manager.send_progress(connection_id, 15, "processing", "Generating report content...") content = await report_service.generate_report( file_key=data.get("file_key"), text_input=data.get("text_input"), format_key=data.get("format_key", "briefing_doc"), custom_prompt=data.get("custom_prompt"), language=data.get("language", "Japanese") ) if not content: raise Exception("AI failed to generate report content") if not db_report.title or "Report-" not in db_report.title: title = content.split('\n')[0].replace('#', '').strip() if not title or len(title) < 3: title = f"Report {data.get('format_key')}" db_report.title = title db_report.content = content db_report.status = "completed" db.commit() await manager.send_result(connection_id, { "type": "report", "id": db_report.id, "status": "completed", "title": db_report.title }) except Exception as e: logger.error(f"Report task failed: {e}") if 'db_report' in locals(): db_report.status = "failed" db_report.error_message = str(e) db.commit() await manager.send_error(connection_id, str(e)) async def handle_mindmap_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): """Internal handler for mindmap generation""" try: source_id = None if data.get("file_key"): source = db.query(db_models.Source).filter( db_models.Source.s3_key == data["file_key"], db_models.Source.user_id == current_user.id ).first() if not source: await manager.send_error(connection_id, "Not authorized to access this file") return source_id = source.id file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None if file_base: title = f"Mind Map-{file_base}" elif data.get("title") and data.get("title") != "string": title = data.get("title") else: title = f"Mind Map {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" db_mindmap = db_models.MindMap( title=title, user_id=current_user.id, source_id=source_id, status="processing" ) db.add(db_mindmap) db.commit() db.refresh(db_mindmap) await manager.send_progress(connection_id, 20, "processing", "Generating mind map visualization...") mermaid_code = await mindmap_service.generate_mindmap( file_key=data.get("file_key"), text_input=data.get("text_input") ) if not mermaid_code: raise Exception("AI failed to generate mind map code") db_mindmap.mermaid_code = mermaid_code db_mindmap.status = "completed" db.commit() await manager.send_result(connection_id, { "type": "mindmap", "id": db_mindmap.id, "status": "completed", "title": db_mindmap.title }) except Exception as e: logger.error(f"Mindmap task failed: {e}") if 'db_mindmap' in locals(): db_mindmap.status = "failed" db_mindmap.error_message = str(e) db.commit() await manager.send_error(connection_id, str(e))