| | import os |
| | import asyncio |
| | import logging |
| | from datetime import datetime |
| | from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks |
| | from sqlalchemy.orm import Session |
| | from typing import Dict, List |
| | from api.websocket_routes import manager |
| |
|
| | from api.auth import get_current_user |
| | from models.schemas import PodcastGenerateRequest, PodcastResponse |
| | from models import db_models |
| | from core.database import get_db, SessionLocal |
| | from services.podcast_service import podcast_service |
| | from services.s3_service import s3_service |
| | from core import constants |
| |
|
| | router = APIRouter(prefix="/api/podcast", tags=["podcast"]) |
| | logger = logging.getLogger(__name__) |
| |
|
| | @router.get("/config") |
| | async def get_podcast_config(): |
| | """Returns available voices, BGM, and formats for podcast generation.""" |
| | return { |
| | "voices": constants.PODCAST_VOICES, |
| | "bgm": constants.PODCAST_BGM, |
| | "formats": constants.PODCAST_FORMATS, |
| | "tts_models": constants.PODCAST_TTS_MODALS, |
| | "models": constants.PODCAST_MODALS |
| | } |
| |
|
| | async def run_podcast_generation(podcast_id: int, request: PodcastGenerateRequest, user_id: int): |
| | """Background task to generate podcast and update status.""" |
| | db = SessionLocal() |
| | try: |
| | podcast = db.query(db_models.Podcast).filter(db_models.Podcast.id == podcast_id).first() |
| | if not podcast: |
| | return |
| |
|
| | podcast.status = "processing" |
| | db.commit() |
| | |
| | |
| | connection_id = f"user_{user_id}" |
| | await manager.send_progress(connection_id, 10, "processing", "Analyzing source file...") |
| |
|
| | |
| | analysis_report = "" |
| | if request.file_key: |
| | analysis_report = await podcast_service.analyze_pdf( |
| | file_key=request.file_key, |
| | duration_minutes=request.duration_minutes |
| | ) |
| | await manager.send_progress(connection_id, 20, "processing", "Generating podcast script...") |
| |
|
| | |
| | script = await podcast_service.generate_script( |
| | user_prompt=request.user_prompt, |
| | model=request.model, |
| | duration_minutes=request.duration_minutes, |
| | podcast_format=request.podcast_format, |
| | pdf_suggestions=analysis_report, |
| | file_key=request.file_key |
| | ) |
| |
|
| | if not script: |
| | raise Exception("Failed to generate script") |
| | |
| | await manager.send_progress(connection_id, 40, "processing", "Generating audio (this may take several minutes)...") |
| |
|
| | |
| | audio_path = await podcast_service.generate_full_audio( |
| | script=script, |
| | tts_model=request.tts_model, |
| | spk1_voice=request.spk1_voice, |
| | spk2_voice=request.spk2_voice, |
| | temperature=request.temperature, |
| | bgm_choice=request.bgm_choice |
| | ) |
| |
|
| | if not audio_path: |
| | raise Exception("Failed to generate audio") |
| |
|
| | await manager.send_progress(connection_id, 85, "processing", "Uploading to S3...") |
| |
|
| | |
| | filename = os.path.basename(audio_path) |
| | s3_key = f"users/{user_id}/outputs/podcasts/{filename}" |
| |
|
| | def upload_audio(): |
| | with open(audio_path, "rb") as f: |
| | content = f.read() |
| | |
| | import boto3 |
| | from core.config import settings |
| | s3_client = boto3.client('s3', |
| | aws_access_key_id=settings.AWS_ACCESS_KEY_ID, |
| | aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, |
| | region_name=settings.AWS_REGION) |
| | s3_client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=s3_key, Body=content) |
| | return content |
| |
|
| | await asyncio.to_thread(upload_audio) |
| |
|
| | public_url = s3_service.get_public_url(s3_key) |
| | |
| | |
| | podcast.s3_key = s3_key |
| | podcast.s3_url = public_url |
| | podcast.script = script |
| | podcast.status = "completed" |
| | db.commit() |
| |
|
| | |
| | await manager.send_result(connection_id, { |
| | "id": podcast.id, |
| | "status": "completed", |
| | "title": podcast.title, |
| | "public_url": public_url |
| | }) |
| |
|
| | |
| | if os.path.exists(audio_path): |
| | os.remove(audio_path) |
| |
|
| | except Exception as e: |
| | logger.error(f"Background podcast generation failed for ID {podcast_id}: {e}") |
| | podcast = db.query(db_models.Podcast).filter(db_models.Podcast.id == podcast_id).first() |
| | if podcast: |
| | podcast.status = "failed" |
| | podcast.error_message = str(e) |
| | db.commit() |
| | |
| | connection_id = f"user_{user_id}" |
| | await manager.send_error(connection_id, f"Generation failed: {str(e)}") |
| | finally: |
| | db.close() |
| |
|
| | @router.post("/generate", response_model=PodcastResponse) |
| | async def generate_podcast( |
| | request: PodcastGenerateRequest, |
| | background_tasks: BackgroundTasks, |
| | current_user: db_models.User = Depends(get_current_user), |
| | db: Session = Depends(get_db) |
| | ): |
| | """ |
| | Initiates podcast generation in the background. |
| | Creates a 'pending' record immediately and returns it. |
| | """ |
| | |
| | source_id = None |
| | if request.file_key: |
| | source = db.query(db_models.Source).filter( |
| | db_models.Source.s3_key == request.file_key, |
| | db_models.Source.user_id == current_user.id |
| | ).first() |
| | if not source: |
| | raise HTTPException(status_code=403, detail="Not authorized to access this file") |
| | source_id = source.id |
| |
|
| | |
| | file_base = request.file_key.split('/')[-1].rsplit('.', 1)[0] if request.file_key else None |
| | title = f"Podcast-{file_base}" if file_base else f"Podcast {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" |
| | db_podcast = db_models.Podcast( |
| | title=title, |
| | user_id=current_user.id, |
| | source_id=source_id, |
| | status="processing" |
| | ) |
| | db.add(db_podcast) |
| | db.commit() |
| | db.refresh(db_podcast) |
| |
|
| | |
| | background_tasks.add_task(run_podcast_generation, db_podcast.id, request, current_user.id) |
| |
|
| | return db_podcast |
| |
|
| | @router.get("/list", response_model=List[PodcastResponse]) |
| | async def list_podcasts( |
| | current_user: db_models.User = Depends(get_current_user), |
| | db: Session = Depends(get_db) |
| | ): |
| | """ |
| | Lists all podcasts for the current user including their generation status. |
| | """ |
| | try: |
| | podcasts = db.query(db_models.Podcast).filter( |
| | db_models.Podcast.user_id == current_user.id |
| | ).order_by(db_models.Podcast.created_at.desc()).all() |
| | |
| | return [PodcastResponse.model_validate(p) for p in podcasts] |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @router.delete("/{podcast_id}") |
| | async def delete_podcast( |
| | podcast_id: int, |
| | current_user: db_models.User = Depends(get_current_user), |
| | db: Session = Depends(get_db) |
| | ): |
| | """ |
| | Deletes a specific podcast from database and S3. |
| | """ |
| | podcast = db.query(db_models.Podcast).filter( |
| | db_models.Podcast.id == podcast_id, |
| | db_models.Podcast.user_id == current_user.id |
| | ).first() |
| | |
| | if not podcast: |
| | raise HTTPException(status_code=404, detail="Podcast not found") |
| | |
| | try: |
| | |
| | if podcast.s3_key: |
| | await s3_service.delete_file(podcast.s3_key) |
| | |
| | |
| | db.delete(podcast) |
| | db.commit() |
| | |
| | return {"message": "Podcast and associated audio file deleted successfully"} |
| | except Exception as e: |
| | db.rollback() |
| | logger.error(f"Failed to delete podcast: {e}") |
| | raise HTTPException(status_code=500, detail=f"Deletion failed: {str(e)}") |