Spaces:
Build error
Build error
| """Orchestration service that coordinates all generation stages.""" | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any | |
| import structlog | |
| from datetime import datetime, timezone | |
| from app.core.config import settings | |
| from app.db.models import Generation | |
| from app.schemas.generation import GenerationRequest, GenerationResponse | |
| from app.services.prompt_understanding import get_prompt_service | |
| from app.services.music_generation import get_music_service | |
| from app.services.vocal_generation import get_vocal_service | |
| from app.services.post_processing import get_post_processing_service | |
| # Import connection manager for real-time updates | |
| from app.api.v1.websockets import manager | |
| logger = structlog.get_logger(__name__) | |
| class GenerationOrchestrator: | |
| """Orchestrates the complete music generation pipeline.""" | |
| def __init__(self): | |
| """Initialize the orchestrator.""" | |
| self.logger = logger.bind(service="orchestrator") | |
| self.prompt_service = get_prompt_service() | |
| self.music_service = get_music_service() | |
| self.vocal_service = get_vocal_service() | |
| self.post_processing_service = get_post_processing_service() | |
| async def generate( | |
| self, | |
| request: GenerationRequest, | |
| generation_record: Generation, | |
| ) -> GenerationResponse: | |
| """ | |
| Execute the complete generation pipeline. | |
| Stages: | |
| 1. Prompt understanding and analysis | |
| 2. Music generation | |
| 3. Vocal generation (if lyrics provided) | |
| 4. Mixing (if vocals) | |
| 5. Post-processing/mastering | |
| 6. Metadata extraction | |
| """ | |
| start_time = datetime.now(timezone.utc) | |
| gen_id = str(generation_record.id) | |
| self.logger.info( | |
| "starting_generation", | |
| generation_id=gen_id, | |
| prompt=request.prompt[:100], | |
| ) | |
| try: | |
| # Broadcast start | |
| await manager.broadcast(gen_id, { | |
| "status": "processing", | |
| "stage": "starting", | |
| "progress": 0, | |
| "message": "Starting generation pipeline..." | |
| }) | |
| # Stage 1: Prompt Understanding | |
| self.logger.info("stage_1_prompt_understanding") | |
| await manager.broadcast(gen_id, { | |
| "status": "processing", | |
| "stage": "prompt_analysis", | |
| "progress": 10, | |
| "message": "Analyzing prompt and context..." | |
| }) | |
| analysis = await self.prompt_service.analyze_prompt( | |
| request.prompt, | |
| request.user_context, | |
| ) | |
| # Update generation record with analysis | |
| generation_record.generation_metadata = { | |
| **(generation_record.generation_metadata or {}), | |
| "analysis": analysis.model_dump(), | |
| } | |
| generation_record.style = analysis.style | |
| generation_record.lyrics = analysis.lyrics or request.lyrics | |
| # Stage 2: Music Generation | |
| self.logger.info("stage_2_music_generation") | |
| await manager.broadcast(gen_id, { | |
| "status": "processing", | |
| "stage": "music_generation", | |
| "progress": 20, | |
| "message": f"Generating music ({analysis.style})..." | |
| }) | |
| instrumental_path = await self.music_service.generate( | |
| prompt=analysis.enriched_prompt, | |
| duration=request.duration or analysis.duration_hint, | |
| style=analysis.style, | |
| tempo=analysis.tempo, | |
| ) | |
| generation_record.instrumental_path = str(instrumental_path) | |
| # Stage 3: Vocal Generation (if lyrics provided) | |
| vocal_path = None | |
| if analysis.lyrics or request.lyrics: | |
| self.logger.info("stage_3_vocal_generation") | |
| await manager.broadcast(gen_id, { | |
| "status": "processing", | |
| "stage": "vocal_generation", | |
| "progress": 60, | |
| "message": "Generating vocals..." | |
| }) | |
| lyrics_text = analysis.lyrics or request.lyrics or "" | |
| vocal_path = await self.vocal_service.generate( | |
| text=lyrics_text, | |
| voice_preset=request.voice_preset, | |
| ) | |
| generation_record.vocal_path = str(vocal_path) | |
| # Stage 4: Mixing (if vocals) | |
| if vocal_path: | |
| self.logger.info("stage_4_mixing") | |
| await manager.broadcast(gen_id, { | |
| "status": "processing", | |
| "stage": "mixing", | |
| "progress": 80, | |
| "message": "Mixing vocals and instrumental..." | |
| }) | |
| mixed_path = Path(settings.AUDIO_STORAGE_PATH) / "mixed" | |
| mixed_path.mkdir(parents=True, exist_ok=True) | |
| mixed_file = mixed_path / f"{uuid.uuid4()}.wav" | |
| await self.post_processing_service.mix_audio( | |
| instrumental_path=instrumental_path, | |
| vocal_path=vocal_path, | |
| output_path=mixed_file, | |
| vocal_volume=request.vocal_volume or 0.7, | |
| instrumental_volume=request.instrumental_volume or 0.8, | |
| ) | |
| audio_path = mixed_file | |
| else: | |
| audio_path = instrumental_path | |
| # Stage 5: Post-processing/Mastering | |
| self.logger.info("stage_5_post_processing") | |
| await manager.broadcast(gen_id, { | |
| "status": "processing", | |
| "stage": "mastering", | |
| "progress": 90, | |
| "message": "Mastering final audio..." | |
| }) | |
| mastered_path = Path(settings.AUDIO_STORAGE_PATH) / "mastered" | |
| mastered_path.mkdir(parents=True, exist_ok=True) | |
| mastered_file = mastered_path / f"{uuid.uuid4()}.wav" | |
| await self.post_processing_service.master_audio( | |
| audio_path=audio_path, | |
| output_path=mastered_file, | |
| normalize=True, | |
| apply_compression=True, | |
| apply_eq=True, | |
| ) | |
| generation_record.audio_path = str(mastered_file) | |
| # Stage 6: Update metadata | |
| processing_time = (datetime.now(timezone.utc) - start_time).total_seconds() | |
| generation_record.status = "completed" | |
| generation_record.completed_at = datetime.now(timezone.utc) | |
| generation_record.processing_time_seconds = processing_time | |
| self.logger.info( | |
| "generation_completed", | |
| generation_id=gen_id, | |
| processing_time=processing_time, | |
| ) | |
| await manager.broadcast(gen_id, { | |
| "status": "completed", | |
| "stage": "finished", | |
| "progress": 100, | |
| "audio_url": f"/api/v1/generations/{gen_id}/audio", | |
| "message": "Generation complete!" | |
| }) | |
| return GenerationResponse( | |
| id=generation_record.id, | |
| status="completed", | |
| audio_path=str(mastered_file), | |
| metadata=generation_record.generation_metadata, | |
| processing_time_seconds=processing_time, | |
| ) | |
| except Exception as e: | |
| self.logger.error( | |
| "generation_failed", | |
| generation_id=gen_id, | |
| exc_info=e, | |
| ) | |
| generation_record.status = "failed" | |
| generation_record.error_message = str(e) | |
| await manager.broadcast(gen_id, { | |
| "status": "failed", | |
| "error": str(e), | |
| "message": "Generation failed." | |
| }) | |
| raise | |
| # Singleton instance | |
| _orchestrator: GenerationOrchestrator | None = None | |
| def get_orchestrator() -> GenerationOrchestrator: | |
| """Get orchestrator instance.""" | |
| global _orchestrator | |
| if _orchestrator is None: | |
| _orchestrator = GenerationOrchestrator() | |
| return _orchestrator | |