infinitetalk / utils /model_loader.py
ShalomKing's picture
Upload utils/model_loader.py with huggingface_hub
2c73ba8 verified
raw
history blame
6.88 kB
"""
Model Manager for InfiniteTalk
Handles lazy loading and caching of models from HuggingFace Hub
"""
import os
import torch
from huggingface_hub import snapshot_download
from pathlib import Path
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
"""Manages model loading and caching"""
def __init__(self, cache_dir=None):
"""
Initialize Model Manager
Args:
cache_dir: Directory for caching models. Defaults to HF_HOME or /data/.huggingface
"""
if cache_dir is None:
cache_dir = os.environ.get("HF_HOME", "/data/.huggingface")
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.models = {}
self.model_paths = {
"wan": None,
"infinitetalk": None,
"wav2vec": None
}
def download_model(self, repo_id, subfolder=None, filename=None):
"""
Download model from HuggingFace Hub with caching
Args:
repo_id: HuggingFace repository ID (e.g., "Kijai/WanVideo_comfy")
subfolder: Optional subfolder within the repository
filename: Optional specific file to download
Returns:
Path to downloaded model directory
"""
try:
logger.info(f"Downloading {repo_id} from HuggingFace Hub...")
download_kwargs = {
"repo_id": repo_id,
"cache_dir": str(self.cache_dir),
"resume_download": True,
}
if subfolder:
download_kwargs["allow_patterns"] = f"{subfolder}/*"
if filename:
download_kwargs["allow_patterns"] = filename
model_path = snapshot_download(**download_kwargs)
if subfolder:
model_path = os.path.join(model_path, subfolder)
logger.info(f"Model downloaded successfully to {model_path}")
return model_path
except Exception as e:
logger.error(f"Error downloading model {repo_id}: {e}")
raise
def get_wan_model_path(self):
"""Get or download Wan2.1 I2V model"""
if self.model_paths["wan"] is None:
logger.info("Downloading Wan2.1-I2V-14B-480P model...")
# This will download the full model - adjust repo_id based on actual HF location
self.model_paths["wan"] = self.download_model(
repo_id="Kijai/WanVideo_comfy",
subfolder="wan2_1_i2v_14B_480P"
)
return self.model_paths["wan"]
def get_infinitetalk_weights_path(self):
"""Get or download InfiniteTalk weights"""
if self.model_paths["infinitetalk"] is None:
logger.info("Downloading InfiniteTalk weights...")
self.model_paths["infinitetalk"] = self.download_model(
repo_id="MeiGen-AI/InfiniteTalk",
subfolder="single"
)
return self.model_paths["infinitetalk"]
def get_wav2vec_model_path(self):
"""Get or download Wav2Vec2 audio encoder"""
if self.model_paths["wav2vec"] is None:
logger.info("Downloading Wav2Vec2 audio encoder...")
self.model_paths["wav2vec"] = self.download_model(
repo_id="TencentGameMate/chinese-wav2vec2-base"
)
return self.model_paths["wav2vec"]
def load_wan_model(self, size="infinitetalk-480", device="cuda", offload_model=True):
"""
Load Wan InfiniteTalk pipeline for inference
Args:
size: Model size configuration (infinitetalk-480 or infinitetalk-720)
device: Device to load model on
offload_model: Whether to offload model to CPU between forwards
Returns:
Loaded InfiniteTalkPipeline
"""
if "wan_pipeline" not in self.models:
import wan
from wan.configs import WAN_CONFIGS
model_path = self.get_wan_model_path()
infinitetalk_path = self.get_infinitetalk_weights_path()
infinitetalk_weights = os.path.join(infinitetalk_path, "infinitetalk.safetensors")
logger.info(f"Loading InfiniteTalk pipeline from {model_path}...")
# Get configuration for infinitetalk-14B
task = "infinitetalk-14B"
cfg = WAN_CONFIGS[task]
# Create InfiniteTalk pipeline
# This matches the initialization in generate_infinitetalk.py
pipeline = wan.InfiniteTalkPipeline(
config=cfg,
checkpoint_dir=model_path,
quant_dir=None, # No quantization for now
device_id=device if isinstance(device, int) else 0,
rank=0, # Single GPU
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
lora_dir=None,
lora_scales=None,
quant=None,
dit_path=None,
infinitetalk_dir=infinitetalk_weights
)
# Enable memory management for low VRAM if needed
# pipeline.enable_vram_management(num_persistent_param_in_dit=0)
self.models["wan_pipeline"] = pipeline
logger.info("InfiniteTalk pipeline loaded successfully")
return self.models["wan_pipeline"]
def load_audio_encoder(self, device="cuda"):
"""
Load Wav2Vec2 audio encoder
Args:
device: Device to load model on
Returns:
Audio encoder model and feature extractor
"""
if "audio_encoder" not in self.models:
from transformers import Wav2Vec2FeatureExtractor
from src.audio_analysis.wav2vec2 import Wav2Vec2Model
wav2vec_path = self.get_wav2vec_model_path()
logger.info(f"Loading audio encoder from {wav2vec_path}...")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path)
audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec_path)
audio_encoder.to(device)
audio_encoder.eval()
self.models["audio_encoder"] = (audio_encoder, feature_extractor)
logger.info("Audio encoder loaded successfully")
return self.models["audio_encoder"]
def unload_model(self, model_name):
"""Unload a specific model to free memory"""
if model_name in self.models:
del self.models[model_name]
torch.cuda.empty_cache()
logger.info(f"Unloaded {model_name}")
def clear_all(self):
"""Unload all models"""
self.models.clear()
torch.cuda.empty_cache()
logger.info("All models unloaded")