BoxOfColors commited on
Commit
8635f79
·
1 Parent(s): 3c63946

Add granular step-by-step logging in _taro_gpu_infer to find exact GPU abort point

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -889,17 +889,23 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
889
  total_dur_s = ctx["total_dur_s"]
890
  print(f"[_taro_gpu_infer] tmp_dir={tmp_dir!r} silent_video={silent_video!r} segments={segments} total_dur_s={total_dur_s}")
891
 
 
892
  extract_cavp, onset_model = _load_taro_feature_extractors(device)
 
893
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
 
894
  # Onset features depend only on the video — extract once for all samples
895
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
 
896
 
897
  # Free feature extractors before loading the heavier inference models
898
  del extract_cavp, onset_model
899
  if torch.cuda.is_available():
900
  torch.cuda.empty_cache()
901
 
 
902
  model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
 
903
 
904
  results = [] # list of (wavs, onset_feats) per sample
905
  for sample_idx in range(num_samples):
 
889
  total_dur_s = ctx["total_dur_s"]
890
  print(f"[_taro_gpu_infer] tmp_dir={tmp_dir!r} silent_video={silent_video!r} segments={segments} total_dur_s={total_dur_s}")
891
 
892
+ print(f"[_taro_gpu_infer] calling _load_taro_feature_extractors")
893
  extract_cavp, onset_model = _load_taro_feature_extractors(device)
894
+ print(f"[_taro_gpu_infer] extractors loaded, calling extract_cavp")
895
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
896
+ print(f"[_taro_gpu_infer] cavp done, calling extract_onset")
897
  # Onset features depend only on the video — extract once for all samples
898
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
899
+ print(f"[_taro_gpu_infer] onset done, freeing extractors")
900
 
901
  # Free feature extractors before loading the heavier inference models
902
  del extract_cavp, onset_model
903
  if torch.cuda.is_available():
904
  torch.cuda.empty_cache()
905
 
906
+ print(f"[_taro_gpu_infer] calling _load_taro_models")
907
  model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
908
+ print(f"[_taro_gpu_infer] models loaded")
909
 
910
  results = [] # list of (wavs, onset_feats) per sample
911
  for sample_idx in range(num_samples):