Spaces:
Paused
Paused
File size: 5,937 Bytes
c4ee290 0e60bc6 c4ee290 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | """
ShortSmith v2 - Configuration Module
Centralized configuration for all components including model paths,
thresholds, domain presets, and runtime settings.
"""
import os
from dataclasses import dataclass, field
from typing import Dict, Optional
from enum import Enum
class ContentDomain(Enum):
"""Supported content domains with different hype characteristics."""
SPORTS = "sports"
VLOGS = "vlogs"
MUSIC = "music"
PODCASTS = "podcasts"
GAMING = "gaming"
GENERAL = "general"
@dataclass
class DomainWeights:
"""Weight configuration for visual vs audio scoring per domain."""
visual_weight: float
audio_weight: float
motion_weight: float = 0.0
def __post_init__(self):
"""Normalize weights to sum to 1.0."""
total = self.visual_weight + self.audio_weight + self.motion_weight
if total > 0:
self.visual_weight /= total
self.audio_weight /= total
self.motion_weight /= total
# Domain-specific weight presets
DOMAIN_PRESETS: Dict[ContentDomain, DomainWeights] = {
ContentDomain.SPORTS: DomainWeights(visual_weight=0.35, audio_weight=0.50, motion_weight=0.15),
ContentDomain.VLOGS: DomainWeights(visual_weight=0.70, audio_weight=0.20, motion_weight=0.10),
ContentDomain.MUSIC: DomainWeights(visual_weight=0.40, audio_weight=0.50, motion_weight=0.10),
ContentDomain.PODCASTS: DomainWeights(visual_weight=0.10, audio_weight=0.85, motion_weight=0.05),
ContentDomain.GAMING: DomainWeights(visual_weight=0.50, audio_weight=0.35, motion_weight=0.15),
ContentDomain.GENERAL: DomainWeights(visual_weight=0.50, audio_weight=0.40, motion_weight=0.10),
}
@dataclass
class ModelConfig:
"""Configuration for AI models."""
# Visual model (Qwen2-VL)
visual_model_id: str = "Qwen/Qwen2-VL-2B-Instruct"
visual_model_quantization: str = "int4" # Options: "int4", "int8", "none"
visual_max_frames: int = 32
# Audio model
audio_model_id: str = "facebook/wav2vec2-base-960h"
use_advanced_audio: bool = False # Use Wav2Vec2 instead of just Librosa
# Face recognition (InsightFace)
face_detection_model: str = "buffalo_l" # SCRFD model
face_similarity_threshold: float = 0.4
# Body recognition (OSNet)
body_model_name: str = "osnet_x1_0"
body_similarity_threshold: float = 0.5
# Motion detection (RAFT)
motion_model: str = "raft-things"
motion_threshold: float = 5.0
# Device settings
device: str = "cuda" # Options: "cuda", "cpu", "mps"
def __post_init__(self):
"""Validate and adjust device based on availability."""
import torch
if self.device == "cuda" and not torch.cuda.is_available():
self.device = "cpu"
elif self.device == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
self.device = "cpu"
@dataclass
class ProcessingConfig:
"""Configuration for video processing pipeline."""
# Sampling settings
coarse_sample_interval: float = 5.0 # Seconds between frames in first pass
dense_sample_fps: float = 3.0 # FPS for dense sampling on candidates
min_motion_for_dense: float = 2.0 # Threshold to trigger dense sampling
# Clip settings
min_clip_duration: float = 10.0 # Minimum clip length in seconds
max_clip_duration: float = 20.0 # Maximum clip length in seconds
default_clip_duration: float = 15.0 # Default clip length
min_gap_between_clips: float = 30.0 # Minimum gap between clip starts
# Output settings
default_num_clips: int = 3
max_num_clips: int = 10
output_format: str = "mp4"
output_codec: str = "libx264"
output_audio_codec: str = "aac"
# Scene detection
scene_threshold: float = 27.0 # PySceneDetect threshold
# Hype scoring
hype_threshold: float = 0.3 # Minimum normalized score to consider
diversity_weight: float = 0.2 # Weight for temporal diversity in ranking
# Performance
batch_size: int = 8 # Frames per batch for model inference
max_video_duration: float = 12000.0
# Temporary files
temp_dir: Optional[str] = None
cleanup_temp: bool = True
@dataclass
class AppConfig:
"""Main application configuration."""
model: ModelConfig = field(default_factory=ModelConfig)
processing: ProcessingConfig = field(default_factory=ProcessingConfig)
# Logging
log_level: str = "INFO"
log_file: Optional[str] = "shortsmith.log"
log_to_console: bool = True
# API settings (for future extensibility)
api_key: Optional[str] = None
# UI settings
share_gradio: bool = False
server_port: int = 7860
@classmethod
def from_env(cls) -> "AppConfig":
"""Create configuration from environment variables."""
config = cls()
# Override from environment
if os.environ.get("SHORTSMITH_LOG_LEVEL"):
config.log_level = os.environ["SHORTSMITH_LOG_LEVEL"]
if os.environ.get("SHORTSMITH_DEVICE"):
config.model.device = os.environ["SHORTSMITH_DEVICE"]
if os.environ.get("SHORTSMITH_API_KEY"):
config.api_key = os.environ["SHORTSMITH_API_KEY"]
if os.environ.get("HF_TOKEN"):
# HuggingFace token for accessing gated models
pass
return config
# Global configuration instance
_config: Optional[AppConfig] = None
def get_config() -> AppConfig:
"""Get the global configuration instance."""
global _config
if _config is None:
_config = AppConfig.from_env()
return _config
def set_config(config: AppConfig) -> None:
"""Set the global configuration instance."""
global _config
_config = config
# Export commonly used items
__all__ = [
"ContentDomain",
"DomainWeights",
"DOMAIN_PRESETS",
"ModelConfig",
"ProcessingConfig",
"AppConfig",
"get_config",
"set_config",
]
|