megalado commited on
Commit
4496399
·
1 Parent(s): 667d24f

Wire up real model inference in app.py

Browse files
Files changed (2) hide show
  1. app.py +38 -6
  2. requirements.txt +8 -0
app.py CHANGED
@@ -1,14 +1,46 @@
 
 
 
 
1
  import gradio as gr
2
 
3
- def generate_motion(prompt):
4
- return f"This will generate motion for: '{prompt}'"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  demo = gr.Interface(
7
  fn=generate_motion,
8
  inputs=gr.Textbox(label="Enter a text prompt"),
9
- outputs="text",
10
- title="Motion Diffusion Model Demo",
11
- description="Generate 3D human motion from text."
12
  )
13
 
14
- demo.launch()
 
 
1
+ import os
2
+ import uuid
3
+ import subprocess
4
+ import shlex
5
  import gradio as gr
6
 
7
+ MODEL_DIR = "motion-diffusion-model" # where you copied the repo
8
+ OUTPUT_DIR = "outputs"
9
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
10
+
11
+ def generate_motion(prompt: str):
12
+ # Unique filename
13
+ run_id = uuid.uuid4().hex[:8]
14
+ out_file = f"{run_id}.mp4"
15
+ out_path = os.path.join(OUTPUT_DIR, out_file)
16
+
17
+ # Command-line call to the sample script
18
+ # Adjust flags if needed; this is the typical pattern
19
+ cmd = f"python {MODEL_DIR}/sample/predict.py " \
20
+ f"--prompt \"{prompt}\" " \
21
+ f"--results_dir {OUTPUT_DIR} " \
22
+ f"--num_samples 1"
23
+
24
+ # Run the model
25
+ subprocess.run(shlex.split(cmd), check=True)
26
+
27
+ # The script will dump something like outputs/sample_0.mp4
28
+ # Rename it to our unique filename
29
+ default_out = os.path.join(OUTPUT_DIR, "sample_0.mp4")
30
+ if os.path.exists(default_out):
31
+ os.replace(default_out, out_path)
32
+ else:
33
+ raise FileNotFoundError(default_out)
34
+
35
+ return out_path
36
 
37
  demo = gr.Interface(
38
  fn=generate_motion,
39
  inputs=gr.Textbox(label="Enter a text prompt"),
40
+ outputs=gr.Video(label="Generated Motion"),
41
+ title="Motion Diffusion Model",
42
+ description="Type some text and watch the 3D human motion!"
43
  )
44
 
45
+ if __name__ == "__main__":
46
+ demo.launch()
requirements.txt CHANGED
@@ -1 +1,9 @@
1
  gradio
 
 
 
 
 
 
 
 
 
1
  gradio
2
+ torch
3
+ transformers
4
+ omegaconf
5
+ einops
6
+ scipy
7
+ numpy
8
+ opencv-python
9
+ imageio