| | import logging |
| | import asyncio |
| | from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends |
| | from sqlalchemy.orm import Session |
| | from datetime import datetime |
| | from typing import Dict, Any |
| |
|
| | from api.auth import get_current_user_ws |
| | from models import db_models |
| | from core.database import get_db |
| | from services.flashcard_service import flashcard_service |
| | from services.quiz_service import quiz_service |
| | from services.report_service import report_service |
| | from services.mindmap_service import mindmap_service |
| | from services.podcast_service import podcast_service |
| | from services.s3_service import s3_service |
| | from services.video_generator_service import video_generator_service |
| | from services.slides_video_service import slides_video_service |
| | from models.schemas import VideoSummaryGenerateRequest, ReportGenerateRequest, MindMapGenerateRequest |
| |
|
| | router = APIRouter(prefix="/ws", tags=["websockets"]) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class ConnectionManager: |
| | """Manages WebSocket connections for parallel execution""" |
| | |
| | def __init__(self): |
| | self.active_connections: Dict[str, WebSocket] = {} |
| | |
| | async def connect(self, websocket: WebSocket, connection_id: str): |
| | await websocket.accept() |
| | self.active_connections[connection_id] = websocket |
| | logger.info(f"WebSocket connected: {connection_id}") |
| | |
| | def disconnect(self, connection_id: str): |
| | if connection_id in self.active_connections: |
| | del self.active_connections[connection_id] |
| | logger.info(f"WebSocket disconnected: {connection_id}") |
| | |
| | async def send_progress(self, connection_id: str, progress: int, status: str, message: str = ""): |
| | if connection_id in self.active_connections: |
| | try: |
| | await self.active_connections[connection_id].send_json({ |
| | "type": "progress", |
| | "progress": progress, |
| | "status": status, |
| | "message": message |
| | }) |
| | except Exception as e: |
| | logger.error(f"Error sending progress to {connection_id}: {e}") |
| | |
| | async def send_result(self, connection_id: str, data: Any): |
| | if connection_id in self.active_connections: |
| | try: |
| | await self.active_connections[connection_id].send_json({ |
| | "type": "result", |
| | "status": "complete", |
| | "progress": 100, |
| | "data": data |
| | }) |
| | except Exception as e: |
| | logger.error(f"Error sending result to {connection_id}: {e}") |
| | |
| | async def send_error(self, connection_id: str, error: str): |
| | if connection_id in self.active_connections: |
| | try: |
| | await self.active_connections[connection_id].send_json({ |
| | "type": "error", |
| | "status": "error", |
| | "message": error |
| | }) |
| | except Exception as e: |
| | logger.error(f"Error sending error to {connection_id}: {e}") |
| |
|
| |
|
| | manager = ConnectionManager() |
| |
|
| |
|
| | @router.websocket("/generate") |
| | async def unified_generate_ws( |
| | websocket: WebSocket, |
| | token: str, |
| | db: Session = Depends(get_db)): |
| | """ |
| | Unified WebSocket gateway for all generation tasks. |
| | Client sends JSON: { "type": "podcast|flashcards|quiz|mindmap|report|video", "data": { ... } } |
| | """ |
| | await websocket.accept() |
| | |
| | try: |
| | current_user = await get_current_user_ws(token, db) |
| | connection_id = f"user_{current_user.id}" |
| | manager.active_connections[connection_id] = websocket |
| | |
| | |
| | message = await websocket.receive_json() |
| | task_type = message.get("type") |
| | data = message.get("data", {}) |
| | |
| | if not task_type: |
| | await manager.send_error(connection_id, "Missing 'type' in request") |
| | return |
| |
|
| | await manager.send_progress(connection_id, 2, "processing", f"Initializing {task_type} task...") |
| |
|
| | |
| | |
| | if task_type == "podcast": |
| | await handle_podcast_task(connection_id, data, current_user, db) |
| | |
| | elif task_type == "video": |
| | await handle_video_task(connection_id, data, current_user, db) |
| | |
| | elif task_type == "report": |
| | await handle_report_task(connection_id, data, current_user, db) |
| | |
| | elif task_type == "mindmap": |
| | await handle_mindmap_task(connection_id, data, current_user, db) |
| |
|
| | elif task_type == "flashcards": |
| | await handle_flashcards_task(connection_id, data, current_user, db) |
| |
|
| | elif task_type == "quiz": |
| | await handle_quiz_task(connection_id, data, current_user, db) |
| | |
| | else: |
| | await manager.send_error(connection_id, f"Unsupported task type: {task_type}") |
| |
|
| | except WebSocketDisconnect: |
| | logger.info(f"Client disconnected") |
| | except Exception as e: |
| | logger.error(f"Unified WebSocket error: {e}") |
| | try: |
| | await manager.send_error(connection_id, str(e)) |
| | except: pass |
| | finally: |
| | if 'connection_id' in locals(): |
| | manager.disconnect(connection_id) |
| |
|
| | async def handle_podcast_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): |
| | """Internal handler for podcast generation""" |
| | try: |
| | source_id = None |
| | if data.get("file_key"): |
| | source = db.query(db_models.Source).filter( |
| | db_models.Source.s3_key == data["file_key"], |
| | db_models.Source.user_id == current_user.id |
| | ).first() |
| | if not source: |
| | await manager.send_error(connection_id, "Not authorized to access this file") |
| | return |
| | source_id = source.id |
| | |
| | file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None |
| | title = f"Podcast-{file_base}" if file_base else f"Podcast {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" |
| | |
| | db_podcast = db_models.Podcast( |
| | title=title, |
| | user_id=current_user.id, |
| | source_id=source_id, |
| | status="processing" |
| | ) |
| | db.add(db_podcast) |
| | db.commit() |
| | db.refresh(db_podcast) |
| |
|
| | db_podcast.status = "processing" |
| | db.commit() |
| |
|
| | analysis_report = "" |
| | if data.get("file_key"): |
| | await manager.send_progress(connection_id, 10, "processing", "Analyzing source file...") |
| | analysis_report = await podcast_service.analyze_pdf( |
| | file_key=data["file_key"], |
| | duration_minutes=data.get("duration_minutes", 10) |
| | ) |
| |
|
| | await manager.send_progress(connection_id, 15, "processing", "Generating podcast script...") |
| | script = await podcast_service.generate_script( |
| | user_prompt=data["user_prompt"], |
| | model=data.get("model", "gpt-4o"), |
| | duration_minutes=data.get("duration_minutes", 10), |
| | podcast_format=data.get("podcast_format", "conversational"), |
| | pdf_suggestions=analysis_report, |
| | file_key=data.get("file_key") |
| | ) |
| | |
| | if not script: raise Exception("Failed to generate script") |
| | |
| | await manager.send_progress(connection_id, 45, "processing", "Generating audio...") |
| | audio_path = await podcast_service.generate_full_audio( |
| | script=script, |
| | tts_model=data.get("tts_model", "gemini-2.0-flash-exp"), |
| | spk1_voice=data.get("spk1_voice", "Puck"), |
| | spk2_voice=data.get("spk2_voice", "Charon"), |
| | temperature=data.get("temperature", 1.0), |
| | bgm_choice=data.get("bgm_choice", "No BGM") |
| | ) |
| | |
| | if not audio_path: raise Exception("Failed to generate audio") |
| | |
| | await manager.send_progress(connection_id, 90, "processing", "Uploading to S3...") |
| | import os |
| | filename = os.path.basename(audio_path) |
| | s3_key = f"users/{current_user.id}/outputs/podcasts/{filename}" |
| | |
| | def upload_audio_sync(): |
| | with open(audio_path, "rb") as f: |
| | content = f.read() |
| | import boto3 |
| | from core.config import settings |
| | boto3.client('s3', |
| | aws_access_key_id=settings.AWS_ACCESS_KEY_ID, |
| | aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, |
| | region_name=settings.AWS_REGION).put_object(Bucket=settings.AWS_S3_BUCKET, Key=s3_key, Body=content) |
| |
|
| | await asyncio.to_thread(upload_audio_sync) |
| | |
| | public_url = s3_service.get_public_url(s3_key) |
| | |
| | db_podcast.s3_key = s3_key |
| | db_podcast.s3_url = public_url |
| | db_podcast.script = script |
| | db_podcast.status = "completed" |
| | db.commit() |
| | |
| | if os.path.exists(audio_path): os.remove(audio_path) |
| | |
| | await manager.send_result(connection_id, { |
| | "id": db_podcast.id, |
| | "status": "completed", |
| | "message": "Podcast generated successfully", |
| | "public_url": public_url |
| | }) |
| | |
| | except Exception as e: |
| | logger.error(f"Podcast task failed: {e}") |
| | if 'db_podcast' in locals(): |
| | db_podcast.status = "failed" |
| | db_podcast.error_message = str(e) |
| | db.commit() |
| | await manager.send_error(connection_id, str(e)) |
| |
|
| | async def handle_flashcards_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): |
| | """Internal handler for flashcard generation""" |
| | try: |
| | source_id = None |
| | source = None |
| | if data.get("file_key"): |
| | source = db.query(db_models.Source).filter( |
| | db_models.Source.s3_key == data["file_key"], |
| | db_models.Source.user_id == current_user.id |
| | ).first() |
| | if not source: |
| | await manager.send_error(connection_id, "Not authorized to access this file") |
| | return |
| | source_id = source.id |
| | |
| | |
| | file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None |
| | |
| | if file_base: |
| | title = f"Flashcard-{file_base}" |
| | elif data.get("topic") and data.get("topic") != "string": |
| | title = data.get("topic") |
| | else: |
| | title = f"Flashcards {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" |
| |
|
| | db_set = db_models.FlashcardSet( |
| | title=title, |
| | difficulty=data.get("difficulty", "medium"), |
| | user_id=current_user.id, |
| | source_id=source_id, |
| | status="processing" |
| | ) |
| | db.add(db_set) |
| | db.commit() |
| | db.refresh(db_set) |
| | |
| | await manager.send_progress(connection_id, 10, "processing", "Generating flashcards...") |
| | |
| | cards_data = await flashcard_service.generate_flashcards( |
| | file_key=data.get("file_key"), |
| | text_input=data.get("text_input"), |
| | difficulty=data.get("difficulty", "medium"), |
| | quantity=data.get("quantity", "standard"), |
| | topic=data.get("topic"), |
| | language=data.get("language", "English"), |
| | progress_callback=lambda p, m: asyncio.create_task( |
| | manager.send_progress(connection_id, 10 + int(p * 0.7), "processing", m) |
| | ) |
| | ) |
| | |
| | if not cards_data: |
| | raise Exception("AI returned empty flashcard data") |
| | |
| | await manager.send_progress(connection_id, 85, "processing", "Saving to database...") |
| | |
| | for item in cards_data: |
| | db_card = db_models.Flashcard( |
| | flashcard_set_id=db_set.id, |
| | question=item.get("question", ""), |
| | answer=item.get("answer", "") |
| | ) |
| | db.add(db_card) |
| | |
| | db_set.status = "completed" |
| | db.commit() |
| | |
| | await manager.send_result(connection_id, { |
| | "id": db_set.id, |
| | "title": db_set.title, |
| | "flashcards_count": len(db_set.flashcards), |
| | "status": "completed" |
| | }) |
| | except Exception as e: |
| | logger.error(f"Flashcard task failed: {e}") |
| | if 'db_set' in locals(): |
| | db_set.status = "failed" |
| | db_set.error_message = str(e) |
| | db.commit() |
| | await manager.send_error(connection_id, str(e)) |
| |
|
| | async def handle_quiz_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): |
| | """Internal handler for quiz generation""" |
| | try: |
| | source_id = None |
| | if data.get("file_key"): |
| | source = db.query(db_models.Source).filter( |
| | db_models.Source.s3_key == data["file_key"], |
| | db_models.Source.user_id == current_user.id |
| | ).first() |
| | if not source: |
| | await manager.send_error(connection_id, "Not authorized to access this file") |
| | return |
| | source_id = source.id |
| | |
| | |
| | file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None |
| | |
| | if file_base: |
| | title = f"Quiz-{file_base}" |
| | elif data.get("topic") and data.get("topic") != "string": |
| | title = data.get("topic") |
| | else: |
| | title = f"Quiz {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" |
| |
|
| | db_set = db_models.QuizSet( |
| | title=title, |
| | difficulty=data.get("difficulty", "medium"), |
| | user_id=current_user.id, |
| | source_id=source_id, |
| | status="processing" |
| | ) |
| | db.add(db_set) |
| | db.commit() |
| | db.refresh(db_set) |
| |
|
| | await manager.send_progress(connection_id, 10, "processing", "Generating quiz...") |
| | |
| | quizzes_data = await quiz_service.generate_quiz( |
| | file_key=data.get("file_key"), |
| | text_input=data.get("text_input"), |
| | difficulty=data.get("difficulty", "medium"), |
| | topic=data.get("topic"), |
| | language=data.get("language", "English"), |
| | count_mode=data.get("count", "STANDARD"), |
| | progress_callback=lambda p, m: asyncio.create_task( |
| | manager.send_progress(connection_id, 10 + int(p * 0.7), "processing", m) |
| | ) |
| | ) |
| | |
| | if not quizzes_data: |
| | raise Exception("AI failed to generate quiz data") |
| | |
| | for item in quizzes_data: |
| | db_question = db_models.QuizQuestion( |
| | quiz_set_id=db_set.id, |
| | question=item.get("question", ""), |
| | choices=item.get("choices", {}), |
| | answer=str(item.get("answer", "1")), |
| | explanation=item.get("explanation", "") |
| | ) |
| | db.add(db_question) |
| | |
| | db_set.status = "completed" |
| | db.commit() |
| | await manager.send_result(connection_id, {"id": db_set.id, "title": db_set.title, "status": "completed"}) |
| | except Exception as e: |
| | logger.error(f"Quiz task failed: {e}") |
| | if 'db_set' in locals(): |
| | db_set.status = "failed" |
| | db_set.error_message = str(e) |
| | db.commit() |
| | await manager.send_error(connection_id, str(e)) |
| |
|
| | async def handle_video_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): |
| | """Internal handler for video summary generation""" |
| | try: |
| | source = db.query(db_models.Source).filter( |
| | db_models.Source.s3_key == data.get("file_key"), |
| | db_models.Source.user_id == current_user.id |
| | ).first() |
| | if not source: |
| | await manager.send_error(connection_id, "Not authorized to access this file") |
| | return |
| |
|
| | file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None |
| | title = f"Video Summary {file_base}" if file_base else f"Video Summary {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" |
| |
|
| | db_summary = db_models.VideoSummary( |
| | title=title, |
| | user_id=current_user.id, |
| | source_id=source.id, |
| | status="processing" |
| | ) |
| | db.add(db_summary) |
| | db.commit() |
| | db.refresh(db_summary) |
| |
|
| | await manager.send_progress(connection_id, 10, "processing", "Starting video generation...") |
| |
|
| | if data.get("use_slides_transformation", True): |
| | result = await slides_video_service.generate_transformed_video_summary( |
| | file_key=data["file_key"], |
| | language=data.get("language", "Japanese"), |
| | voice_name=data.get("voice_name", "Kore"), |
| | custom_prompt=data.get("custom_prompt", "") |
| | ) |
| | else: |
| | result = await video_generator_service.generate_video_summary( |
| | file_key=data["file_key"], |
| | language=data.get("language", "Japanese"), |
| | voice_name=data.get("voice_name", "Kore") |
| | ) |
| |
|
| | db_summary.title = result["title"] |
| | db_summary.s3_key = result["s3_key"] |
| | db_summary.s3_url = result["s3_url"] |
| | db_summary.status = "completed" |
| | db.commit() |
| |
|
| | await manager.send_result(connection_id, { |
| | "type": "video", |
| | "id": db_summary.id, |
| | "status": "completed", |
| | "title": db_summary.title, |
| | "public_url": db_summary.s3_url |
| | }) |
| | except Exception as e: |
| | logger.error(f"Video task failed: {e}") |
| | if 'db_summary' in locals(): |
| | db_summary.status = "failed" |
| | db_summary.error_message = str(e) |
| | db.commit() |
| | await manager.send_error(connection_id, str(e)) |
| |
|
| | async def handle_report_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): |
| | """Internal handler for report generation""" |
| | try: |
| | source_id = None |
| | if data.get("file_key"): |
| | source = db.query(db_models.Source).filter( |
| | db_models.Source.s3_key == data["file_key"], |
| | db_models.Source.user_id == current_user.id |
| | ).first() |
| | if not source: |
| | await manager.send_error(connection_id, "Not authorized to access this file") |
| | return |
| | source_id = source.id |
| |
|
| | file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None |
| | title = f"Report-{file_base}" if file_base else f"Report {data.get('format_key', 'custom')} {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" |
| |
|
| | db_report = db_models.Report( |
| | title=title, |
| | format_key=data.get("format_key", "custom"), |
| | user_id=current_user.id, |
| | source_id=source_id, |
| | status="processing" |
| | ) |
| | db.add(db_report) |
| | db.commit() |
| | db.refresh(db_report) |
| |
|
| | await manager.send_progress(connection_id, 15, "processing", "Generating report content...") |
| |
|
| | content = await report_service.generate_report( |
| | file_key=data.get("file_key"), |
| | text_input=data.get("text_input"), |
| | format_key=data.get("format_key", "briefing_doc"), |
| | custom_prompt=data.get("custom_prompt"), |
| | language=data.get("language", "Japanese") |
| | ) |
| |
|
| | if not content: |
| | raise Exception("AI failed to generate report content") |
| |
|
| | if not db_report.title or "Report-" not in db_report.title: |
| | title = content.split('\n')[0].replace('#', '').strip() |
| | if not title or len(title) < 3: |
| | title = f"Report {data.get('format_key')}" |
| | db_report.title = title |
| | db_report.content = content |
| | db_report.status = "completed" |
| | db.commit() |
| |
|
| | await manager.send_result(connection_id, { |
| | "type": "report", |
| | "id": db_report.id, |
| | "status": "completed", |
| | "title": db_report.title |
| | }) |
| | except Exception as e: |
| | logger.error(f"Report task failed: {e}") |
| | if 'db_report' in locals(): |
| | db_report.status = "failed" |
| | db_report.error_message = str(e) |
| | db.commit() |
| | await manager.send_error(connection_id, str(e)) |
| |
|
| | async def handle_mindmap_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session): |
| | """Internal handler for mindmap generation""" |
| | try: |
| | source_id = None |
| | if data.get("file_key"): |
| | source = db.query(db_models.Source).filter( |
| | db_models.Source.s3_key == data["file_key"], |
| | db_models.Source.user_id == current_user.id |
| | ).first() |
| | if not source: |
| | await manager.send_error(connection_id, "Not authorized to access this file") |
| | return |
| | source_id = source.id |
| |
|
| | file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None |
| | |
| | if file_base: |
| | title = f"Mind Map-{file_base}" |
| | elif data.get("title") and data.get("title") != "string": |
| | title = data.get("title") |
| | else: |
| | title = f"Mind Map {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" |
| |
|
| | db_mindmap = db_models.MindMap( |
| | title=title, |
| | user_id=current_user.id, |
| | source_id=source_id, |
| | status="processing" |
| | ) |
| | db.add(db_mindmap) |
| | db.commit() |
| | db.refresh(db_mindmap) |
| |
|
| | await manager.send_progress(connection_id, 20, "processing", "Generating mind map visualization...") |
| |
|
| | mermaid_code = await mindmap_service.generate_mindmap( |
| | file_key=data.get("file_key"), |
| | text_input=data.get("text_input") |
| | ) |
| |
|
| | if not mermaid_code: |
| | raise Exception("AI failed to generate mind map code") |
| |
|
| | db_mindmap.mermaid_code = mermaid_code |
| | db_mindmap.status = "completed" |
| | db.commit() |
| |
|
| | await manager.send_result(connection_id, { |
| | "type": "mindmap", |
| | "id": db_mindmap.id, |
| | "status": "completed", |
| | "title": db_mindmap.title |
| | }) |
| | except Exception as e: |
| | logger.error(f"Mindmap task failed: {e}") |
| | if 'db_mindmap' in locals(): |
| | db_mindmap.status = "failed" |
| | db_mindmap.error_message = str(e) |
| | db.commit() |
| | await manager.send_error(connection_id, str(e)) |
| |
|
| |
|