creatorstudio-ai-backend-develop / api /websocket_routes.py
matsuap's picture
Upload folder using huggingface_hub
e768b43 verified
import logging
import asyncio
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
from sqlalchemy.orm import Session
from datetime import datetime
from typing import Dict, Any
from api.auth import get_current_user_ws
from models import db_models
from core.database import get_db
from services.flashcard_service import flashcard_service
from services.quiz_service import quiz_service
from services.report_service import report_service
from services.mindmap_service import mindmap_service
from services.podcast_service import podcast_service
from services.s3_service import s3_service
from services.video_generator_service import video_generator_service
from services.slides_video_service import slides_video_service
from models.schemas import VideoSummaryGenerateRequest, ReportGenerateRequest, MindMapGenerateRequest
router = APIRouter(prefix="/ws", tags=["websockets"])
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manages WebSocket connections for parallel execution"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
async def connect(self, websocket: WebSocket, connection_id: str):
await websocket.accept()
self.active_connections[connection_id] = websocket
logger.info(f"WebSocket connected: {connection_id}")
def disconnect(self, connection_id: str):
if connection_id in self.active_connections:
del self.active_connections[connection_id]
logger.info(f"WebSocket disconnected: {connection_id}")
async def send_progress(self, connection_id: str, progress: int, status: str, message: str = ""):
if connection_id in self.active_connections:
try:
await self.active_connections[connection_id].send_json({
"type": "progress",
"progress": progress,
"status": status,
"message": message
})
except Exception as e:
logger.error(f"Error sending progress to {connection_id}: {e}")
async def send_result(self, connection_id: str, data: Any):
if connection_id in self.active_connections:
try:
await self.active_connections[connection_id].send_json({
"type": "result",
"status": "complete",
"progress": 100,
"data": data
})
except Exception as e:
logger.error(f"Error sending result to {connection_id}: {e}")
async def send_error(self, connection_id: str, error: str):
if connection_id in self.active_connections:
try:
await self.active_connections[connection_id].send_json({
"type": "error",
"status": "error",
"message": error
})
except Exception as e:
logger.error(f"Error sending error to {connection_id}: {e}")
manager = ConnectionManager()
@router.websocket("/generate")
async def unified_generate_ws(
websocket: WebSocket,
token: str,
db: Session = Depends(get_db)):
"""
Unified WebSocket gateway for all generation tasks.
Client sends JSON: { "type": "podcast|flashcards|quiz|mindmap|report|video", "data": { ... } }
"""
await websocket.accept()
try:
current_user = await get_current_user_ws(token, db)
connection_id = f"user_{current_user.id}"
manager.active_connections[connection_id] = websocket
# Receive the task specification
message = await websocket.receive_json()
task_type = message.get("type")
data = message.get("data", {})
if not task_type:
await manager.send_error(connection_id, "Missing 'type' in request")
return
await manager.send_progress(connection_id, 2, "processing", f"Initializing {task_type} task...")
# --- ROUTING LOGIC ---
if task_type == "podcast":
await handle_podcast_task(connection_id, data, current_user, db)
elif task_type == "video":
await handle_video_task(connection_id, data, current_user, db)
elif task_type == "report":
await handle_report_task(connection_id, data, current_user, db)
elif task_type == "mindmap":
await handle_mindmap_task(connection_id, data, current_user, db)
elif task_type == "flashcards":
await handle_flashcards_task(connection_id, data, current_user, db)
elif task_type == "quiz":
await handle_quiz_task(connection_id, data, current_user, db)
else:
await manager.send_error(connection_id, f"Unsupported task type: {task_type}")
except WebSocketDisconnect:
logger.info(f"Client disconnected")
except Exception as e:
logger.error(f"Unified WebSocket error: {e}")
try:
await manager.send_error(connection_id, str(e))
except: pass
finally:
if 'connection_id' in locals():
manager.disconnect(connection_id)
async def handle_podcast_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
"""Internal handler for podcast generation"""
try:
source_id = None
if data.get("file_key"):
source = db.query(db_models.Source).filter(
db_models.Source.s3_key == data["file_key"],
db_models.Source.user_id == current_user.id
).first()
if not source:
await manager.send_error(connection_id, "Not authorized to access this file")
return
source_id = source.id
file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None
title = f"Podcast-{file_base}" if file_base else f"Podcast {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
db_podcast = db_models.Podcast(
title=title,
user_id=current_user.id,
source_id=source_id,
status="processing"
)
db.add(db_podcast)
db.commit()
db.refresh(db_podcast)
db_podcast.status = "processing"
db.commit()
analysis_report = ""
if data.get("file_key"):
await manager.send_progress(connection_id, 10, "processing", "Analyzing source file...")
analysis_report = await podcast_service.analyze_pdf(
file_key=data["file_key"],
duration_minutes=data.get("duration_minutes", 10)
)
await manager.send_progress(connection_id, 15, "processing", "Generating podcast script...")
script = await podcast_service.generate_script(
user_prompt=data["user_prompt"],
model=data.get("model", "gpt-4o"),
duration_minutes=data.get("duration_minutes", 10),
podcast_format=data.get("podcast_format", "conversational"),
pdf_suggestions=analysis_report,
file_key=data.get("file_key")
)
if not script: raise Exception("Failed to generate script")
await manager.send_progress(connection_id, 45, "processing", "Generating audio...")
audio_path = await podcast_service.generate_full_audio(
script=script,
tts_model=data.get("tts_model", "gemini-2.0-flash-exp"),
spk1_voice=data.get("spk1_voice", "Puck"),
spk2_voice=data.get("spk2_voice", "Charon"),
temperature=data.get("temperature", 1.0),
bgm_choice=data.get("bgm_choice", "No BGM")
)
if not audio_path: raise Exception("Failed to generate audio")
await manager.send_progress(connection_id, 90, "processing", "Uploading to S3...")
import os
filename = os.path.basename(audio_path)
s3_key = f"users/{current_user.id}/outputs/podcasts/{filename}"
def upload_audio_sync():
with open(audio_path, "rb") as f:
content = f.read()
import boto3
from core.config import settings
boto3.client('s3',
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
region_name=settings.AWS_REGION).put_object(Bucket=settings.AWS_S3_BUCKET, Key=s3_key, Body=content)
await asyncio.to_thread(upload_audio_sync)
public_url = s3_service.get_public_url(s3_key)
db_podcast.s3_key = s3_key
db_podcast.s3_url = public_url
db_podcast.script = script
db_podcast.status = "completed"
db.commit()
if os.path.exists(audio_path): os.remove(audio_path)
await manager.send_result(connection_id, {
"id": db_podcast.id,
"status": "completed",
"message": "Podcast generated successfully",
"public_url": public_url
})
except Exception as e:
logger.error(f"Podcast task failed: {e}")
if 'db_podcast' in locals():
db_podcast.status = "failed"
db_podcast.error_message = str(e)
db.commit()
await manager.send_error(connection_id, str(e))
async def handle_flashcards_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
"""Internal handler for flashcard generation"""
try:
source_id = None
source = None
if data.get("file_key"):
source = db.query(db_models.Source).filter(
db_models.Source.s3_key == data["file_key"],
db_models.Source.user_id == current_user.id
).first()
if not source:
await manager.send_error(connection_id, "Not authorized to access this file")
return
source_id = source.id
# Create initial processing record
file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None
if file_base:
title = f"Flashcard-{file_base}"
elif data.get("topic") and data.get("topic") != "string":
title = data.get("topic")
else:
title = f"Flashcards {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
db_set = db_models.FlashcardSet(
title=title,
difficulty=data.get("difficulty", "medium"),
user_id=current_user.id,
source_id=source_id,
status="processing"
)
db.add(db_set)
db.commit()
db.refresh(db_set)
await manager.send_progress(connection_id, 10, "processing", "Generating flashcards...")
cards_data = await flashcard_service.generate_flashcards(
file_key=data.get("file_key"),
text_input=data.get("text_input"),
difficulty=data.get("difficulty", "medium"),
quantity=data.get("quantity", "standard"),
topic=data.get("topic"),
language=data.get("language", "English"),
progress_callback=lambda p, m: asyncio.create_task(
manager.send_progress(connection_id, 10 + int(p * 0.7), "processing", m)
)
)
if not cards_data:
raise Exception("AI returned empty flashcard data")
await manager.send_progress(connection_id, 85, "processing", "Saving to database...")
for item in cards_data:
db_card = db_models.Flashcard(
flashcard_set_id=db_set.id,
question=item.get("question", ""),
answer=item.get("answer", "")
)
db.add(db_card)
db_set.status = "completed"
db.commit()
await manager.send_result(connection_id, {
"id": db_set.id,
"title": db_set.title,
"flashcards_count": len(db_set.flashcards),
"status": "completed"
})
except Exception as e:
logger.error(f"Flashcard task failed: {e}")
if 'db_set' in locals():
db_set.status = "failed"
db_set.error_message = str(e)
db.commit()
await manager.send_error(connection_id, str(e))
async def handle_quiz_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
"""Internal handler for quiz generation"""
try:
source_id = None
if data.get("file_key"):
source = db.query(db_models.Source).filter(
db_models.Source.s3_key == data["file_key"],
db_models.Source.user_id == current_user.id
).first()
if not source:
await manager.send_error(connection_id, "Not authorized to access this file")
return
source_id = source.id
# Create initial processing record
file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None
if file_base:
title = f"Quiz-{file_base}"
elif data.get("topic") and data.get("topic") != "string":
title = data.get("topic")
else:
title = f"Quiz {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
db_set = db_models.QuizSet(
title=title,
difficulty=data.get("difficulty", "medium"),
user_id=current_user.id,
source_id=source_id,
status="processing"
)
db.add(db_set)
db.commit()
db.refresh(db_set)
await manager.send_progress(connection_id, 10, "processing", "Generating quiz...")
quizzes_data = await quiz_service.generate_quiz(
file_key=data.get("file_key"),
text_input=data.get("text_input"),
difficulty=data.get("difficulty", "medium"),
topic=data.get("topic"),
language=data.get("language", "English"),
count_mode=data.get("count", "STANDARD"),
progress_callback=lambda p, m: asyncio.create_task(
manager.send_progress(connection_id, 10 + int(p * 0.7), "processing", m)
)
)
if not quizzes_data:
raise Exception("AI failed to generate quiz data")
for item in quizzes_data:
db_question = db_models.QuizQuestion(
quiz_set_id=db_set.id,
question=item.get("question", ""),
choices=item.get("choices", {}),
answer=str(item.get("answer", "1")),
explanation=item.get("explanation", "")
)
db.add(db_question)
db_set.status = "completed"
db.commit()
await manager.send_result(connection_id, {"id": db_set.id, "title": db_set.title, "status": "completed"})
except Exception as e:
logger.error(f"Quiz task failed: {e}")
if 'db_set' in locals():
db_set.status = "failed"
db_set.error_message = str(e)
db.commit()
await manager.send_error(connection_id, str(e))
async def handle_video_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
"""Internal handler for video summary generation"""
try:
source = db.query(db_models.Source).filter(
db_models.Source.s3_key == data.get("file_key"),
db_models.Source.user_id == current_user.id
).first()
if not source:
await manager.send_error(connection_id, "Not authorized to access this file")
return
file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None
title = f"Video Summary {file_base}" if file_base else f"Video Summary {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
db_summary = db_models.VideoSummary(
title=title,
user_id=current_user.id,
source_id=source.id,
status="processing"
)
db.add(db_summary)
db.commit()
db.refresh(db_summary)
await manager.send_progress(connection_id, 10, "processing", "Starting video generation...")
if data.get("use_slides_transformation", True):
result = await slides_video_service.generate_transformed_video_summary(
file_key=data["file_key"],
language=data.get("language", "Japanese"),
voice_name=data.get("voice_name", "Kore"),
custom_prompt=data.get("custom_prompt", "")
)
else:
result = await video_generator_service.generate_video_summary(
file_key=data["file_key"],
language=data.get("language", "Japanese"),
voice_name=data.get("voice_name", "Kore")
)
db_summary.title = result["title"]
db_summary.s3_key = result["s3_key"]
db_summary.s3_url = result["s3_url"]
db_summary.status = "completed"
db.commit()
await manager.send_result(connection_id, {
"type": "video",
"id": db_summary.id,
"status": "completed",
"title": db_summary.title,
"public_url": db_summary.s3_url
})
except Exception as e:
logger.error(f"Video task failed: {e}")
if 'db_summary' in locals():
db_summary.status = "failed"
db_summary.error_message = str(e)
db.commit()
await manager.send_error(connection_id, str(e))
async def handle_report_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
"""Internal handler for report generation"""
try:
source_id = None
if data.get("file_key"):
source = db.query(db_models.Source).filter(
db_models.Source.s3_key == data["file_key"],
db_models.Source.user_id == current_user.id
).first()
if not source:
await manager.send_error(connection_id, "Not authorized to access this file")
return
source_id = source.id
file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None
title = f"Report-{file_base}" if file_base else f"Report {data.get('format_key', 'custom')} {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
db_report = db_models.Report(
title=title,
format_key=data.get("format_key", "custom"),
user_id=current_user.id,
source_id=source_id,
status="processing"
)
db.add(db_report)
db.commit()
db.refresh(db_report)
await manager.send_progress(connection_id, 15, "processing", "Generating report content...")
content = await report_service.generate_report(
file_key=data.get("file_key"),
text_input=data.get("text_input"),
format_key=data.get("format_key", "briefing_doc"),
custom_prompt=data.get("custom_prompt"),
language=data.get("language", "Japanese")
)
if not content:
raise Exception("AI failed to generate report content")
if not db_report.title or "Report-" not in db_report.title:
title = content.split('\n')[0].replace('#', '').strip()
if not title or len(title) < 3:
title = f"Report {data.get('format_key')}"
db_report.title = title
db_report.content = content
db_report.status = "completed"
db.commit()
await manager.send_result(connection_id, {
"type": "report",
"id": db_report.id,
"status": "completed",
"title": db_report.title
})
except Exception as e:
logger.error(f"Report task failed: {e}")
if 'db_report' in locals():
db_report.status = "failed"
db_report.error_message = str(e)
db.commit()
await manager.send_error(connection_id, str(e))
async def handle_mindmap_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
"""Internal handler for mindmap generation"""
try:
source_id = None
if data.get("file_key"):
source = db.query(db_models.Source).filter(
db_models.Source.s3_key == data["file_key"],
db_models.Source.user_id == current_user.id
).first()
if not source:
await manager.send_error(connection_id, "Not authorized to access this file")
return
source_id = source.id
file_base = data.get("file_key").split('/')[-1].rsplit('.', 1)[0] if data.get("file_key") else None
if file_base:
title = f"Mind Map-{file_base}"
elif data.get("title") and data.get("title") != "string":
title = data.get("title")
else:
title = f"Mind Map {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
db_mindmap = db_models.MindMap(
title=title,
user_id=current_user.id,
source_id=source_id,
status="processing"
)
db.add(db_mindmap)
db.commit()
db.refresh(db_mindmap)
await manager.send_progress(connection_id, 20, "processing", "Generating mind map visualization...")
mermaid_code = await mindmap_service.generate_mindmap(
file_key=data.get("file_key"),
text_input=data.get("text_input")
)
if not mermaid_code:
raise Exception("AI failed to generate mind map code")
db_mindmap.mermaid_code = mermaid_code
db_mindmap.status = "completed"
db.commit()
await manager.send_result(connection_id, {
"type": "mindmap",
"id": db_mindmap.id,
"status": "completed",
"title": db_mindmap.title
})
except Exception as e:
logger.error(f"Mindmap task failed: {e}")
if 'db_mindmap' in locals():
db_mindmap.status = "failed"
db_mindmap.error_message = str(e)
db.commit()
await manager.send_error(connection_id, str(e))