File size: 8,013 Bytes
792ad00
951d5c6
792ad00
 
951d5c6
792ad00
 
951d5c6
792ad00
 
951d5c6
792ad00
951d5c6
792ad00
 
 
 
 
 
 
 
 
 
 
 
 
acccb85
86eea89
 
792ad00
 
951d5c6
 
 
792ad00
951d5c6
 
 
792ad00
951d5c6
 
 
 
 
 
792ad00
951d5c6
 
792ad00
951d5c6
 
 
 
 
792ad00
 
 
 
 
 
 
951d5c6
792ad00
 
 
 
951d5c6
 
 
792ad00
 
 
 
 
 
 
 
 
 
 
 
951d5c6
 
 
792ad00
 
 
951d5c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
792ad00
951d5c6
 
 
 
 
 
 
792ad00
 
951d5c6
 
 
 
 
 
 
792ad00
951d5c6
 
 
792ad00
 
951d5c6
 
 
 
 
 
 
 
 
 
 
792ad00
951d5c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f281e
 
951d5c6
77f281e
951d5c6
 
 
 
 
 
 
 
 
 
 
 
 
 
792ad00
 
 
 
 
951d5c6
792ad00
 
 
 
 
 
951d5c6
792ad00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951d5c6
 
 
792ad00
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import os
import asyncio
import logging
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.orm import Session
from typing import Dict, List
from api.websocket_routes import manager

from api.auth import get_current_user
from models.schemas import PodcastGenerateRequest, PodcastResponse
from models import db_models
from core.database import get_db, SessionLocal
from services.podcast_service import podcast_service
from services.s3_service import s3_service
from core import constants

router = APIRouter(prefix="/api/podcast", tags=["podcast"])
logger = logging.getLogger(__name__)

@router.get("/config")
async def get_podcast_config():
    """Returns available voices, BGM, and formats for podcast generation."""
    return {
        "voices": constants.PODCAST_VOICES,
        "bgm": constants.PODCAST_BGM,
        "formats": constants.PODCAST_FORMATS,
        "tts_models": constants.PODCAST_TTS_MODALS,
        "models": constants.PODCAST_MODALS
    }

async def run_podcast_generation(podcast_id: int, request: PodcastGenerateRequest, user_id: int):
    """Background task to generate podcast and update status."""
    db = SessionLocal()
    try:
        podcast = db.query(db_models.Podcast).filter(db_models.Podcast.id == podcast_id).first()
        if not podcast:
            return

        podcast.status = "processing"
        db.commit()
        
        # Notify via WebSocket if connected
        connection_id = f"user_{user_id}"
        await manager.send_progress(connection_id, 10, "processing", "Analyzing source file...")

        # 1. Analyze first if file is provided
        analysis_report = ""
        if request.file_key:
            analysis_report = await podcast_service.analyze_pdf(
                file_key=request.file_key,
                duration_minutes=request.duration_minutes
            )
            await manager.send_progress(connection_id, 20, "processing", "Generating podcast script...")

        # 2. Generate Script
        script = await podcast_service.generate_script(
            user_prompt=request.user_prompt,
            model=request.model,
            duration_minutes=request.duration_minutes,
            podcast_format=request.podcast_format,
            pdf_suggestions=analysis_report,
            file_key=request.file_key
        )

        if not script:
            raise Exception("Failed to generate script")
        
        await manager.send_progress(connection_id, 40, "processing", "Generating audio (this may take several minutes)...")

        # 3. Generate Audio
        audio_path = await podcast_service.generate_full_audio(
            script=script,
            tts_model=request.tts_model,
            spk1_voice=request.spk1_voice,
            spk2_voice=request.spk2_voice,
            temperature=request.temperature,
            bgm_choice=request.bgm_choice
        )

        if not audio_path:
            raise Exception("Failed to generate audio")

        await manager.send_progress(connection_id, 85, "processing", "Uploading to S3...")

        # 4. Upload to S3
        filename = os.path.basename(audio_path)
        s3_key = f"users/{user_id}/outputs/podcasts/{filename}"

        def upload_audio():
            with open(audio_path, "rb") as f:
                content = f.read()
            
            import boto3
            from core.config import settings
            s3_client = boto3.client('s3', 
                                    aws_access_key_id=settings.AWS_ACCESS_KEY_ID, 
                                    aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
                                    region_name=settings.AWS_REGION)
            s3_client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=s3_key, Body=content)
            return content

        await asyncio.to_thread(upload_audio)

        public_url = s3_service.get_public_url(s3_key)
        
        # 5. Final update to DB
        podcast.s3_key = s3_key
        podcast.s3_url = public_url
        podcast.script = script
        podcast.status = "completed"
        db.commit()

        # Notify completion
        await manager.send_result(connection_id, {
            "id": podcast.id,
            "status": "completed",
            "title": podcast.title,
            "public_url": public_url
        })

        # Clean up
        if os.path.exists(audio_path):
            os.remove(audio_path)

    except Exception as e:
        logger.error(f"Background podcast generation failed for ID {podcast_id}: {e}")
        podcast = db.query(db_models.Podcast).filter(db_models.Podcast.id == podcast_id).first()
        if podcast:
            podcast.status = "failed"
            podcast.error_message = str(e)
            db.commit()
        
        connection_id = f"user_{user_id}"
        await manager.send_error(connection_id, f"Generation failed: {str(e)}")
    finally:
        db.close()

@router.post("/generate", response_model=PodcastResponse)
async def generate_podcast(
    request: PodcastGenerateRequest,
    background_tasks: BackgroundTasks,
    current_user: db_models.User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """
    Initiates podcast generation in the background.
    Creates a 'pending' record immediately and returns it.
    """
    # 1. Verify file ownership if provided
    source_id = None
    if request.file_key:
        source = db.query(db_models.Source).filter(
            db_models.Source.s3_key == request.file_key,
            db_models.Source.user_id == current_user.id
        ).first()
        if not source:
            raise HTTPException(status_code=403, detail="Not authorized to access this file")
        source_id = source.id

    # 2. Create pending record
    file_base = request.file_key.split('/')[-1].rsplit('.', 1)[0] if request.file_key else None
    title = f"Podcast-{file_base}" if file_base else f"Podcast {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
    db_podcast = db_models.Podcast(
        title=title,
        user_id=current_user.id,
        source_id=source_id,
        status="processing"
    )
    db.add(db_podcast)
    db.commit()
    db.refresh(db_podcast)

    # 3. Add to background tasks
    background_tasks.add_task(run_podcast_generation, db_podcast.id, request, current_user.id)

    return db_podcast

@router.get("/list", response_model=List[PodcastResponse])
async def list_podcasts(
    current_user: db_models.User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """
    Lists all podcasts for the current user including their generation status.
    """
    try:
        podcasts = db.query(db_models.Podcast).filter(
            db_models.Podcast.user_id == current_user.id
        ).order_by(db_models.Podcast.created_at.desc()).all()
        
        return [PodcastResponse.model_validate(p) for p in podcasts]
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.delete("/{podcast_id}")
async def delete_podcast(
    podcast_id: int,
    current_user: db_models.User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """
    Deletes a specific podcast from database and S3.
    """
    podcast = db.query(db_models.Podcast).filter(
        db_models.Podcast.id == podcast_id,
        db_models.Podcast.user_id == current_user.id
    ).first()
    
    if not podcast:
        raise HTTPException(status_code=404, detail="Podcast not found")
    
    try:
        # 1. Delete from S3 if it exists
        if podcast.s3_key:
            await s3_service.delete_file(podcast.s3_key)
        
        # 2. Delete from DB
        db.delete(podcast)
        db.commit()
        
        return {"message": "Podcast and associated audio file deleted successfully"}
    except Exception as e:
        db.rollback()
        logger.error(f"Failed to delete podcast: {e}")
        raise HTTPException(status_code=500, detail=f"Deletion failed: {str(e)}")