Spaces:
Build error
Build error
| """Generation endpoints.""" | |
| from typing import Any | |
| from uuid import UUID | |
| import structlog | |
| from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query | |
| from fastapi.responses import FileResponse | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from sqlalchemy import select, func | |
| from pathlib import Path | |
| from app.db.database import get_db | |
| from app.db.models import Generation | |
| from app.schemas.generation import ( | |
| GenerationRequest, | |
| GenerationResponse, | |
| GenerationListResponse, | |
| ) | |
| from app.services.orchestrator import get_orchestrator | |
| from app.core.metrics import http_requests_total, http_request_duration | |
| from app.core.config import settings | |
| logger = structlog.get_logger(__name__) | |
| router = APIRouter() | |
| async def create_generation( | |
| request: GenerationRequest, | |
| background_tasks: BackgroundTasks, | |
| db: AsyncSession = Depends(get_db), | |
| ) -> GenerationResponse: | |
| """ | |
| Create a new music generation request. | |
| Returns immediately with generation ID, processing happens in background. | |
| """ | |
| with http_request_duration.labels(method="POST", endpoint="/generations").time(): | |
| try: | |
| # Create generation record | |
| generation = Generation( | |
| prompt=request.prompt, | |
| lyrics=request.lyrics, | |
| style=request.style, | |
| duration=request.duration or settings.MUSICGEN_DURATION, | |
| status="pending", | |
| ) | |
| db.add(generation) | |
| await db.commit() | |
| await db.refresh(generation) | |
| # Start background processing | |
| background_tasks.add_task(process_generation_task, generation.id, request) | |
| http_requests_total.labels( | |
| method="POST", endpoint="/generations", status="202" | |
| ).inc() | |
| logger.info( | |
| "generation_created", | |
| generation_id=str(generation.id), | |
| prompt=request.prompt[:100], | |
| ) | |
| return GenerationResponse( | |
| id=generation.id, | |
| status="pending", | |
| prompt=generation.prompt, | |
| created_at=generation.created_at, | |
| ) | |
| except Exception as e: | |
| logger.error("failed_to_create_generation", exc_info=e) | |
| http_requests_total.labels( | |
| method="POST", endpoint="/generations", status="500" | |
| ).inc() | |
| raise HTTPException(status_code=500, detail="Failed to create generation") | |
| async def get_generation( | |
| generation_id: UUID, | |
| db: AsyncSession = Depends(get_db), | |
| ) -> GenerationResponse: | |
| """Get generation by ID.""" | |
| with http_request_duration.labels( | |
| method="GET", endpoint="/generations/{id}" | |
| ).time(): | |
| result = await db.execute( | |
| select(Generation).where(Generation.id == generation_id) | |
| ) | |
| generation = result.scalar_one_or_none() | |
| if not generation: | |
| http_requests_total.labels( | |
| method="GET", endpoint="/generations/{id}", status="404" | |
| ).inc() | |
| raise HTTPException(status_code=404, detail="Generation not found") | |
| http_requests_total.labels( | |
| method="GET", endpoint="/generations/{id}", status="200" | |
| ).inc() | |
| audio_url = None | |
| if generation.audio_path and generation.status == "completed": | |
| audio_url = f"/api/v1/generations/{generation.id}/audio" | |
| logger.info("debug_get_generation", audio_url=audio_url, original_path=generation.audio_path) | |
| return GenerationResponse( | |
| id=generation.id, | |
| status=generation.status, | |
| prompt=generation.prompt, | |
| audio_path=audio_url, | |
| metadata=generation.generation_metadata, | |
| processing_time_seconds=generation.processing_time_seconds, | |
| error_message=generation.error_message, | |
| created_at=generation.created_at, | |
| completed_at=generation.completed_at, | |
| ) | |
| async def get_generation_audio( | |
| generation_id: UUID, | |
| db: AsyncSession = Depends(get_db), | |
| ) -> FileResponse: | |
| """Get generated audio file.""" | |
| result = await db.execute( | |
| select(Generation).where(Generation.id == generation_id) | |
| ) | |
| generation = result.scalar_one_or_none() | |
| if not generation: | |
| raise HTTPException(status_code=404, detail="Generation not found") | |
| if not generation.audio_path: | |
| raise HTTPException( | |
| status_code=404, detail="Audio not yet generated" | |
| ) | |
| audio_path = Path(generation.audio_path) | |
| if not audio_path.exists(): | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| return FileResponse( | |
| path=str(audio_path), | |
| media_type="audio/wav", | |
| filename=f"generation-{generation_id}.wav", | |
| ) | |
| async def list_generations( | |
| page: int = Query(1, ge=1), | |
| page_size: int = Query(20, ge=1, le=100), | |
| db: AsyncSession = Depends(get_db), | |
| ) -> GenerationListResponse: | |
| """List generations with pagination.""" | |
| with http_request_duration.labels(method="GET", endpoint="/generations").time(): | |
| # Get total count | |
| count_result = await db.execute(select(func.count(Generation.id))) | |
| total = count_result.scalar_one() | |
| # Get paginated results | |
| offset = (page - 1) * page_size | |
| result = await db.execute( | |
| select(Generation) | |
| .order_by(Generation.created_at.desc()) | |
| .offset(offset) | |
| .limit(page_size) | |
| ) | |
| generations = result.scalars().all() | |
| http_requests_total.labels( | |
| method="GET", endpoint="/generations", status="200" | |
| ).inc() | |
| return GenerationListResponse( | |
| items=[ | |
| GenerationResponse( | |
| id=g.id, | |
| status=g.status, | |
| prompt=g.prompt, | |
| audio_path=f"/api/v1/generations/{g.id}/audio" if g.audio_path and g.status == "completed" else None, | |
| metadata=g.generation_metadata, | |
| processing_time_seconds=g.processing_time_seconds, | |
| error_message=g.error_message, | |
| created_at=g.created_at, | |
| completed_at=g.completed_at, | |
| ) | |
| for g in generations | |
| ], | |
| total=total, | |
| page=page, | |
| page_size=page_size, | |
| ) | |
| async def process_generation_task(generation_id: UUID, request: GenerationRequest) -> None: | |
| """Background task to process generation.""" | |
| from app.db.database import AsyncSessionLocal | |
| async with AsyncSessionLocal() as db: | |
| try: | |
| # Get generation record | |
| result = await db.execute( | |
| select(Generation).where(Generation.id == generation_id) | |
| ) | |
| generation = result.scalar_one() | |
| # Update status | |
| generation.status = "processing" | |
| await db.commit() | |
| # Run orchestrator | |
| orchestrator = get_orchestrator() | |
| await orchestrator.generate(request, generation) | |
| # Commit final state | |
| await db.commit() | |
| except Exception as e: | |
| logger.error( | |
| "background_generation_failed", | |
| generation_id=str(generation_id), | |
| exc_info=e, | |
| ) | |
| # Update error status | |
| try: | |
| result = await db.execute( | |
| select(Generation).where(Generation.id == generation_id) | |
| ) | |
| generation = result.scalar_one() | |
| generation.status = "failed" | |
| generation.error_message = str(e) | |
| await db.commit() | |
| except Exception: | |
| pass | |