dev_caio / config.py
Chaitanya-aitf's picture
Initializing project from local
ad4e58a verified
"""
ShortSmith v2 - Configuration Module
Centralized configuration for all components including model paths,
thresholds, domain presets, and runtime settings.
"""
import os
from dataclasses import dataclass, field
from typing import Dict, Optional
from enum import Enum
class ContentDomain(Enum):
"""Supported content domains with different hype characteristics."""
SPORTS = "sports"
VLOGS = "vlogs"
MUSIC = "music"
PODCASTS = "podcasts"
GAMING = "gaming"
GENERAL = "general"
@dataclass
class DomainWeights:
"""Weight configuration for visual vs audio scoring per domain."""
visual_weight: float
audio_weight: float
motion_weight: float = 0.0
def __post_init__(self):
"""Normalize weights to sum to 1.0."""
total = self.visual_weight + self.audio_weight + self.motion_weight
if total > 0:
self.visual_weight /= total
self.audio_weight /= total
self.motion_weight /= total
# Domain-specific weight presets
DOMAIN_PRESETS: Dict[ContentDomain, DomainWeights] = {
ContentDomain.SPORTS: DomainWeights(visual_weight=0.35, audio_weight=0.50, motion_weight=0.15),
ContentDomain.VLOGS: DomainWeights(visual_weight=0.70, audio_weight=0.20, motion_weight=0.10),
ContentDomain.MUSIC: DomainWeights(visual_weight=0.40, audio_weight=0.50, motion_weight=0.10),
ContentDomain.PODCASTS: DomainWeights(visual_weight=0.10, audio_weight=0.85, motion_weight=0.05),
ContentDomain.GAMING: DomainWeights(visual_weight=0.50, audio_weight=0.35, motion_weight=0.15),
ContentDomain.GENERAL: DomainWeights(visual_weight=0.50, audio_weight=0.40, motion_weight=0.10),
}
@dataclass
class ModelConfig:
"""Configuration for AI models."""
# Visual model (Qwen2-VL)
visual_model_id: str = "Qwen/Qwen2-VL-2B-Instruct"
visual_model_quantization: str = "int4" # Options: "int4", "int8", "none"
visual_max_frames: int = 32
# Audio model
audio_model_id: str = "facebook/wav2vec2-base-960h"
use_advanced_audio: bool = False # Use Wav2Vec2 instead of just Librosa
# Face recognition (InsightFace)
face_detection_model: str = "buffalo_l" # SCRFD model
face_similarity_threshold: float = 0.4
# Body recognition (OSNet)
body_model_name: str = "osnet_x1_0"
body_similarity_threshold: float = 0.5
# Motion detection (RAFT)
motion_model: str = "raft-things"
motion_threshold: float = 5.0
# Device settings
device: str = "cuda" # Options: "cuda", "cpu", "mps"
def __post_init__(self):
"""Validate and adjust device based on availability."""
import torch
if self.device == "cuda" and not torch.cuda.is_available():
self.device = "cpu"
elif self.device == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
self.device = "cpu"
@dataclass
class ProcessingConfig:
"""Configuration for video processing pipeline."""
# Sampling settings
coarse_sample_interval: float = 5.0 # Seconds between frames in first pass
dense_sample_fps: float = 3.0 # FPS for dense sampling on candidates
min_motion_for_dense: float = 2.0 # Threshold to trigger dense sampling
# Clip settings
min_clip_duration: float = 10.0 # Minimum clip length in seconds
max_clip_duration: float = 20.0 # Maximum clip length in seconds
default_clip_duration: float = 15.0 # Default clip length
min_gap_between_clips: float = 30.0 # Minimum gap between clip starts
# Output settings
default_num_clips: int = 3
max_num_clips: int = 10
output_format: str = "mp4"
output_codec: str = "libx264"
output_audio_codec: str = "aac"
# Scene detection
scene_threshold: float = 27.0 # PySceneDetect threshold
# Hype scoring
hype_threshold: float = 0.3 # Minimum normalized score to consider
diversity_weight: float = 0.2 # Weight for temporal diversity in ranking
# Performance
batch_size: int = 8 # Frames per batch for model inference
max_video_duration: float = 3600.0 # Maximum video length (1 hour)
# Temporary files
temp_dir: Optional[str] = None
cleanup_temp: bool = True
@dataclass
class AppConfig:
"""Main application configuration."""
model: ModelConfig = field(default_factory=ModelConfig)
processing: ProcessingConfig = field(default_factory=ProcessingConfig)
# Logging
log_level: str = "INFO"
log_file: Optional[str] = "shortsmith.log"
log_to_console: bool = True
# API settings (for future extensibility)
api_key: Optional[str] = None
# UI settings
share_gradio: bool = False
server_port: int = 7860
@classmethod
def from_env(cls) -> "AppConfig":
"""Create configuration from environment variables."""
config = cls()
# Override from environment
if os.environ.get("SHORTSMITH_LOG_LEVEL"):
config.log_level = os.environ["SHORTSMITH_LOG_LEVEL"]
if os.environ.get("SHORTSMITH_DEVICE"):
config.model.device = os.environ["SHORTSMITH_DEVICE"]
if os.environ.get("SHORTSMITH_API_KEY"):
config.api_key = os.environ["SHORTSMITH_API_KEY"]
if os.environ.get("HF_TOKEN"):
# HuggingFace token for accessing gated models
pass
return config
# Global configuration instance
_config: Optional[AppConfig] = None
def get_config() -> AppConfig:
"""Get the global configuration instance."""
global _config
if _config is None:
_config = AppConfig.from_env()
return _config
def set_config(config: AppConfig) -> None:
"""Set the global configuration instance."""
global _config
_config = config
# Export commonly used items
__all__ = [
"ContentDomain",
"DomainWeights",
"DOMAIN_PRESETS",
"ModelConfig",
"ProcessingConfig",
"AppConfig",
"get_config",
"set_config",
]