Spaces:
Paused
Paused
| """ | |
| ShortSmith v2 - Configuration Module | |
| Centralized configuration for all components including model paths, | |
| thresholds, domain presets, and runtime settings. | |
| """ | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import Dict, Optional | |
| from enum import Enum | |
| class ContentDomain(Enum): | |
| """Supported content domains with different hype characteristics.""" | |
| SPORTS = "sports" | |
| VLOGS = "vlogs" | |
| MUSIC = "music" | |
| PODCASTS = "podcasts" | |
| GAMING = "gaming" | |
| GENERAL = "general" | |
| class DomainWeights: | |
| """Weight configuration for visual vs audio scoring per domain.""" | |
| visual_weight: float | |
| audio_weight: float | |
| motion_weight: float = 0.0 | |
| def __post_init__(self): | |
| """Normalize weights to sum to 1.0.""" | |
| total = self.visual_weight + self.audio_weight + self.motion_weight | |
| if total > 0: | |
| self.visual_weight /= total | |
| self.audio_weight /= total | |
| self.motion_weight /= total | |
| # Domain-specific weight presets | |
| DOMAIN_PRESETS: Dict[ContentDomain, DomainWeights] = { | |
| ContentDomain.SPORTS: DomainWeights(visual_weight=0.35, audio_weight=0.50, motion_weight=0.15), | |
| ContentDomain.VLOGS: DomainWeights(visual_weight=0.70, audio_weight=0.20, motion_weight=0.10), | |
| ContentDomain.MUSIC: DomainWeights(visual_weight=0.40, audio_weight=0.50, motion_weight=0.10), | |
| ContentDomain.PODCASTS: DomainWeights(visual_weight=0.10, audio_weight=0.85, motion_weight=0.05), | |
| ContentDomain.GAMING: DomainWeights(visual_weight=0.50, audio_weight=0.35, motion_weight=0.15), | |
| ContentDomain.GENERAL: DomainWeights(visual_weight=0.50, audio_weight=0.40, motion_weight=0.10), | |
| } | |
| class ModelConfig: | |
| """Configuration for AI models.""" | |
| # Visual model (Qwen2-VL) | |
| visual_model_id: str = "Qwen/Qwen2-VL-2B-Instruct" | |
| visual_model_quantization: str = "int4" # Options: "int4", "int8", "none" | |
| visual_max_frames: int = 32 | |
| # Audio model | |
| audio_model_id: str = "facebook/wav2vec2-base-960h" | |
| use_advanced_audio: bool = False # Use Wav2Vec2 instead of just Librosa | |
| # Face recognition (InsightFace) | |
| face_detection_model: str = "buffalo_l" # SCRFD model | |
| face_similarity_threshold: float = 0.4 | |
| # Body recognition (OSNet) | |
| body_model_name: str = "osnet_x1_0" | |
| body_similarity_threshold: float = 0.5 | |
| # Motion detection (RAFT) | |
| motion_model: str = "raft-things" | |
| motion_threshold: float = 5.0 | |
| # Device settings | |
| device: str = "cuda" # Options: "cuda", "cpu", "mps" | |
| def __post_init__(self): | |
| """Validate and adjust device based on availability.""" | |
| import torch | |
| if self.device == "cuda" and not torch.cuda.is_available(): | |
| self.device = "cpu" | |
| elif self.device == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): | |
| self.device = "cpu" | |
| class ProcessingConfig: | |
| """Configuration for video processing pipeline.""" | |
| # Sampling settings | |
| coarse_sample_interval: float = 5.0 # Seconds between frames in first pass | |
| dense_sample_fps: float = 3.0 # FPS for dense sampling on candidates | |
| min_motion_for_dense: float = 2.0 # Threshold to trigger dense sampling | |
| # Clip settings | |
| min_clip_duration: float = 10.0 # Minimum clip length in seconds | |
| max_clip_duration: float = 20.0 # Maximum clip length in seconds | |
| default_clip_duration: float = 15.0 # Default clip length | |
| min_gap_between_clips: float = 30.0 # Minimum gap between clip starts | |
| # Output settings | |
| default_num_clips: int = 3 | |
| max_num_clips: int = 10 | |
| output_format: str = "mp4" | |
| output_codec: str = "libx264" | |
| output_audio_codec: str = "aac" | |
| # Scene detection | |
| scene_threshold: float = 27.0 # PySceneDetect threshold | |
| # Hype scoring | |
| hype_threshold: float = 0.3 # Minimum normalized score to consider | |
| diversity_weight: float = 0.2 # Weight for temporal diversity in ranking | |
| # Performance | |
| batch_size: int = 8 # Frames per batch for model inference | |
| max_video_duration: float = 7200.0 # Maximum video length (2 hours) | |
| # Temporary files | |
| temp_dir: Optional[str] = None | |
| cleanup_temp: bool = True | |
| class AppConfig: | |
| """Main application configuration.""" | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| processing: ProcessingConfig = field(default_factory=ProcessingConfig) | |
| # Logging | |
| log_level: str = "INFO" | |
| log_file: Optional[str] = "shortsmith.log" | |
| log_to_console: bool = True | |
| # API settings (for future extensibility) | |
| api_key: Optional[str] = None | |
| # UI settings | |
| share_gradio: bool = False | |
| server_port: int = 7860 | |
| def from_env(cls) -> "AppConfig": | |
| """Create configuration from environment variables.""" | |
| config = cls() | |
| # Override from environment | |
| if os.environ.get("SHORTSMITH_LOG_LEVEL"): | |
| config.log_level = os.environ["SHORTSMITH_LOG_LEVEL"] | |
| if os.environ.get("SHORTSMITH_DEVICE"): | |
| config.model.device = os.environ["SHORTSMITH_DEVICE"] | |
| if os.environ.get("SHORTSMITH_API_KEY"): | |
| config.api_key = os.environ["SHORTSMITH_API_KEY"] | |
| if os.environ.get("HF_TOKEN"): | |
| # HuggingFace token for accessing gated models | |
| pass | |
| return config | |
| # Global configuration instance | |
| _config: Optional[AppConfig] = None | |
| def get_config() -> AppConfig: | |
| """Get the global configuration instance.""" | |
| global _config | |
| if _config is None: | |
| _config = AppConfig.from_env() | |
| return _config | |
| def set_config(config: AppConfig) -> None: | |
| """Set the global configuration instance.""" | |
| global _config | |
| _config = config | |
| # Export commonly used items | |
| __all__ = [ | |
| "ContentDomain", | |
| "DomainWeights", | |
| "DOMAIN_PRESETS", | |
| "ModelConfig", | |
| "ProcessingConfig", | |
| "AppConfig", | |
| "get_config", | |
| "set_config", | |
| ] | |