megalado commited on
Commit
d345528
·
1 Parent(s): 41660e5

use correct CLI flags (--text_prompt, --cuda -1, --output_dir)

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -14,37 +14,40 @@ DEVICE = "cpu" # free HF Spaces have no GPU
14
 
15
  def generate_motion(prompt: str) -> str:
16
  """
17
- Runs the MDM sampling script in a subprocess and returns the BVH
18
- file path so Gradio can hand it to the user.
19
  """
20
- out_file = Path("/tmp") / f"{uuid.uuid4().hex}.bvh"
 
 
21
 
22
  cmd = [
23
  "python",
24
  "-m",
25
  "motion_diffusion_model.sample.generate",
26
  "--model_path", str(CKPT_PATH),
27
- "--prompt", prompt,
28
- "--output", str(out_file),
29
- "--device", DEVICE,
30
- "--num_steps", "50", # matches the checkpoint
31
  ]
32
 
33
- # --- make sure the local repo root is on PYTHONPATH so
34
- # 'utils.*' imports inside the script can be resolved
35
  env = os.environ.copy()
36
  root = Path(__file__).parent
37
  repo_inner = root / "motion_diffusion_model"
38
- env["PYTHONPATH"] = (
39
- f"{env.get('PYTHONPATH', '')}:{root}:{repo_inner}"
40
- )
41
-
42
 
43
  completed = subprocess.run(cmd, env=env, capture_output=True, text=True)
44
  if completed.returncode != 0:
45
  raise RuntimeError(f"Inference failed:\n{completed.stderr}")
46
 
47
- return str(out_file)
 
 
 
 
 
 
48
 
49
 
50
  # ----------------------- Gradio UI ----------------------------------
 
14
 
15
  def generate_motion(prompt: str) -> str:
16
  """
17
+ Calls the MDM sampling script and returns the generated BVH path.
 
18
  """
19
+ # create a unique temp directory for this run
20
+ out_dir = Path("/tmp") / f"mdm_{uuid.uuid4().hex}"
21
+ out_dir.mkdir(parents=True, exist_ok=True)
22
 
23
  cmd = [
24
  "python",
25
  "-m",
26
  "motion_diffusion_model.sample.generate",
27
  "--model_path", str(CKPT_PATH),
28
+ "--text_prompt", prompt,
29
+ "--cuda", "-1", # -1 = cpu
30
+ "--num_samples", "1",
31
+ "--output_dir", str(out_dir),
32
  ]
33
 
34
+ # add both repo roots to PYTHONPATH
 
35
  env = os.environ.copy()
36
  root = Path(__file__).parent
37
  repo_inner = root / "motion_diffusion_model"
38
+ env["PYTHONPATH"] = f"{env.get('PYTHONPATH', '')}:{root}:{repo_inner}"
 
 
 
39
 
40
  completed = subprocess.run(cmd, env=env, capture_output=True, text=True)
41
  if completed.returncode != 0:
42
  raise RuntimeError(f"Inference failed:\n{completed.stderr}")
43
 
44
+ # find the .bvh the script just wrote
45
+ bvh_files = list(out_dir.rglob("*.bvh"))
46
+ if not bvh_files:
47
+ raise RuntimeError("No BVH file produced.")
48
+ return str(bvh_files[0])
49
+
50
+
51
 
52
 
53
  # ----------------------- Gradio UI ----------------------------------