Muhammadidrees commited on
Commit
5f538d3
·
verified ·
1 Parent(s): f162467

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +61 -14
infer.py CHANGED
@@ -122,16 +122,30 @@ def parse_args():
122
  args = parser.parse_args()
123
  return args
124
 
 
 
 
 
 
 
 
 
125
  def load_models(args):
126
- # Load Wan I2V models
127
 
 
 
 
128
  snapshot_download("Wan-AI/Wan2.1-I2V-14B-720P", local_dir="./models/Wan2.1-I2V-14B-720P")
129
  snapshot_download("facebook/wav2vec2-base-960h", local_dir="./models/wav2vec2-base-960h")
130
  snapshot_download("acvlab/FantasyTalking", local_dir="./models")
131
 
 
 
 
 
 
132
 
133
- model_manager = ModelManager(device="cpu")
134
-
135
  model_manager.load_models(
136
  [
137
  [
@@ -147,23 +161,56 @@ def load_models(args):
147
  f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
148
  f"{args.wan_model_dir}/Wan2.1_VAE.pth",
149
  ],
150
- # torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
151
- torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
 
 
 
152
  )
153
- pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
154
 
155
- # Load FantasyTalking weights
156
- fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
157
- fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
158
-
159
- # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
160
  pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
 
 
 
 
 
 
 
 
 
161
 
162
- # Load wav2vec models
 
 
 
163
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
164
- wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- return pipe,fantasytalking,wav2vec_processor,wav2vec
167
 
168
 
169
 
 
122
  args = parser.parse_args()
123
  return args
124
 
125
+ import torch
126
+ from huggingface_hub import snapshot_download
127
+ from diffusers import WanVideoPipeline
128
+ from transformers import Wav2Vec2Processor, Wav2Vec2Model
129
+ from models import FantasyTalkingAudioConditionModel # adjust import if needed
130
+ from model_manager import ModelManager # assuming this exists in your repo
131
+
132
+
133
  def load_models(args):
134
+ print("🚀 [Startup] Initializing all models (compile-time preloading)...")
135
 
136
+ # --------------------------------------------
137
+ # STEP 1 — Ensure all model files are cached
138
+ # --------------------------------------------
139
  snapshot_download("Wan-AI/Wan2.1-I2V-14B-720P", local_dir="./models/Wan2.1-I2V-14B-720P")
140
  snapshot_download("facebook/wav2vec2-base-960h", local_dir="./models/wav2vec2-base-960h")
141
  snapshot_download("acvlab/FantasyTalking", local_dir="./models")
142
 
143
+ # --------------------------------------------
144
+ # STEP 2 — Initialize ModelManager (core loader)
145
+ # --------------------------------------------
146
+ print("🔧 Loading Wan I2V model checkpoints via ModelManager...")
147
+ model_manager = ModelManager(device="cuda")
148
 
 
 
149
  model_manager.load_models(
150
  [
151
  [
 
161
  f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
162
  f"{args.wan_model_dir}/Wan2.1_VAE.pth",
163
  ],
164
+ torch_dtype=torch.bfloat16,
165
+ )
166
+
167
+ pipe = WanVideoPipeline.from_model_manager(
168
+ model_manager, torch_dtype=torch.bfloat16, device="cuda"
169
  )
 
170
 
 
 
 
 
 
171
  pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
172
+ pipe.to("cuda")
173
+ pipe.eval()
174
+
175
+ # --------------------------------------------
176
+ # STEP 3 — Load FantasyTalking model
177
+ # --------------------------------------------
178
+ print("🧠 Loading FantasyTalking model...")
179
+ fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda").eval()
180
+ fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
181
 
182
+ # --------------------------------------------
183
+ # STEP 4 — Load Wav2Vec2 model + processor
184
+ # --------------------------------------------
185
+ print("🎙️ Loading Wav2Vec2 model...")
186
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
187
+ wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda").eval()
188
+
189
+ # --------------------------------------------
190
+ # STEP 5 — FORCE preload (compile-time warmup)
191
+ # --------------------------------------------
192
+ print("🔥 Preloading all models into GPU memory (forcing weight instantiation)...")
193
+ with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
194
+ # Wav2Vec2 warmup
195
+ dummy_audio = torch.zeros(1, 16000).to("cuda")
196
+ _ = wav2vec(dummy_audio)
197
+
198
+ # Diffusion UNet warmup
199
+ dummy_latent = torch.randn(1, pipe.unet.in_channels, 64, 64, device="cuda", dtype=torch.bfloat16)
200
+ _ = pipe.unet(dummy_latent, 0.5)
201
+
202
+ # FantasyTalking warmup
203
+ try:
204
+ dummy_feat = torch.randn(1, 256).to("cuda")
205
+ _ = fantasytalking(dummy_feat)
206
+ except Exception as e:
207
+ print(f"⚠️ FantasyTalking warmup skipped: {e}")
208
+
209
+ torch.cuda.synchronize()
210
+ print("✅ [Ready] All models fully loaded and warmed up in GPU memory.")
211
+
212
+ return pipe, fantasytalking, wav2vec_processor, wav2vec
213
 
 
214
 
215
 
216