""" ShortSmith v2 - Configuration Module Centralized configuration for all components including model paths, thresholds, domain presets, and runtime settings. """ import os from dataclasses import dataclass, field from typing import Dict, Optional from enum import Enum class ContentDomain(Enum): """Supported content domains with different hype characteristics.""" SPORTS = "sports" VLOGS = "vlogs" MUSIC = "music" PODCASTS = "podcasts" GAMING = "gaming" GENERAL = "general" @dataclass class DomainWeights: """Weight configuration for visual vs audio scoring per domain.""" visual_weight: float audio_weight: float motion_weight: float = 0.0 def __post_init__(self): """Normalize weights to sum to 1.0.""" total = self.visual_weight + self.audio_weight + self.motion_weight if total > 0: self.visual_weight /= total self.audio_weight /= total self.motion_weight /= total # Domain-specific weight presets DOMAIN_PRESETS: Dict[ContentDomain, DomainWeights] = { ContentDomain.SPORTS: DomainWeights(visual_weight=0.35, audio_weight=0.50, motion_weight=0.15), ContentDomain.VLOGS: DomainWeights(visual_weight=0.70, audio_weight=0.20, motion_weight=0.10), ContentDomain.MUSIC: DomainWeights(visual_weight=0.40, audio_weight=0.50, motion_weight=0.10), ContentDomain.PODCASTS: DomainWeights(visual_weight=0.10, audio_weight=0.85, motion_weight=0.05), ContentDomain.GAMING: DomainWeights(visual_weight=0.50, audio_weight=0.35, motion_weight=0.15), ContentDomain.GENERAL: DomainWeights(visual_weight=0.50, audio_weight=0.40, motion_weight=0.10), } @dataclass class ModelConfig: """Configuration for AI models.""" # Visual model (Qwen2-VL) visual_model_id: str = "Qwen/Qwen2-VL-2B-Instruct" visual_model_quantization: str = "int4" # Options: "int4", "int8", "none" visual_max_frames: int = 32 # Audio model audio_model_id: str = "facebook/wav2vec2-base-960h" use_advanced_audio: bool = False # Use Wav2Vec2 instead of just Librosa # Face recognition (InsightFace) face_detection_model: str = "buffalo_l" # SCRFD model face_similarity_threshold: float = 0.4 # Body recognition (OSNet) body_model_name: str = "osnet_x1_0" body_similarity_threshold: float = 0.5 # Motion detection (RAFT) motion_model: str = "raft-things" motion_threshold: float = 5.0 # Device settings device: str = "cuda" # Options: "cuda", "cpu", "mps" def __post_init__(self): """Validate and adjust device based on availability.""" import torch if self.device == "cuda" and not torch.cuda.is_available(): self.device = "cpu" elif self.device == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): self.device = "cpu" @dataclass class ProcessingConfig: """Configuration for video processing pipeline.""" # Sampling settings coarse_sample_interval: float = 5.0 # Seconds between frames in first pass dense_sample_fps: float = 3.0 # FPS for dense sampling on candidates min_motion_for_dense: float = 2.0 # Threshold to trigger dense sampling # Clip settings min_clip_duration: float = 10.0 # Minimum clip length in seconds max_clip_duration: float = 20.0 # Maximum clip length in seconds default_clip_duration: float = 15.0 # Default clip length min_gap_between_clips: float = 30.0 # Minimum gap between clip starts # Output settings default_num_clips: int = 3 max_num_clips: int = 10 output_format: str = "mp4" output_codec: str = "libx264" output_audio_codec: str = "aac" # Scene detection scene_threshold: float = 27.0 # PySceneDetect threshold # Hype scoring hype_threshold: float = 0.3 # Minimum normalized score to consider diversity_weight: float = 0.2 # Weight for temporal diversity in ranking # Performance batch_size: int = 8 # Frames per batch for model inference max_video_duration: float = 3600.0 # Maximum video length (1 hour) # Temporary files temp_dir: Optional[str] = None cleanup_temp: bool = True @dataclass class AppConfig: """Main application configuration.""" model: ModelConfig = field(default_factory=ModelConfig) processing: ProcessingConfig = field(default_factory=ProcessingConfig) # Logging log_level: str = "INFO" log_file: Optional[str] = "shortsmith.log" log_to_console: bool = True # API settings (for future extensibility) api_key: Optional[str] = None # UI settings share_gradio: bool = False server_port: int = 7860 @classmethod def from_env(cls) -> "AppConfig": """Create configuration from environment variables.""" config = cls() # Override from environment if os.environ.get("SHORTSMITH_LOG_LEVEL"): config.log_level = os.environ["SHORTSMITH_LOG_LEVEL"] if os.environ.get("SHORTSMITH_DEVICE"): config.model.device = os.environ["SHORTSMITH_DEVICE"] if os.environ.get("SHORTSMITH_API_KEY"): config.api_key = os.environ["SHORTSMITH_API_KEY"] if os.environ.get("HF_TOKEN"): # HuggingFace token for accessing gated models pass return config # Global configuration instance _config: Optional[AppConfig] = None def get_config() -> AppConfig: """Get the global configuration instance.""" global _config if _config is None: _config = AppConfig.from_env() return _config def set_config(config: AppConfig) -> None: """Set the global configuration instance.""" global _config _config = config # Export commonly used items __all__ = [ "ContentDomain", "DomainWeights", "DOMAIN_PRESETS", "ModelConfig", "ProcessingConfig", "AppConfig", "get_config", "set_config", ]