AudioForge / backend /app /services /orchestrator.py
OnyxlMunkey's picture
c618549
"""Orchestration service that coordinates all generation stages."""
import uuid
from pathlib import Path
from typing import Any
import structlog
from datetime import datetime, timezone
from app.core.config import settings
from app.db.models import Generation
from app.schemas.generation import GenerationRequest, GenerationResponse
from app.services.prompt_understanding import get_prompt_service
from app.services.music_generation import get_music_service
from app.services.vocal_generation import get_vocal_service
from app.services.post_processing import get_post_processing_service
# Import connection manager for real-time updates
from app.api.v1.websockets import manager
logger = structlog.get_logger(__name__)
class GenerationOrchestrator:
"""Orchestrates the complete music generation pipeline."""
def __init__(self):
"""Initialize the orchestrator."""
self.logger = logger.bind(service="orchestrator")
self.prompt_service = get_prompt_service()
self.music_service = get_music_service()
self.vocal_service = get_vocal_service()
self.post_processing_service = get_post_processing_service()
async def generate(
self,
request: GenerationRequest,
generation_record: Generation,
) -> GenerationResponse:
"""
Execute the complete generation pipeline.
Stages:
1. Prompt understanding and analysis
2. Music generation
3. Vocal generation (if lyrics provided)
4. Mixing (if vocals)
5. Post-processing/mastering
6. Metadata extraction
"""
start_time = datetime.now(timezone.utc)
gen_id = str(generation_record.id)
self.logger.info(
"starting_generation",
generation_id=gen_id,
prompt=request.prompt[:100],
)
try:
# Broadcast start
await manager.broadcast(gen_id, {
"status": "processing",
"stage": "starting",
"progress": 0,
"message": "Starting generation pipeline..."
})
# Stage 1: Prompt Understanding
self.logger.info("stage_1_prompt_understanding")
await manager.broadcast(gen_id, {
"status": "processing",
"stage": "prompt_analysis",
"progress": 10,
"message": "Analyzing prompt and context..."
})
analysis = await self.prompt_service.analyze_prompt(
request.prompt,
request.user_context,
)
# Update generation record with analysis
generation_record.generation_metadata = {
**(generation_record.generation_metadata or {}),
"analysis": analysis.model_dump(),
}
generation_record.style = analysis.style
generation_record.lyrics = analysis.lyrics or request.lyrics
# Stage 2: Music Generation
self.logger.info("stage_2_music_generation")
await manager.broadcast(gen_id, {
"status": "processing",
"stage": "music_generation",
"progress": 20,
"message": f"Generating music ({analysis.style})..."
})
instrumental_path = await self.music_service.generate(
prompt=analysis.enriched_prompt,
duration=request.duration or analysis.duration_hint,
style=analysis.style,
tempo=analysis.tempo,
)
generation_record.instrumental_path = str(instrumental_path)
# Stage 3: Vocal Generation (if lyrics provided)
vocal_path = None
if analysis.lyrics or request.lyrics:
self.logger.info("stage_3_vocal_generation")
await manager.broadcast(gen_id, {
"status": "processing",
"stage": "vocal_generation",
"progress": 60,
"message": "Generating vocals..."
})
lyrics_text = analysis.lyrics or request.lyrics or ""
vocal_path = await self.vocal_service.generate(
text=lyrics_text,
voice_preset=request.voice_preset,
)
generation_record.vocal_path = str(vocal_path)
# Stage 4: Mixing (if vocals)
if vocal_path:
self.logger.info("stage_4_mixing")
await manager.broadcast(gen_id, {
"status": "processing",
"stage": "mixing",
"progress": 80,
"message": "Mixing vocals and instrumental..."
})
mixed_path = Path(settings.AUDIO_STORAGE_PATH) / "mixed"
mixed_path.mkdir(parents=True, exist_ok=True)
mixed_file = mixed_path / f"{uuid.uuid4()}.wav"
await self.post_processing_service.mix_audio(
instrumental_path=instrumental_path,
vocal_path=vocal_path,
output_path=mixed_file,
vocal_volume=request.vocal_volume or 0.7,
instrumental_volume=request.instrumental_volume or 0.8,
)
audio_path = mixed_file
else:
audio_path = instrumental_path
# Stage 5: Post-processing/Mastering
self.logger.info("stage_5_post_processing")
await manager.broadcast(gen_id, {
"status": "processing",
"stage": "mastering",
"progress": 90,
"message": "Mastering final audio..."
})
mastered_path = Path(settings.AUDIO_STORAGE_PATH) / "mastered"
mastered_path.mkdir(parents=True, exist_ok=True)
mastered_file = mastered_path / f"{uuid.uuid4()}.wav"
await self.post_processing_service.master_audio(
audio_path=audio_path,
output_path=mastered_file,
normalize=True,
apply_compression=True,
apply_eq=True,
)
generation_record.audio_path = str(mastered_file)
# Stage 6: Update metadata
processing_time = (datetime.now(timezone.utc) - start_time).total_seconds()
generation_record.status = "completed"
generation_record.completed_at = datetime.now(timezone.utc)
generation_record.processing_time_seconds = processing_time
self.logger.info(
"generation_completed",
generation_id=gen_id,
processing_time=processing_time,
)
await manager.broadcast(gen_id, {
"status": "completed",
"stage": "finished",
"progress": 100,
"audio_url": f"/api/v1/generations/{gen_id}/audio",
"message": "Generation complete!"
})
return GenerationResponse(
id=generation_record.id,
status="completed",
audio_path=str(mastered_file),
metadata=generation_record.generation_metadata,
processing_time_seconds=processing_time,
)
except Exception as e:
self.logger.error(
"generation_failed",
generation_id=gen_id,
exc_info=e,
)
generation_record.status = "failed"
generation_record.error_message = str(e)
await manager.broadcast(gen_id, {
"status": "failed",
"error": str(e),
"message": "Generation failed."
})
raise
# Singleton instance
_orchestrator: GenerationOrchestrator | None = None
def get_orchestrator() -> GenerationOrchestrator:
"""Get orchestrator instance."""
global _orchestrator
if _orchestrator is None:
_orchestrator = GenerationOrchestrator()
return _orchestrator