Muhammadidrees commited on
Commit
29531a5
Β·
verified Β·
1 Parent(s): 40745cf

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +11 -16
  2. infer.py +14 -61
app.py CHANGED
@@ -14,21 +14,6 @@ except Exception as e:
14
  print(f"GPU warmup failed: {e}")
15
  os.environ["GRADIO_TEMP_DIR"] = "./tmp"
16
 
17
- try:
18
- # You can define minimal dummy args here for initialization
19
- class DummyArgs:
20
- wan_model_dir = "./models/Wan2.1-I2V-14B-720P"
21
- fantasytalking_model_path = "./models/fantasytalking_model.ckpt"
22
- wav2vec_model_dir = "./models/wav2vec2-base-960h"
23
-
24
- args = DummyArgs()
25
- pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
26
- print("βœ… Models loaded successfully.")
27
- except Exception as e:
28
- print(f"❌ Error loading models: {e}")
29
- pipe = fantasytalking = wav2vec_processor = wav2vec = None
30
- raise e # fail fast on startup if models can't load
31
-
32
  pipe,fantasytalking,wav2vec_processor,wav2vec = None,None,None,None
33
  @spaces.GPU(duration=1200)
34
  def generate_video(
@@ -68,7 +53,17 @@ def generate_video(
68
  seed=seed,
69
  )
70
 
71
-
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  def create_args(
 
14
  print(f"GPU warmup failed: {e}")
15
  os.environ["GRADIO_TEMP_DIR"] = "./tmp"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  pipe,fantasytalking,wav2vec_processor,wav2vec = None,None,None,None
18
  @spaces.GPU(duration=1200)
19
  def generate_video(
 
53
  seed=seed,
54
  )
55
 
56
+ try:
57
+ global pipe, fantasytalking, wav2vec_processor, wav2vec
58
+ if pipe is None:
59
+ pipe,fantasytalking,wav2vec_processor,wav2vec = load_models(args)
60
+ output_path=main(
61
+ args,pipe,fantasytalking,wav2vec_processor,wav2vec
62
+ )
63
+ return output_path # Ensure the output path is returned
64
+ except Exception as e:
65
+ print(f"Error during processing: {str(e)}")
66
+ raise gr.Error(f"Error during processing: {str(e)}")
67
 
68
 
69
  def create_args(
infer.py CHANGED
@@ -122,30 +122,16 @@ def parse_args():
122
  args = parser.parse_args()
123
  return args
124
 
125
- import torch
126
- from huggingface_hub import snapshot_download
127
- from diffusers import WanVideoPipeline
128
- from transformers import Wav2Vec2Processor, Wav2Vec2Model
129
- from models import FantasyTalkingAudioConditionModel # adjust import if needed
130
- from model_manager import ModelManager # assuming this exists in your repo
131
-
132
-
133
  def load_models(args):
134
- print("πŸš€ [Startup] Initializing all models (compile-time preloading)...")
135
 
136
- # --------------------------------------------
137
- # STEP 1 β€” Ensure all model files are cached
138
- # --------------------------------------------
139
  snapshot_download("Wan-AI/Wan2.1-I2V-14B-720P", local_dir="./models/Wan2.1-I2V-14B-720P")
140
  snapshot_download("facebook/wav2vec2-base-960h", local_dir="./models/wav2vec2-base-960h")
141
  snapshot_download("acvlab/FantasyTalking", local_dir="./models")
142
 
143
- # --------------------------------------------
144
- # STEP 2 β€” Initialize ModelManager (core loader)
145
- # --------------------------------------------
146
- print("πŸ”§ Loading Wan I2V model checkpoints via ModelManager...")
147
- model_manager = ModelManager(device="cuda")
148
 
 
 
149
  model_manager.load_models(
150
  [
151
  [
@@ -161,56 +147,23 @@ def load_models(args):
161
  f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
162
  f"{args.wan_model_dir}/Wan2.1_VAE.pth",
163
  ],
164
- torch_dtype=torch.bfloat16,
165
- )
166
-
167
- pipe = WanVideoPipeline.from_model_manager(
168
- model_manager, torch_dtype=torch.bfloat16, device="cuda"
169
  )
 
170
 
171
- pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
172
- pipe.to("cuda")
173
- pipe.eval()
174
-
175
- # --------------------------------------------
176
- # STEP 3 β€” Load FantasyTalking model
177
- # --------------------------------------------
178
- print("🧠 Loading FantasyTalking model...")
179
- fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda").eval()
180
  fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
 
 
 
181
 
182
- # --------------------------------------------
183
- # STEP 4 β€” Load Wav2Vec2 model + processor
184
- # --------------------------------------------
185
- print("πŸŽ™οΈ Loading Wav2Vec2 model...")
186
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
187
- wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda").eval()
188
-
189
- # --------------------------------------------
190
- # STEP 5 β€” FORCE preload (compile-time warmup)
191
- # --------------------------------------------
192
- print("πŸ”₯ Preloading all models into GPU memory (forcing weight instantiation)...")
193
- with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
194
- # Wav2Vec2 warmup
195
- dummy_audio = torch.zeros(1, 16000).to("cuda")
196
- _ = wav2vec(dummy_audio)
197
-
198
- # Diffusion UNet warmup
199
- dummy_latent = torch.randn(1, pipe.unet.in_channels, 64, 64, device="cuda", dtype=torch.bfloat16)
200
- _ = pipe.unet(dummy_latent, 0.5)
201
-
202
- # FantasyTalking warmup
203
- try:
204
- dummy_feat = torch.randn(1, 256).to("cuda")
205
- _ = fantasytalking(dummy_feat)
206
- except Exception as e:
207
- print(f"⚠️ FantasyTalking warmup skipped: {e}")
208
-
209
- torch.cuda.synchronize()
210
- print("βœ… [Ready] All models fully loaded and warmed up in GPU memory.")
211
-
212
- return pipe, fantasytalking, wav2vec_processor, wav2vec
213
 
 
214
 
215
 
216
 
 
122
  args = parser.parse_args()
123
  return args
124
 
 
 
 
 
 
 
 
 
125
  def load_models(args):
126
+ # Load Wan I2V models
127
 
 
 
 
128
  snapshot_download("Wan-AI/Wan2.1-I2V-14B-720P", local_dir="./models/Wan2.1-I2V-14B-720P")
129
  snapshot_download("facebook/wav2vec2-base-960h", local_dir="./models/wav2vec2-base-960h")
130
  snapshot_download("acvlab/FantasyTalking", local_dir="./models")
131
 
 
 
 
 
 
132
 
133
+ model_manager = ModelManager(device="cpu")
134
+
135
  model_manager.load_models(
136
  [
137
  [
 
147
  f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
148
  f"{args.wan_model_dir}/Wan2.1_VAE.pth",
149
  ],
150
+ # torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
151
+ torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
 
 
 
152
  )
153
+ pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
154
 
155
+ # Load FantasyTalking weights
156
+ fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
 
 
 
 
 
 
 
157
  fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
158
+
159
+ # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
160
+ pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
161
 
162
+ # Load wav2vec models
 
 
 
163
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
164
+ wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ return pipe,fantasytalking,wav2vec_processor,wav2vec
167
 
168
 
169