CocoBro commited on
Commit
caea508
·
1 Parent(s): 77f1338

fix load gpu

Browse files
Files changed (1) hide show
  1. app.py +6 -25
app.py CHANGED
@@ -168,31 +168,11 @@ def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_roo
168
  exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
169
 
170
 
171
- # ---------------------------------------------------------
172
- # Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
173
- # 带 fallback:避免 404
174
- # ---------------------------------------------------------
175
-
176
-
177
- def amp_autocast(device):
178
- import torch
179
-
180
- if not USE_AMP:
181
- return torch.autocast("cuda", enabled=False)
182
-
183
- if device.type != "cuda":
184
- return torch.autocast("cpu", enabled=False)
185
-
186
- dtype = torch.bfloat16 if AMP_DTYPE.lower() == "bf16" else torch.float16
187
- return torch.autocast("cuda", dtype=dtype, enabled=True)
188
-
189
-
190
-
191
  # -----------------------------
192
  # ZeroGPU 核心任务
193
  # -----------------------------
194
  # 学长说的就是这里:所有费资源的操作(加载+推理)都要放在这里面
195
- @spaces.GPU(duration=150)
196
  def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
197
  # 延迟导入,防止全局污染
198
  import torch
@@ -338,9 +318,10 @@ def build_demo():
338
 
339
  if __name__ == "__main__":
340
  demo = build_demo()
341
- # 必须 ssr_mode=False
342
  demo.queue().launch(
343
- server_name="0.0.0.0",
344
- server_port=int(os.environ.get("PORT", 7860)),
345
- ssr_mode=False
 
346
  )
 
168
  exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # -----------------------------
172
  # ZeroGPU 核心任务
173
  # -----------------------------
174
  # 学长说的就是这里:所有费资源的操作(加载+推理)都要放在这里面
175
+ @spaces.GPU
176
  def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
177
  # 延迟导入,防止全局污染
178
  import torch
 
318
 
319
  if __name__ == "__main__":
320
  demo = build_demo()
321
+ port = int(os.environ.get("PORT", "7860"))
322
  demo.queue().launch(
323
+ server_name="0.0.0.0",
324
+ server_port=port,
325
+ share=False,
326
+ ssr_mode=False,
327
  )