JackIsNotInTheBox commited on
Commit
e9d7eba
·
1 Parent(s): c5a6704

Fix torch.load: add weights_only=False for PyTorch 2.6 compat

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -49,7 +49,7 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
49
  from samplers import euler_sampler, euler_maruyama_sampler
50
  from diffusers import AudioLDM2Pipeline
51
  extract_cavp = Extract_CAVP_Features(device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path)
52
- state_dict = torch.load(onset_ckpt_path, map_location=device)["state_dict"]
53
  new_state_dict = {}
54
  for key, value in state_dict.items():
55
  if "model.net.model" in key:
@@ -63,7 +63,7 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
63
  onset_model.load_state_dict(new_state_dict)
64
  onset_model.eval()
65
  model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
66
- ckpt = torch.load(taro_ckpt_path, map_location=device)["ema"]
67
  model.load_state_dict(ckpt)
68
  model.eval()
69
  model.to(weight_dtype)
 
49
  from samplers import euler_sampler, euler_maruyama_sampler
50
  from diffusers import AudioLDM2Pipeline
51
  extract_cavp = Extract_CAVP_Features(device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path)
52
+ state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
53
  new_state_dict = {}
54
  for key, value in state_dict.items():
55
  if "model.net.model" in key:
 
63
  onset_model.load_state_dict(new_state_dict)
64
  onset_model.eval()
65
  model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
66
+ ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
67
  model.load_state_dict(ckpt)
68
  model.eval()
69
  model.to(weight_dtype)