jake commited on
Commit
05fc139
·
1 Parent(s): 1d0c879
Files changed (1) hide show
  1. app.py +680 -369
app.py CHANGED
@@ -1,398 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
 
 
3
  from pathlib import Path
 
 
4
  import spaces
 
 
 
 
 
5
 
6
- # === Import project modules ===
7
  PROJECT_ROOT = Path(__file__).resolve().parent
8
  MMADA_ROOT = PROJECT_ROOT / "MMaDA"
9
  if str(MMADA_ROOT) not in sys.path:
10
  sys.path.insert(0, str(MMADA_ROOT))
11
 
12
- from inference.gradio_multimodal_demo_inst import OmadaDemo
13
- import gradio as gr
14
-
15
-
16
- # ----------------------------------------------------------------------
17
- # 1. Asset Loading (Downloaded by entrypoint)
18
- # ----------------------------------------------------------------------
19
 
20
- ASSET_ROOT = PROJECT_ROOT / "_asset_cache" / "AIDAS-Omni-Modal-Diffusion-assets"
21
- DEMO_ROOT = ASSET_ROOT # asset repo already modality-split
22
 
 
 
 
23
 
24
- # ----------------------------------------------------------------------
25
- # 2. GPU Handler Wrapper
26
- # ----------------------------------------------------------------------
27
-
28
- def gpu_handler(fn):
29
  """
30
- Wrap an inference function using ZeroGPU.
 
 
31
  """
32
- @spaces.GPU
33
- def inner(*args, **kwargs):
34
- return fn(*args, **kwargs)
35
- return inner
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # ----------------------------------------------------------------------
39
- # 3. Build Demo UI With Examples
40
- # ----------------------------------------------------------------------
41
 
42
- def build_zero_gpu_demo(app: OmadaDemo):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- with gr.Blocks(title="AIDAS Omni-Modal Diffusion (ZeroGPU)") as demo:
45
 
46
- # ---------------- Header ----------------
47
- gr.Markdown(
48
- "<h1 style='text-align:center'>AIDAS Omni-Modal Diffusion Model</h1>"
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- try:
52
- logo_path = "/mnt/data/A2E36E9F-F389-487D-9984-FFF21C9228E3.png"
53
- gr.Image(logo_path, elem_id="logo", show_label=False, height=120)
54
- except:
55
- pass
56
-
57
- gr.Markdown("### Multimodal Inference Demo (ZeroGPU Optimized)")
58
- gr.Markdown("---")
59
-
60
- # ---------------- Tabs ----------------
61
-
62
- with gr.Tabs():
63
-
64
- # ============================================================
65
- # 1) TEXT → SPEECH (T2S)
66
- # ============================================================
67
- with gr.Tab("Text → Speech (T2S)"):
68
-
69
- t2s_in = gr.Textbox(label="Input Text")
70
- t2s_btn = gr.Button("Generate")
71
- t2s_audio = gr.Audio(label="Speech Output")
72
- t2s_status = gr.Textbox(label="Status", interactive=False)
73
-
74
- t2s_examples = []
75
- t2s_dir = DEMO_ROOT / "t2s"
76
- if t2s_dir.exists():
77
- for f in t2s_dir.glob("*.txt"):
78
- txt = f.read_text().strip()
79
- t2s_examples.append([txt])
80
-
81
- if len(t2s_examples) > 0:
82
- gr.Examples(
83
- examples=t2s_examples,
84
- inputs=[t2s_in],
85
- outputs=[t2s_audio, t2s_status],
86
- fn=gpu_handler(app.run_t2s),
87
- )
88
-
89
- t2s_btn.click(
90
- gpu_handler(app.run_t2s),
91
- inputs=[t2s_in],
92
- outputs=[t2s_audio, t2s_status],
93
- )
94
-
95
- # ============================================================
96
- # 2) SPEECH → SPEECH (S2S)
97
- # ============================================================
98
- with gr.Tab("Speech → Speech (S2S)"):
99
-
100
- s2s_in = gr.Audio(type="filepath", label="Input Speech")
101
- s2s_btn = gr.Button("Generate")
102
- s2s_audio = gr.Audio(label="Output Speech")
103
- s2s_status = gr.Textbox(label="Status", interactive=False)
104
-
105
- s2s_examples = []
106
- s2s_dir = DEMO_ROOT / "s2s"
107
- if s2s_dir.exists():
108
- for f in s2s_dir.glob("*.wav"):
109
- s2s_examples.append([str(f)])
110
-
111
- if len(s2s_examples) > 0:
112
- gr.Examples(
113
- examples=s2s_examples,
114
- inputs=[s2s_in],
115
- outputs=[s2s_audio, s2s_status],
116
- fn=gpu_handler(app.run_s2s),
117
- )
118
-
119
- s2s_btn.click(
120
- gpu_handler(app.run_s2s),
121
- inputs=[s2s_in],
122
- outputs=[s2s_audio, s2s_status]
123
- )
124
-
125
- # ============================================================
126
- # 3) SPEECH → TEXT (S2T)
127
- # ============================================================
128
- with gr.Tab("Speech → Text (S2T)"):
129
-
130
- s2t_in = gr.Audio(type="filepath", label="Input Speech")
131
- s2t_btn = gr.Button("Transcribe")
132
- s2t_text = gr.Textbox(label="Transcribed Text")
133
- s2t_status = gr.Textbox(label="Status", interactive=False)
134
-
135
- s2t_examples = []
136
- s2t_dir = DEMO_ROOT / "s2t"
137
- if s2t_dir.exists():
138
- for f in s2t_dir.glob("*.wav"):
139
- s2t_examples.append([str(f)])
140
-
141
- if len(s2t_examples) > 0:
142
- gr.Examples(
143
- examples=s2t_examples,
144
- inputs=[s2t_in],
145
- outputs=[s2t_text, s2t_status],
146
- fn=gpu_handler(app.run_s2t),
147
- )
148
-
149
- s2t_btn.click(
150
- gpu_handler(app.run_s2t),
151
- inputs=[s2t_in],
152
- outputs=[s2t_text, s2t_status],
153
- )
154
-
155
- # ============================================================
156
- # 4) VIDEO → TEXT (V2T)
157
- # ============================================================
158
- with gr.Tab("Video → Text (V2T)"):
159
-
160
- v2t_in = gr.Video(type="filepath", label="Input Video")
161
- v2t_btn = gr.Button("Generate Caption")
162
- v2t_text = gr.Textbox(label="Caption")
163
- v2t_status = gr.Textbox(label="Status")
164
-
165
- v2t_examples = []
166
- v2t_dir = DEMO_ROOT / "v2t"
167
- if v2t_dir.exists():
168
- for f in v2t_dir.glob("*.mp4"):
169
- v2t_examples.append([str(f)])
170
-
171
- if len(v2t_examples) > 0:
172
- gr.Examples(
173
- examples=v2t_examples,
174
- inputs=[v2t_in],
175
- outputs=[v2t_text, v2t_status],
176
- fn=gpu_handler(app.run_v2t),
177
- )
178
-
179
- v2t_btn.click(
180
- gpu_handler(app.run_v2t),
181
- inputs=[v2t_in],
182
- outputs=[v2t_text, v2t_status],
183
- )
184
-
185
- # ============================================================
186
- # 5) VIDEO → SPEECH (V2S)
187
- # ============================================================
188
- with gr.Tab("Video → Speech (V2S)"):
189
-
190
- v2s_in = gr.Video(type="filepath", label="Input Video")
191
- v2s_btn = gr.Button("Generate Speech")
192
- v2s_audio = gr.Audio(label="Speech Output")
193
- v2s_status = gr.Textbox(label="Status")
194
-
195
- v2s_examples = []
196
- v2s_dir = DEMO_ROOT / "v2s"
197
- if v2s_dir.exists():
198
- for f in v2s_dir.glob("*.mp4"):
199
- v2s_examples.append([str(f)])
200
-
201
- if len(v2s_examples) > 0:
202
- gr.Examples(
203
- examples=v2s_examples,
204
- inputs=[v2s_in],
205
- outputs=[v2s_audio, v2s_status],
206
- fn=gpu_handler(app.run_v2s),
207
- )
208
-
209
- v2s_btn.click(
210
- gpu_handler(app.run_v2s),
211
- inputs=[v2s_in],
212
- outputs=[v2s_audio, v2s_status],
213
- )
214
-
215
- # ============================================================
216
- # 6) IMAGE → SPEECH (I2S)
217
- # ============================================================
218
- with gr.Tab("Image → Speech (I2S)"):
219
-
220
- i2s_in = gr.Image(type="filepath", label="Input Image")
221
- i2s_btn = gr.Button("Generate Speech")
222
- i2s_audio = gr.Audio(label="Speech")
223
- i2s_status = gr.Textbox(label="Status")
224
-
225
- # Only if folder exists
226
- i2s_examples = []
227
- i2s_dir = DEMO_ROOT / "i2s"
228
- if i2s_dir.exists():
229
- for f in i2s_dir.glob("*.*"):
230
- i2s_examples.append([str(f)])
231
-
232
- if len(i2s_examples) > 0:
233
- gr.Examples(
234
- examples=i2s_examples,
235
- inputs=[i2s_in],
236
- outputs=[i2s_audio, i2s_status],
237
- fn=gpu_handler(app.run_i2s),
238
- )
239
-
240
- i2s_btn.click(
241
- gpu_handler(app.run_i2s),
242
- inputs=[i2s_in],
243
- outputs=[i2s_audio, i2s_status],
244
- )
245
-
246
- # ============================================================
247
- # 7) CHAT
248
- # ============================================================
249
- with gr.Tab("Chat (Text)"):
250
-
251
- chat_in = gr.Textbox(label="Message")
252
- chat_btn = gr.Button("Send")
253
- chat_out = gr.Textbox(label="Response")
254
- chat_status = gr.Textbox(label="Status")
255
-
256
- chat_examples = []
257
- chat_dir = DEMO_ROOT / "chat"
258
- if chat_dir.exists():
259
- for f in chat_dir.glob("*.txt"):
260
- txt = f.read_text().strip()
261
- chat_examples.append([txt])
262
-
263
- if len(chat_examples) > 0:
264
- gr.Examples(
265
- examples=chat_examples,
266
- inputs=[chat_in],
267
- outputs=[chat_out, chat_status],
268
- fn=gpu_handler(app.run_chat),
269
- )
270
-
271
- chat_btn.click(
272
- gpu_handler(app.run_chat),
273
- inputs=[chat_in],
274
- outputs=[chat_out, chat_status],
275
- )
276
-
277
- # ============================================================
278
- # 8) MMU (2 images → text)
279
- # ============================================================
280
- with gr.Tab("MMU (Dual-Image Reasoning)"):
281
-
282
- mmu_img1 = gr.Image(type="filepath", label="Image 1")
283
- mmu_img2 = gr.Image(type="filepath", label="Image 2")
284
- mmu_prompt = gr.Textbox(label="Prompt")
285
- mmu_btn = gr.Button("Run MMU")
286
- mmu_out = gr.Textbox(label="Output")
287
- mmu_status = gr.Textbox(label="Status")
288
-
289
- mmu_examples = []
290
- mmu_dir = DEMO_ROOT / "mmu"
291
- if mmu_dir.exists():
292
- imgs = list(mmu_dir.glob("*.png"))
293
- if len(imgs) >= 2:
294
- mmu_examples.append([
295
- str(imgs[0]),
296
- str(imgs[1]),
297
- "Describe the relation between two objects."
298
- ])
299
-
300
- if len(mmu_examples) > 0:
301
- gr.Examples(
302
- examples=mmu_examples,
303
- inputs=[mmu_img1, mmu_img2, mmu_prompt],
304
- outputs=[mmu_out, mmu_status],
305
- fn=gpu_handler(app.run_mmu_dual),
306
- )
307
-
308
- mmu_btn.click(
309
- gpu_handler(app.run_mmu_dual),
310
- inputs=[mmu_img1, mmu_img2, mmu_prompt],
311
- outputs=[mmu_out, mmu_status]
312
- )
313
-
314
- # ============================================================
315
- # 9) TEXT → IMAGE (T2I)
316
- # ============================================================
317
- with gr.Tab("Text → Image (T2I)"):
318
-
319
- t2i_in = gr.Textbox(label="Prompt")
320
- t2i_btn = gr.Button("Generate Image")
321
- t2i_img = gr.Image(label="Generated Image")
322
- t2i_status = gr.Textbox(label="Status")
323
-
324
- t2i_examples = []
325
- t2i_dir = DEMO_ROOT / "t2i"
326
- if t2i_dir.exists():
327
- for f in t2i_dir.glob("*.txt"):
328
- txt = f.read_text().strip()
329
- t2i_examples.append([txt])
330
-
331
- if len(t2i_examples) > 0:
332
- gr.Examples(
333
- examples=t2i_examples,
334
- inputs=[t2i_in],
335
- outputs=[t2i_img, t2i_status],
336
- fn=gpu_handler(app.run_t2i),
337
- )
338
-
339
- t2i_btn.click(
340
- gpu_handler(app.run_t2i),
341
- inputs=[t2i_in],
342
- outputs=[t2i_img, t2i_status],
343
- )
344
-
345
- # ============================================================
346
- # 10) IMAGE EDITING (I2I)
347
- # ============================================================
348
- with gr.Tab("Image Editing (I2I)"):
349
-
350
- i2i_in = gr.Image(type="filepath", label="Input Image")
351
- i2i_prompt = gr.Textbox(label="Edit Instruction")
352
- i2i_btn = gr.Button("Apply Edit")
353
- i2i_img = gr.Image(label="Edited Image")
354
- i2i_status = gr.Textbox(label="Status")
355
-
356
- i2i_examples = []
357
- i2i_dir = DEMO_ROOT / "i2i"
358
- if i2i_dir.exists():
359
- for f in i2i_dir.glob("*.*"):
360
- i2i_examples.append([str(f), "Make it more vibrant."])
361
-
362
- if len(i2i_examples) > 0:
363
- gr.Examples(
364
- examples=i2i_examples,
365
- inputs=[i2i_in, i2i_prompt],
366
- outputs=[i2i_img, i2i_status],
367
- fn=gpu_handler(app.run_i2i),
368
- )
369
-
370
- i2i_btn.click(
371
- gpu_handler(app.run_i2i),
372
- inputs=[i2i_in, i2i_prompt],
373
- outputs=[i2i_img, i2i_status]
374
- )
375
-
376
- # End Tabs
377
-
378
- return demo
379
-
380
-
381
- # ----------------------------------------------------------------------
382
- # 4. Entry Point for Space
383
- # ----------------------------------------------------------------------
384
 
385
  @spaces.GPU
386
- def main():
387
- app = OmadaDemo(
388
- train_config=str(MMADA_ROOT / "inference/demo/demo.yaml"),
389
- checkpoint=os.getenv("MODEL_CHECKPOINT_DIR", "_ckpt_cache/omada"),
390
- device="cpu"
 
 
 
 
 
 
 
391
  )
 
392
 
393
- demo = build_zero_gpu_demo(app)
394
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
 
397
  if __name__ == "__main__":
398
- main()
 
1
+ """
2
+ ZeroGPU-friendly Gradio entrypoint for OMada demo.
3
+
4
+ - Downloads checkpoint + assets + style centroids from Hugging Face Hub
5
+ - Instantiates OmadaDemo once (global)
6
+ - Exposes 10 modalities via Gradio tabs
7
+ - Uses @spaces.GPU only on inference handlers so GPU is allocated per request
8
+
9
+ Environment overrides:
10
+ MODEL_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion)
11
+ MODEL_REVISION (default: main)
12
+ ASSET_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion-assets)
13
+ ASSET_REVISION (default: main)
14
+ STYLE_REPO_ID (default: jaeikkim/aidas-style-centroid)
15
+ STYLE_REVISION (default: main)
16
+ HF_TOKEN (optional, for private model/dataset)
17
+ TRAIN_CONFIG_PATH (default: MMaDA/inference/demo/demo.yaml)
18
+ DEVICE (default: cuda)
19
+ """
20
+
21
  import os
22
  import sys
23
+ import subprocess
24
+ import importlib
25
  from pathlib import Path
26
+
27
+ import gradio as gr
28
  import spaces
29
+ from packaging.version import parse as parse_version
30
+
31
+ # ---------------------------
32
+ # Project roots & sys.path
33
+ # ---------------------------
34
 
 
35
  PROJECT_ROOT = Path(__file__).resolve().parent
36
  MMADA_ROOT = PROJECT_ROOT / "MMaDA"
37
  if str(MMADA_ROOT) not in sys.path:
38
  sys.path.insert(0, str(MMADA_ROOT))
39
 
40
+ EMOVA_ROOT = PROJECT_ROOT / "EMOVA_speech_tokenizer"
41
+ if str(EMOVA_ROOT) not in sys.path:
42
+ sys.path.insert(0, str(EMOVA_ROOT))
 
 
 
 
43
 
 
 
44
 
45
+ # ---------------------------
46
+ # HuggingFace Hub helper
47
+ # ---------------------------
48
 
49
+ def ensure_hf_hub(target: str = "0.36.0"):
 
 
 
 
50
  """
51
+ Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers.
52
+
53
+ The Spaces base image may pull in a newer version via gradio, so we pin it.
54
  """
55
+ try:
56
+ import huggingface_hub as hub
57
+ except ImportError:
58
+ subprocess.check_call(
59
+ [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
60
+ )
61
+ import huggingface_hub as hub
62
 
63
+ if parse_version(hub.__version__) >= parse_version("1.0.0"):
64
+ subprocess.check_call(
65
+ [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
66
+ )
67
+ hub = importlib.reload(hub)
68
+
69
+ # Backfill missing constants in older hub versions to avoid AttributeError.
70
+ try:
71
+ import huggingface_hub.constants as hub_consts # type: ignore
72
+ except Exception:
73
+ hub_consts = None
74
+ if hub_consts and not hasattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER"):
75
+ setattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER", False)
76
+ return hub
77
+
78
+
79
+ snapshot_download = ensure_hf_hub().snapshot_download
80
+
81
+
82
+ # ---------------------------
83
+ # Imports from OMada demo
84
+ # ---------------------------
85
+
86
+ from inference.gradio_multimodal_demo_inst import ( # noqa: E402
87
+ OmadaDemo,
88
+ CUSTOM_CSS,
89
+ FORCE_LIGHT_MODE_JS,
90
+ )
91
+
92
+
93
+ # ---------------------------
94
+ # HF download helpers
95
+ # ---------------------------
96
+
97
+ def download_assets() -> Path:
98
+ """Download demo assets (logo + sample prompts/media) and return the root path."""
99
+ repo_id = os.getenv("ASSET_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion-assets")
100
+ revision = os.getenv("ASSET_REVISION", "main")
101
+ token = os.getenv("HF_TOKEN")
102
+ cache_dir = PROJECT_ROOT / "_asset_cache"
103
+ cache_dir.mkdir(parents=True, exist_ok=True)
104
+
105
+ return Path(
106
+ snapshot_download(
107
+ repo_id=repo_id,
108
+ revision=revision,
109
+ repo_type="dataset",
110
+ local_dir=cache_dir,
111
+ local_dir_use_symlinks=False,
112
+ token=token,
113
+ )
114
+ )
115
 
 
 
 
116
 
117
+ def download_style() -> Path:
118
+ """Download style centroid dataset and return the root path."""
119
+ repo_id = os.getenv("STYLE_REPO_ID", "jaeikkim/aidas-style-centroid")
120
+ revision = os.getenv("STYLE_REVISION", "main")
121
+ token = os.getenv("HF_TOKEN")
122
+ cache_dir = PROJECT_ROOT / "_style_cache"
123
+ cache_dir.mkdir(parents=True, exist_ok=True)
124
+
125
+ return Path(
126
+ snapshot_download(
127
+ repo_id=repo_id,
128
+ revision=revision,
129
+ repo_type="dataset",
130
+ local_dir=cache_dir,
131
+ local_dir_use_symlinks=False,
132
+ token=token,
133
+ )
134
+ )
135
 
 
136
 
137
+ def download_checkpoint() -> Path:
138
+ """Download checkpoint snapshot and return an `unwrapped_model` directory."""
139
+ repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion")
140
+ revision = os.getenv("MODEL_REVISION", "main")
141
+ token = os.getenv("HF_TOKEN")
142
+ cache_dir = PROJECT_ROOT / "_ckpt_cache"
143
+ cache_dir.mkdir(parents=True, exist_ok=True)
144
+
145
+ snapshot_path = Path(
146
+ snapshot_download(
147
+ repo_id=repo_id,
148
+ revision=revision,
149
+ repo_type="model",
150
+ local_dir=cache_dir,
151
+ local_dir_use_symlinks=False,
152
+ token=token,
153
  )
154
+ )
155
+
156
+ # If snapshot itself is unwrapped_model, return it; otherwise look for nested dir,
157
+ # and finally alias via symlink.
158
+ if snapshot_path.name == "unwrapped_model":
159
+ return snapshot_path
160
+
161
+ nested = snapshot_path / "unwrapped_model"
162
+ if nested.is_dir():
163
+ return nested
164
+
165
+ aliased = snapshot_path.parent / "unwrapped_model"
166
+ if not aliased.exists():
167
+ aliased.symlink_to(snapshot_path, target_is_directory=True)
168
+ return aliased
169
+
170
+
171
+ # ---------------------------
172
+ # Global OmadaDemo instance
173
+ # ---------------------------
174
+
175
+ APP = None # type: ignore
176
+
177
+
178
+ def get_app() -> OmadaDemo:
179
+ global APP
180
+ if APP is not None:
181
+ return APP
182
+
183
+ # Download everything once
184
+ ckpt_dir = download_checkpoint()
185
+ asset_root = download_assets()
186
+ style_root = download_style()
187
+
188
+ # Wire style centroids to expected locations
189
+ style_targets = [
190
+ MMADA_ROOT / "models" / "speech_tokenization" / "condition_style_centroid",
191
+ PROJECT_ROOT
192
+ / "EMOVA_speech_tokenizer"
193
+ / "emova_speech_tokenizer"
194
+ / "speech_tokenization"
195
+ / "condition_style_centroid",
196
+ ]
197
+ for starget in style_targets:
198
+ if not starget.exists():
199
+ starget.parent.mkdir(parents=True, exist_ok=True)
200
+ starget.symlink_to(style_root, target_is_directory=True)
201
+
202
+ # Choose train config
203
+ default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
204
+ legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
205
+ train_config = os.getenv("TRAIN_CONFIG_PATH")
206
+ if not train_config:
207
+ train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
208
+
209
+ # Device: in ZeroGPU environment, "cuda" is virtualized and only actually
210
+ # attached inside @spaces.GPU handlers.
211
+ device = os.getenv("DEVICE", "cuda")
212
+
213
+ APP = OmadaDemo(train_config=train_config, checkpoint=str(ckpt_dir), device=device)
214
+ return APP
215
+
216
+
217
+ # ---------------------------
218
+ # ZeroGPU-wrapped handlers
219
+ # ---------------------------
220
+
221
+ @spaces.GPU
222
+ def t2s_handler(
223
+ text,
224
+ max_tokens,
225
+ steps,
226
+ block_len,
227
+ temperature,
228
+ cfg_scale,
229
+ gender,
230
+ emotion,
231
+ speed,
232
+ pitch,
233
+ ):
234
+ app = get_app()
235
+ audio, status = app.run_t2s(
236
+ text=text,
237
+ max_new_tokens=int(max_tokens),
238
+ steps=int(steps),
239
+ block_length=int(block_len),
240
+ temperature=float(temperature),
241
+ cfg_scale=float(cfg_scale),
242
+ gender_choice=gender,
243
+ emotion_choice=emotion,
244
+ speed_choice=speed,
245
+ pitch_choice=pitch,
246
+ )
247
+ return audio, status
248
+
249
+
250
+ @spaces.GPU
251
+ def s2s_handler(
252
+ audio_path,
253
+ max_tokens,
254
+ steps,
255
+ block_len,
256
+ temperature,
257
+ cfg_scale,
258
+ ):
259
+ app = get_app()
260
+ audio, status = app.run_s2s(
261
+ audio_path=audio_path,
262
+ max_new_tokens=int(max_tokens),
263
+ steps=int(steps),
264
+ block_length=int(block_len),
265
+ temperature=float(temperature),
266
+ cfg_scale=float(cfg_scale),
267
+ )
268
+ return audio, status
269
+
270
+
271
+ @spaces.GPU
272
+ def s2t_handler(
273
+ audio_path,
274
+ steps,
275
+ block_len,
276
+ max_tokens,
277
+ remasking,
278
+ ):
279
+ app = get_app()
280
+ text, status = app.run_s2t(
281
+ audio_path=audio_path,
282
+ steps=int(steps),
283
+ block_length=int(block_len),
284
+ max_new_tokens=int(max_tokens),
285
+ remasking=str(remasking),
286
+ )
287
+ return text, status
288
+
289
+
290
+ @spaces.GPU
291
+ def v2t_handler(
292
+ video,
293
+ steps,
294
+ block_len,
295
+ max_tokens,
296
+ ):
297
+ app = get_app()
298
+ text, status = app.run_v2t(
299
+ video_path=video,
300
+ steps=int(steps),
301
+ block_length=int(block_len),
302
+ max_new_tokens=int(max_tokens),
303
+ )
304
+ return text, status
305
+
306
+
307
+ @spaces.GPU
308
+ def v2s_handler(
309
+ video,
310
+ message,
311
+ max_tokens,
312
+ steps,
313
+ block_len,
314
+ temperature,
315
+ cfg_scale,
316
+ ):
317
+ app = get_app()
318
+ audio, status = app.run_v2s(
319
+ video_path=video,
320
+ message=message,
321
+ max_new_tokens=int(max_tokens),
322
+ steps=int(steps),
323
+ block_length=int(block_len),
324
+ temperature=float(temperature),
325
+ cfg_scale=float(cfg_scale),
326
+ )
327
+ return audio, status
328
+
329
+
330
+ @spaces.GPU
331
+ def i2s_handler(
332
+ image,
333
+ message,
334
+ max_tokens,
335
+ steps,
336
+ block_len,
337
+ temperature,
338
+ cfg_scale,
339
+ ):
340
+ app = get_app()
341
+ audio, status = app.run_i2s(
342
+ image=image,
343
+ message=message,
344
+ max_new_tokens=int(max_tokens),
345
+ steps=int(steps),
346
+ block_length=int(block_len),
347
+ temperature=float(temperature),
348
+ cfg_scale=float(cfg_scale),
349
+ )
350
+ return audio, status
351
+
352
+
353
+ @spaces.GPU
354
+ def chat_handler(
355
+ message,
356
+ max_tokens,
357
+ steps,
358
+ block_len,
359
+ temperature,
360
+ ):
361
+ app = get_app()
362
+ text, status = app.run_chat(
363
+ message=message,
364
+ max_new_tokens=int(max_tokens),
365
+ steps=int(steps),
366
+ block_length=int(block_len),
367
+ temperature=float(temperature),
368
+ )
369
+ return text, status
370
+
371
+
372
+ @spaces.GPU
373
+ def mmu_handler(
374
+ image_a,
375
+ image_b,
376
+ question,
377
+ max_tokens,
378
+ steps,
379
+ block_len,
380
+ temperature,
381
+ ):
382
+ app = get_app()
383
+ text, status = app.run_mmu_dual(
384
+ image_a=image_a,
385
+ image_b=image_b,
386
+ message=question,
387
+ max_new_tokens=int(max_tokens),
388
+ steps=int(steps),
389
+ block_length=int(block_len),
390
+ temperature=float(temperature),
391
+ )
392
+ return text, status
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
  @spaces.GPU
396
+ def t2i_handler(
397
+ prompt,
398
+ timesteps,
399
+ temperature,
400
+ guidance,
401
+ ):
402
+ app = get_app()
403
+ image, status = app.run_t2i(
404
+ prompt=prompt,
405
+ timesteps=int(timesteps),
406
+ temperature=float(temperature),
407
+ guidance_scale=float(guidance),
408
  )
409
+ return image, status
410
 
411
+
412
+ @spaces.GPU
413
+ def i2i_handler(
414
+ instruction,
415
+ image,
416
+ timesteps,
417
+ temperature,
418
+ guidance,
419
+ ):
420
+ app = get_app()
421
+ image_out, status = app.run_i2i(
422
+ instruction=instruction,
423
+ source_image=image,
424
+ timesteps=int(timesteps),
425
+ temperature=float(temperature),
426
+ guidance_scale=float(guidance),
427
+ )
428
+ return image_out, status
429
+
430
+
431
+ # ---------------------------
432
+ # Gradio UI (10 tabs)
433
+ # ---------------------------
434
+
435
+ theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
436
+
437
+ with gr.Blocks(
438
+ title="AIDAS Lab @ SNU - OMni-modal Diffusion",
439
+ css=CUSTOM_CSS,
440
+ theme=theme,
441
+ js=FORCE_LIGHT_MODE_JS,
442
+ ) as demo:
443
+ gr.Markdown(
444
+ "## Omni-modal Diffusion Foundation Model\n"
445
+ "### AIDAS Lab @ SNU"
446
+ )
447
+
448
+ with gr.Tab("Text → Speech (T2S)"):
449
+ with gr.Row():
450
+ t2s_text = gr.Textbox(
451
+ label="Input text",
452
+ lines=4,
453
+ placeholder="Type the speech you want to synthesize...",
454
+ )
455
+ t2s_audio = gr.Audio(label="Generated speech", type="numpy")
456
+ t2s_status = gr.Textbox(label="Status", interactive=False)
457
+ with gr.Accordion("Advanced settings", open=False):
458
+ t2s_max_tokens = gr.Slider(2, 512, value=384, step=2, label="Speech token length")
459
+ t2s_steps = gr.Slider(2, 512, value=128, step=2, label="Total refinement steps")
460
+ t2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
461
+ t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
462
+ t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, step=0.1, label="CFG scale")
463
+ with gr.Row():
464
+ t2s_gender = gr.Dropdown(["random", "female", "male"], value="random", label="Gender")
465
+ t2s_emotion = gr.Dropdown(["random", "angry", "happy", "neutral", "sad"], value="random", label="Emotion")
466
+ with gr.Row():
467
+ t2s_speed = gr.Dropdown(["random", "normal", "fast", "slow"], value="random", label="Speed")
468
+ t2s_pitch = gr.Dropdown(["random", "normal", "high", "low"], value="random", label="Pitch")
469
+ t2s_btn = gr.Button("Generate speech", variant="primary")
470
+ t2s_btn.click(
471
+ t2s_handler,
472
+ inputs=[
473
+ t2s_text,
474
+ t2s_max_tokens,
475
+ t2s_steps,
476
+ t2s_block,
477
+ t2s_temperature,
478
+ t2s_cfg,
479
+ t2s_gender,
480
+ t2s_emotion,
481
+ t2s_speed,
482
+ t2s_pitch,
483
+ ],
484
+ outputs=[t2s_audio, t2s_status],
485
+ )
486
+
487
+ with gr.Tab("Speech → Speech (S2S)"):
488
+ s2s_audio_in = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"])
489
+ s2s_audio_out = gr.Audio(type="numpy", label="Reply speech")
490
+ s2s_status = gr.Textbox(label="Status", interactive=False)
491
+ with gr.Accordion("Advanced settings", open=False):
492
+ s2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
493
+ s2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps")
494
+ s2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
495
+ s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="Sampling temperature")
496
+ s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, step=0.1, label="CFG scale")
497
+ s2s_btn = gr.Button("Generate reply speech", variant="primary")
498
+ s2s_btn.click(
499
+ s2s_handler,
500
+ inputs=[
501
+ s2s_audio_in,
502
+ s2s_max_tokens,
503
+ s2s_steps,
504
+ s2s_block,
505
+ s2s_temperature,
506
+ s2s_cfg,
507
+ ],
508
+ outputs=[s2s_audio_out, s2s_status],
509
+ )
510
+
511
+ with gr.Tab("Speech → Text (S2T)"):
512
+ s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
513
+ s2t_text_out = gr.Textbox(label="Transcription", lines=4)
514
+ s2t_status = gr.Textbox(label="Status", interactive=False)
515
+ with gr.Accordion("Advanced settings", open=False):
516
+ s2t_steps = gr.Slider(2, 512, value=128, step=2, label="Denoising steps")
517
+ s2t_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
518
+ s2t_max_tokens = gr.Slider(2, 512, value=128, step=2, label="Max new tokens")
519
+ s2t_remasking = gr.Dropdown(
520
+ ["low_confidence", "random"],
521
+ value="low_confidence",
522
+ label="Remasking strategy",
523
+ )
524
+ s2t_btn = gr.Button("Transcribe", variant="primary")
525
+ s2t_btn.click(
526
+ s2t_handler,
527
+ inputs=[s2t_audio_in, s2t_steps, s2t_block, s2t_max_tokens, s2t_remasking],
528
+ outputs=[s2t_text_out, s2t_status],
529
+ )
530
+
531
+ with gr.Tab("Video → Text (V2T)"):
532
+ v2t_video_in = gr.Video(
533
+ label="Upload or record video",
534
+ height=256,
535
+ sources=["upload", "webcam"],
536
+ )
537
+ v2t_text_out = gr.Textbox(label="Caption / answer", lines=4)
538
+ v2t_status = gr.Textbox(label="Status", interactive=False)
539
+ with gr.Accordion("Advanced settings", open=False):
540
+ v2t_steps = gr.Slider(2, 512, value=64, step=2, label="Denoising steps")
541
+ v2t_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
542
+ v2t_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Max new tokens")
543
+ v2t_btn = gr.Button("Generate caption", variant="primary")
544
+ v2t_btn.click(
545
+ v2t_handler,
546
+ inputs=[v2t_video_in, v2t_steps, v2t_block, v2t_max_tokens],
547
+ outputs=[v2t_text_out, v2t_status],
548
+ )
549
+
550
+ with gr.Tab("Video → Speech (V2S)"):
551
+ v2s_video_in = gr.Video(
552
+ label="Upload or record video",
553
+ height=256,
554
+ sources=["upload", "webcam"],
555
+ )
556
+ v2s_prompt = gr.Textbox(
557
+ label="Optional instruction",
558
+ placeholder="(Optional) e.g., 'Describe this scene in spoken form.'",
559
+ )
560
+ v2s_audio_out = gr.Audio(type="numpy", label="Generated speech")
561
+ v2s_status = gr.Textbox(label="Status", interactive=False)
562
+ with gr.Accordion("Advanced settings", open=False):
563
+ v2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
564
+ v2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps")
565
+ v2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
566
+ v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
567
+ v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
568
+ v2s_btn = gr.Button("Generate speech from video", variant="primary")
569
+ v2s_btn.click(
570
+ v2s_handler,
571
+ inputs=[
572
+ v2s_video_in,
573
+ v2s_prompt,
574
+ v2s_max_tokens,
575
+ v2s_steps,
576
+ v2s_block,
577
+ v2s_temperature,
578
+ v2s_cfg,
579
+ ],
580
+ outputs=[v2s_audio_out, v2s_status],
581
+ )
582
+
583
+ with gr.Tab("Image → Speech (I2S)"):
584
+ i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
585
+ i2s_prompt = gr.Textbox(
586
+ label="Optional question",
587
+ placeholder="(Optional) e.g., 'Describe this image aloud.'",
588
+ )
589
+ i2s_audio_out = gr.Audio(type="numpy", label="Spoken description")
590
+ i2s_status = gr.Textbox(label="Status", interactive=False)
591
+ with gr.Accordion("Advanced settings", open=False):
592
+ i2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
593
+ i2s_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
594
+ i2s_block = gr.Slider(2, 512, value=256, step=2, label="Block length")
595
+ i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
596
+ i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
597
+ i2s_btn = gr.Button("Generate spoken description", variant="primary")
598
+ i2s_btn.click(
599
+ i2s_handler,
600
+ inputs=[
601
+ i2s_image_in,
602
+ i2s_prompt,
603
+ i2s_max_tokens,
604
+ i2s_steps,
605
+ i2s_block,
606
+ i2s_temperature,
607
+ i2s_cfg,
608
+ ],
609
+ outputs=[i2s_audio_out, i2s_status],
610
+ )
611
+
612
+ with gr.Tab("Text Chat"):
613
+ chat_in = gr.Textbox(
614
+ label="Message",
615
+ lines=4,
616
+ placeholder="Ask anything. The model will reply in text.",
617
+ )
618
+ chat_out = gr.Textbox(label="Assistant reply", lines=6)
619
+ chat_status = gr.Textbox(label="Status", interactive=False)
620
+ with gr.Accordion("Advanced settings", open=False):
621
+ chat_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Reply max tokens")
622
+ chat_steps = gr.Slider(2, 512, value=64, step=2, label="Refinement steps")
623
+ chat_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
624
+ chat_temperature_slider = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Sampling temperature")
625
+ chat_btn = gr.Button("Send", variant="primary")
626
+ chat_btn.click(
627
+ chat_handler,
628
+ inputs=[
629
+ chat_in,
630
+ chat_max_tokens,
631
+ chat_steps,
632
+ chat_block,
633
+ chat_temperature_slider,
634
+ ],
635
+ outputs=[chat_out, chat_status],
636
+ )
637
+
638
+ with gr.Tab("MMU (2 images → text)"):
639
+ mmu_img_a = gr.Image(type="pil", label="Image A", sources=["upload"])
640
+ mmu_img_b = gr.Image(type="pil", label="Image B", sources=["upload"])
641
+ mmu_question = gr.Textbox(
642
+ label="Question",
643
+ lines=3,
644
+ placeholder="Ask about the relationship or differences between the two images.",
645
+ )
646
+ mmu_answer = gr.Textbox(label="Answer", lines=6)
647
+ mmu_status = gr.Textbox(label="Status", interactive=False)
648
+ with gr.Accordion("Advanced settings", open=False):
649
+ mmu_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Answer max tokens")
650
+ mmu_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
651
+ mmu_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
652
+ mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Sampling temperature")
653
+ mmu_btn = gr.Button("Answer about the two images", variant="primary")
654
+ mmu_btn.click(
655
+ mmu_handler,
656
+ inputs=[
657
+ mmu_img_a,
658
+ mmu_img_b,
659
+ mmu_question,
660
+ mmu_max_tokens,
661
+ mmu_steps,
662
+ mmu_block,
663
+ mmu_temperature,
664
+ ],
665
+ outputs=[mmu_answer, mmu_status],
666
+ )
667
+
668
+ with gr.Tab("Text → Image (T2I)"):
669
+ t2i_prompt = gr.Textbox(
670
+ label="Prompt",
671
+ lines=4,
672
+ placeholder="Describe the image you want to generate...",
673
+ )
674
+ t2i_image_out = gr.Image(label="Generated image")
675
+ t2i_status = gr.Textbox(label="Status", interactive=False)
676
+ with gr.Accordion("Advanced settings", open=False):
677
+ t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps")
678
+ t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
679
+ t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
680
+ t2i_btn = gr.Button("Generate image", variant="primary")
681
+ t2i_btn.click(
682
+ t2i_handler,
683
+ inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance],
684
+ outputs=[t2i_image_out, t2i_status],
685
+ )
686
+
687
+ with gr.Tab("Image Editing (I2I)"):
688
+ i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"])
689
+ i2i_instr = gr.Textbox(
690
+ label="Editing instruction",
691
+ lines=4,
692
+ placeholder="Describe how you want to edit the image...",
693
+ )
694
+ i2i_image_out = gr.Image(label="Edited image")
695
+ i2i_status = gr.Textbox(label="Status", interactive=False)
696
+ with gr.Accordion("Advanced settings", open=False):
697
+ i2i_timesteps = gr.Slider(4, 128, value=18, step=2, label="Timesteps")
698
+ i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
699
+ i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
700
+ i2i_btn = gr.Button("Apply edit", variant="primary")
701
+ i2i_btn.click(
702
+ i2i_handler,
703
+ inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance],
704
+ outputs=[i2i_image_out, i2i_status],
705
+ )
706
 
707
 
708
  if __name__ == "__main__":
709
+ demo.launch()