"""Session data structures and in-memory storage.""" import uuid import time from dataclasses import dataclass, field from typing import Optional import numpy as np @dataclass class StemData: """Represents a single audio stem.""" name: str # "bass", "drums", "guitar", etc. audio: np.ndarray # shape: (samples,) for mono or (samples, channels) for stereo sample_rate: int @dataclass class Session: """Represents a user session with uploaded audio and analysis results.""" id: str = field(default_factory=lambda: str(uuid.uuid4())) created_at: float = field(default_factory=time.time) stems: dict[str, StemData] = field(default_factory=dict) full_mix: Optional[StemData] = None midi_data: Optional[object] = None # mido.MidiFile if provided detected_bpm: Optional[float] = None detected_key: Optional[str] = None detected_mode: Optional[str] = None detection_confidence: Optional[dict] = None processed_stems: dict[str, StemData] = field(default_factory=dict) region_processed_stems: dict[str, StemData] = field(default_factory=dict) generated_continuation: Optional[StemData] = None original_sr: int = 44100 # Cache for encoded WAV bytes to avoid re-encoding on each request wav_cache: dict[str, bytes] = field(default_factory=dict) # Optional metadata overrides loaded from preset's metadata.json preset_metadata: Optional[dict] = None # Chord progression and scale suggestions (cached after first detection) chord_progression: Optional[list] = None scale_suggestions: Optional[list] = None chord_source: Optional[str] = None # In-memory session store _sessions: dict[str, Session] = {} def create_session() -> Session: """Create a new session and store it.""" session = Session() _sessions[session.id] = session return session def get_session(session_id: str) -> Optional[Session]: """Get a session by ID, returns None if not found.""" return _sessions.get(session_id) def delete_session(session_id: str) -> bool: """Delete a session by ID, returns True if deleted.""" if session_id in _sessions: del _sessions[session_id] return True return False def cleanup_old_sessions(max_age_seconds: float = 3600) -> int: """ Delete sessions older than max_age_seconds. Returns the number of sessions deleted. """ now = time.time() to_delete = [ sid for sid, session in _sessions.items() if now - session.created_at > max_age_seconds ] for sid in to_delete: del _sessions[sid] return len(to_delete) def get_all_sessions() -> dict[str, Session]: """Get all sessions (for testing/debugging).""" return _sessions.copy()