Spaces:
Running
Running
| """ | |
| Model Manager for InfiniteTalk | |
| Handles lazy loading and caching of models from HuggingFace Hub | |
| """ | |
| import os | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from pathlib import Path | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| """Manages model loading and caching""" | |
| def __init__(self, cache_dir=None): | |
| """ | |
| Initialize Model Manager | |
| Args: | |
| cache_dir: Directory for caching models. Defaults to HF_HOME or /data/.huggingface | |
| """ | |
| if cache_dir is None: | |
| cache_dir = os.environ.get("HF_HOME", "/data/.huggingface") | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self.models = {} | |
| self.model_paths = { | |
| "wan": None, | |
| "infinitetalk": None, | |
| "wav2vec": None | |
| } | |
| def download_model(self, repo_id, subfolder=None, filename=None): | |
| """ | |
| Download model from HuggingFace Hub with caching | |
| Args: | |
| repo_id: HuggingFace repository ID (e.g., "Kijai/WanVideo_comfy") | |
| subfolder: Optional subfolder within the repository | |
| filename: Optional specific file to download | |
| Returns: | |
| Path to downloaded model directory | |
| """ | |
| try: | |
| logger.info(f"Downloading {repo_id} from HuggingFace Hub...") | |
| download_kwargs = { | |
| "repo_id": repo_id, | |
| "cache_dir": str(self.cache_dir), | |
| "resume_download": True, | |
| } | |
| if subfolder: | |
| download_kwargs["allow_patterns"] = f"{subfolder}/*" | |
| if filename: | |
| download_kwargs["allow_patterns"] = filename | |
| model_path = snapshot_download(**download_kwargs) | |
| if subfolder: | |
| model_path = os.path.join(model_path, subfolder) | |
| logger.info(f"Model downloaded successfully to {model_path}") | |
| return model_path | |
| except Exception as e: | |
| logger.error(f"Error downloading model {repo_id}: {e}") | |
| raise | |
| def get_wan_model_path(self): | |
| """Get or download Wan2.1 I2V model""" | |
| if self.model_paths["wan"] is None: | |
| logger.info("Downloading Wan2.1-I2V-14B-480P model...") | |
| # This will download the full model - adjust repo_id based on actual HF location | |
| self.model_paths["wan"] = self.download_model( | |
| repo_id="Kijai/WanVideo_comfy", | |
| subfolder="wan2_1_i2v_14B_480P" | |
| ) | |
| return self.model_paths["wan"] | |
| def get_infinitetalk_weights_path(self): | |
| """Get or download InfiniteTalk weights""" | |
| if self.model_paths["infinitetalk"] is None: | |
| logger.info("Downloading InfiniteTalk weights...") | |
| self.model_paths["infinitetalk"] = self.download_model( | |
| repo_id="MeiGen-AI/InfiniteTalk", | |
| subfolder="single" | |
| ) | |
| return self.model_paths["infinitetalk"] | |
| def get_wav2vec_model_path(self): | |
| """Get or download Wav2Vec2 audio encoder""" | |
| if self.model_paths["wav2vec"] is None: | |
| logger.info("Downloading Wav2Vec2 audio encoder...") | |
| self.model_paths["wav2vec"] = self.download_model( | |
| repo_id="TencentGameMate/chinese-wav2vec2-base" | |
| ) | |
| return self.model_paths["wav2vec"] | |
| def load_wan_model(self, size="infinitetalk-480", device="cuda", offload_model=True): | |
| """ | |
| Load Wan InfiniteTalk pipeline for inference | |
| Args: | |
| size: Model size configuration (infinitetalk-480 or infinitetalk-720) | |
| device: Device to load model on | |
| offload_model: Whether to offload model to CPU between forwards | |
| Returns: | |
| Loaded InfiniteTalkPipeline | |
| """ | |
| if "wan_pipeline" not in self.models: | |
| import wan | |
| from wan.configs import WAN_CONFIGS | |
| model_path = self.get_wan_model_path() | |
| infinitetalk_path = self.get_infinitetalk_weights_path() | |
| infinitetalk_weights = os.path.join(infinitetalk_path, "infinitetalk.safetensors") | |
| logger.info(f"Loading InfiniteTalk pipeline from {model_path}...") | |
| # Get configuration for infinitetalk-14B | |
| task = "infinitetalk-14B" | |
| cfg = WAN_CONFIGS[task] | |
| # Create InfiniteTalk pipeline | |
| # This matches the initialization in generate_infinitetalk.py | |
| pipeline = wan.InfiniteTalkPipeline( | |
| config=cfg, | |
| checkpoint_dir=model_path, | |
| quant_dir=None, # No quantization for now | |
| device_id=device if isinstance(device, int) else 0, | |
| rank=0, # Single GPU | |
| t5_fsdp=False, | |
| dit_fsdp=False, | |
| use_usp=False, | |
| t5_cpu=False, | |
| lora_dir=None, | |
| lora_scales=None, | |
| quant=None, | |
| dit_path=None, | |
| infinitetalk_dir=infinitetalk_weights | |
| ) | |
| # Enable memory management for low VRAM if needed | |
| # pipeline.enable_vram_management(num_persistent_param_in_dit=0) | |
| self.models["wan_pipeline"] = pipeline | |
| logger.info("InfiniteTalk pipeline loaded successfully") | |
| return self.models["wan_pipeline"] | |
| def load_audio_encoder(self, device="cuda"): | |
| """ | |
| Load Wav2Vec2 audio encoder | |
| Args: | |
| device: Device to load model on | |
| Returns: | |
| Audio encoder model and feature extractor | |
| """ | |
| if "audio_encoder" not in self.models: | |
| from transformers import Wav2Vec2FeatureExtractor | |
| from src.audio_analysis.wav2vec2 import Wav2Vec2Model | |
| wav2vec_path = self.get_wav2vec_model_path() | |
| logger.info(f"Loading audio encoder from {wav2vec_path}...") | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path) | |
| audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec_path) | |
| audio_encoder.to(device) | |
| audio_encoder.eval() | |
| self.models["audio_encoder"] = (audio_encoder, feature_extractor) | |
| logger.info("Audio encoder loaded successfully") | |
| return self.models["audio_encoder"] | |
| def unload_model(self, model_name): | |
| """Unload a specific model to free memory""" | |
| if model_name in self.models: | |
| del self.models[model_name] | |
| torch.cuda.empty_cache() | |
| logger.info(f"Unloaded {model_name}") | |
| def clear_all(self): | |
| """Unload all models""" | |
| self.models.clear() | |
| torch.cuda.empty_cache() | |
| logger.info("All models unloaded") | |