""" Model Manager for InfiniteTalk Handles lazy loading and caching of models from HuggingFace Hub """ import os import torch from huggingface_hub import snapshot_download from pathlib import Path import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: """Manages model loading and caching""" def __init__(self, cache_dir=None): """ Initialize Model Manager Args: cache_dir: Directory for caching models. Defaults to HF_HOME or /data/.huggingface """ if cache_dir is None: cache_dir = os.environ.get("HF_HOME", "/data/.huggingface") self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.models = {} self.model_paths = { "wan": None, "infinitetalk": None, "wav2vec": None } def download_model(self, repo_id, subfolder=None, filename=None): """ Download model from HuggingFace Hub with caching Args: repo_id: HuggingFace repository ID (e.g., "Kijai/WanVideo_comfy") subfolder: Optional subfolder within the repository filename: Optional specific file to download Returns: Path to downloaded model directory """ try: logger.info(f"Downloading {repo_id} from HuggingFace Hub...") download_kwargs = { "repo_id": repo_id, "cache_dir": str(self.cache_dir), "resume_download": True, } if subfolder: download_kwargs["allow_patterns"] = f"{subfolder}/*" if filename: download_kwargs["allow_patterns"] = filename model_path = snapshot_download(**download_kwargs) if subfolder: model_path = os.path.join(model_path, subfolder) logger.info(f"Model downloaded successfully to {model_path}") return model_path except Exception as e: logger.error(f"Error downloading model {repo_id}: {e}") raise def get_wan_model_path(self): """Get or download Wan2.1 I2V model""" if self.model_paths["wan"] is None: logger.info("Downloading Wan2.1-I2V-14B-480P model...") # This will download the full model - adjust repo_id based on actual HF location self.model_paths["wan"] = self.download_model( repo_id="Kijai/WanVideo_comfy", subfolder="wan2_1_i2v_14B_480P" ) return self.model_paths["wan"] def get_infinitetalk_weights_path(self): """Get or download InfiniteTalk weights""" if self.model_paths["infinitetalk"] is None: logger.info("Downloading InfiniteTalk weights...") self.model_paths["infinitetalk"] = self.download_model( repo_id="MeiGen-AI/InfiniteTalk", subfolder="single" ) return self.model_paths["infinitetalk"] def get_wav2vec_model_path(self): """Get or download Wav2Vec2 audio encoder""" if self.model_paths["wav2vec"] is None: logger.info("Downloading Wav2Vec2 audio encoder...") self.model_paths["wav2vec"] = self.download_model( repo_id="TencentGameMate/chinese-wav2vec2-base" ) return self.model_paths["wav2vec"] def load_wan_model(self, size="infinitetalk-480", device="cuda", offload_model=True): """ Load Wan InfiniteTalk pipeline for inference Args: size: Model size configuration (infinitetalk-480 or infinitetalk-720) device: Device to load model on offload_model: Whether to offload model to CPU between forwards Returns: Loaded InfiniteTalkPipeline """ if "wan_pipeline" not in self.models: import wan from wan.configs import WAN_CONFIGS model_path = self.get_wan_model_path() infinitetalk_path = self.get_infinitetalk_weights_path() infinitetalk_weights = os.path.join(infinitetalk_path, "infinitetalk.safetensors") logger.info(f"Loading InfiniteTalk pipeline from {model_path}...") # Get configuration for infinitetalk-14B task = "infinitetalk-14B" cfg = WAN_CONFIGS[task] # Create InfiniteTalk pipeline # This matches the initialization in generate_infinitetalk.py pipeline = wan.InfiniteTalkPipeline( config=cfg, checkpoint_dir=model_path, quant_dir=None, # No quantization for now device_id=device if isinstance(device, int) else 0, rank=0, # Single GPU t5_fsdp=False, dit_fsdp=False, use_usp=False, t5_cpu=False, lora_dir=None, lora_scales=None, quant=None, dit_path=None, infinitetalk_dir=infinitetalk_weights ) # Enable memory management for low VRAM if needed # pipeline.enable_vram_management(num_persistent_param_in_dit=0) self.models["wan_pipeline"] = pipeline logger.info("InfiniteTalk pipeline loaded successfully") return self.models["wan_pipeline"] def load_audio_encoder(self, device="cuda"): """ Load Wav2Vec2 audio encoder Args: device: Device to load model on Returns: Audio encoder model and feature extractor """ if "audio_encoder" not in self.models: from transformers import Wav2Vec2FeatureExtractor from src.audio_analysis.wav2vec2 import Wav2Vec2Model wav2vec_path = self.get_wav2vec_model_path() logger.info(f"Loading audio encoder from {wav2vec_path}...") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path) audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec_path) audio_encoder.to(device) audio_encoder.eval() self.models["audio_encoder"] = (audio_encoder, feature_extractor) logger.info("Audio encoder loaded successfully") return self.models["audio_encoder"] def unload_model(self, model_name): """Unload a specific model to free memory""" if model_name in self.models: del self.models[model_name] torch.cuda.empty_cache() logger.info(f"Unloaded {model_name}") def clear_all(self): """Unload all models""" self.models.clear() torch.cuda.empty_cache() logger.info("All models unloaded")