Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -232,13 +232,15 @@ if args.remove_pretransform_weight_norm == "post_load":
|
|
| 232 |
ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
|
| 233 |
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
| 234 |
# 加载模型权重时根据设备选择map_location
|
| 235 |
-
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
|
|
|
|
|
|
| 236 |
|
| 237 |
def get_video_duration(video_path):
|
| 238 |
video = VideoFileClip(video_path)
|
| 239 |
return video.duration
|
| 240 |
|
| 241 |
-
@spaces.GPU(duration=
|
| 242 |
@torch.inference_mode()
|
| 243 |
@torch.no_grad()
|
| 244 |
def get_audio(video_path, caption):
|
|
|
|
| 232 |
ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
|
| 233 |
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
| 234 |
# 加载模型权重时根据设备选择map_location
|
| 235 |
+
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
| 236 |
+
|
| 237 |
+
training_wrapper.to("cuda")
|
| 238 |
|
| 239 |
def get_video_duration(video_path):
|
| 240 |
video = VideoFileClip(video_path)
|
| 241 |
return video.duration
|
| 242 |
|
| 243 |
+
@spaces.GPU(duration=60)
|
| 244 |
@torch.inference_mode()
|
| 245 |
@torch.no_grad()
|
| 246 |
def get_audio(video_path, caption):
|