ShalomKing commited on
Commit
2c73ba8
·
verified ·
1 Parent(s): bc5110c

Upload utils/model_loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils/model_loader.py +35 -31
utils/model_loader.py CHANGED
@@ -104,54 +104,58 @@ class ModelManager:
104
  )
105
  return self.model_paths["wav2vec"]
106
 
107
- def load_wan_model(self, size="infinitetalk-480", device="cuda"):
108
  """
109
- Load Wan model for inference
110
 
111
  Args:
112
- size: Model size configuration
113
  device: Device to load model on
 
114
 
115
  Returns:
116
- Loaded model
117
  """
118
- if "wan_model" not in self.models:
119
  import wan
120
- from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
121
 
122
  model_path = self.get_wan_model_path()
123
  infinitetalk_path = self.get_infinitetalk_weights_path()
 
124
 
125
- logger.info(f"Loading Wan model from {model_path}...")
126
 
127
- # Initialize model based on InfiniteTalk's approach
128
  task = "infinitetalk-14B"
129
- args_dict = {
130
- "ckpt_dir": model_path,
131
- "infinitetalk_dir": os.path.join(infinitetalk_path, "infinitetalk.safetensors"),
132
- "task": task,
133
- "size": size,
134
- "sample_steps": 40,
135
- "sample_shift": 7 if size == "infinitetalk-480" else 11,
136
- }
137
-
138
- # Create a simple namespace object for args
139
- class Args:
140
- def __init__(self, **kwargs):
141
- self.__dict__.update(kwargs)
142
-
143
- args = Args(**args_dict)
 
 
 
 
 
144
 
145
- # Load model (simplified - actual loading would use wan.load_model())
146
- # This is a placeholder - actual implementation would call the wan library
147
- model = wan.WanModel(args)
148
- model.to(device)
149
- model.eval()
150
 
151
- self.models["wan_model"] = model
152
- logger.info("Wan model loaded successfully")
153
 
154
- return self.models["wan_model"]
155
 
156
  def load_audio_encoder(self, device="cuda"):
157
  """
 
104
  )
105
  return self.model_paths["wav2vec"]
106
 
107
+ def load_wan_model(self, size="infinitetalk-480", device="cuda", offload_model=True):
108
  """
109
+ Load Wan InfiniteTalk pipeline for inference
110
 
111
  Args:
112
+ size: Model size configuration (infinitetalk-480 or infinitetalk-720)
113
  device: Device to load model on
114
+ offload_model: Whether to offload model to CPU between forwards
115
 
116
  Returns:
117
+ Loaded InfiniteTalkPipeline
118
  """
119
+ if "wan_pipeline" not in self.models:
120
  import wan
121
+ from wan.configs import WAN_CONFIGS
122
 
123
  model_path = self.get_wan_model_path()
124
  infinitetalk_path = self.get_infinitetalk_weights_path()
125
+ infinitetalk_weights = os.path.join(infinitetalk_path, "infinitetalk.safetensors")
126
 
127
+ logger.info(f"Loading InfiniteTalk pipeline from {model_path}...")
128
 
129
+ # Get configuration for infinitetalk-14B
130
  task = "infinitetalk-14B"
131
+ cfg = WAN_CONFIGS[task]
132
+
133
+ # Create InfiniteTalk pipeline
134
+ # This matches the initialization in generate_infinitetalk.py
135
+ pipeline = wan.InfiniteTalkPipeline(
136
+ config=cfg,
137
+ checkpoint_dir=model_path,
138
+ quant_dir=None, # No quantization for now
139
+ device_id=device if isinstance(device, int) else 0,
140
+ rank=0, # Single GPU
141
+ t5_fsdp=False,
142
+ dit_fsdp=False,
143
+ use_usp=False,
144
+ t5_cpu=False,
145
+ lora_dir=None,
146
+ lora_scales=None,
147
+ quant=None,
148
+ dit_path=None,
149
+ infinitetalk_dir=infinitetalk_weights
150
+ )
151
 
152
+ # Enable memory management for low VRAM if needed
153
+ # pipeline.enable_vram_management(num_persistent_param_in_dit=0)
 
 
 
154
 
155
+ self.models["wan_pipeline"] = pipeline
156
+ logger.info("InfiniteTalk pipeline loaded successfully")
157
 
158
+ return self.models["wan_pipeline"]
159
 
160
  def load_audio_encoder(self, device="cuda"):
161
  """