fix load gpu
Browse files
app.py
CHANGED
|
@@ -168,31 +168,11 @@ def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_roo
|
|
| 168 |
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
|
| 169 |
|
| 170 |
|
| 171 |
-
# ---------------------------------------------------------
|
| 172 |
-
# Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
|
| 173 |
-
# 带 fallback:避免 404
|
| 174 |
-
# ---------------------------------------------------------
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def amp_autocast(device):
|
| 178 |
-
import torch
|
| 179 |
-
|
| 180 |
-
if not USE_AMP:
|
| 181 |
-
return torch.autocast("cuda", enabled=False)
|
| 182 |
-
|
| 183 |
-
if device.type != "cuda":
|
| 184 |
-
return torch.autocast("cpu", enabled=False)
|
| 185 |
-
|
| 186 |
-
dtype = torch.bfloat16 if AMP_DTYPE.lower() == "bf16" else torch.float16
|
| 187 |
-
return torch.autocast("cuda", dtype=dtype, enabled=True)
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
# -----------------------------
|
| 192 |
# ZeroGPU 核心任务
|
| 193 |
# -----------------------------
|
| 194 |
# 学长说的就是这里:所有费资源的操作(加载+推理)都要放在这里面
|
| 195 |
-
@spaces.GPU
|
| 196 |
def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
|
| 197 |
# 延迟导入,防止全局污染
|
| 198 |
import torch
|
|
@@ -338,9 +318,10 @@ def build_demo():
|
|
| 338 |
|
| 339 |
if __name__ == "__main__":
|
| 340 |
demo = build_demo()
|
| 341 |
-
|
| 342 |
demo.queue().launch(
|
| 343 |
-
server_name="0.0.0.0",
|
| 344 |
-
server_port=
|
| 345 |
-
|
|
|
|
| 346 |
)
|
|
|
|
| 168 |
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
|
| 169 |
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
# -----------------------------
|
| 172 |
# ZeroGPU 核心任务
|
| 173 |
# -----------------------------
|
| 174 |
# 学长说的就是这里:所有费资源的操作(加载+推理)都要放在这里面
|
| 175 |
+
@spaces.GPU
|
| 176 |
def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
|
| 177 |
# 延迟导入,防止全局污染
|
| 178 |
import torch
|
|
|
|
| 318 |
|
| 319 |
if __name__ == "__main__":
|
| 320 |
demo = build_demo()
|
| 321 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 322 |
demo.queue().launch(
|
| 323 |
+
server_name="0.0.0.0",
|
| 324 |
+
server_port=port,
|
| 325 |
+
share=False,
|
| 326 |
+
ssr_mode=False,
|
| 327 |
)
|