jake commited on
Commit
df89a6a
Β·
1 Parent(s): 05fc139
Files changed (1) hide show
  1. app.py +239 -65
app.py CHANGED
@@ -168,6 +168,86 @@ def download_checkpoint() -> Path:
168
  return aliased
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # ---------------------------
172
  # Global OmadaDemo instance
173
  # ---------------------------
@@ -180,9 +260,8 @@ def get_app() -> OmadaDemo:
180
  if APP is not None:
181
  return APP
182
 
183
- # Download everything once
184
  ckpt_dir = download_checkpoint()
185
- asset_root = download_assets()
186
  style_root = download_style()
187
 
188
  # Wire style centroids to expected locations
@@ -440,11 +519,21 @@ with gr.Blocks(
440
  theme=theme,
441
  js=FORCE_LIGHT_MODE_JS,
442
  ) as demo:
 
 
 
 
 
 
 
 
 
443
  gr.Markdown(
444
  "## Omni-modal Diffusion Foundation Model\n"
445
  "### AIDAS Lab @ SNU"
446
  )
447
 
 
448
  with gr.Tab("Text β†’ Speech (T2S)"):
449
  with gr.Row():
450
  t2s_text = gr.Textbox(
@@ -484,6 +573,15 @@ with gr.Blocks(
484
  outputs=[t2s_audio, t2s_status],
485
  )
486
 
 
 
 
 
 
 
 
 
 
487
  with gr.Tab("Speech β†’ Speech (S2S)"):
488
  s2s_audio_in = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"])
489
  s2s_audio_out = gr.Audio(type="numpy", label="Reply speech")
@@ -508,6 +606,15 @@ with gr.Blocks(
508
  outputs=[s2s_audio_out, s2s_status],
509
  )
510
 
 
 
 
 
 
 
 
 
 
511
  with gr.Tab("Speech β†’ Text (S2T)"):
512
  s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
513
  s2t_text_out = gr.Textbox(label="Transcription", lines=4)
@@ -528,6 +635,15 @@ with gr.Blocks(
528
  outputs=[s2t_text_out, s2t_status],
529
  )
530
 
 
 
 
 
 
 
 
 
 
531
  with gr.Tab("Video β†’ Text (V2T)"):
532
  v2t_video_in = gr.Video(
533
  label="Upload or record video",
@@ -547,6 +663,15 @@ with gr.Blocks(
547
  outputs=[v2t_text_out, v2t_status],
548
  )
549
 
 
 
 
 
 
 
 
 
 
550
  with gr.Tab("Video β†’ Speech (V2S)"):
551
  v2s_video_in = gr.Video(
552
  label="Upload or record video",
@@ -580,35 +705,64 @@ with gr.Blocks(
580
  outputs=[v2s_audio_out, v2s_status],
581
  )
582
 
583
- with gr.Tab("Image β†’ Speech (I2S)"):
584
- i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
585
- i2s_prompt = gr.Textbox(
586
- label="Optional question",
587
- placeholder="(Optional) e.g., 'Describe this image aloud.'",
 
 
 
 
 
 
 
 
 
588
  )
589
- i2s_audio_out = gr.Audio(type="numpy", label="Spoken description")
590
- i2s_status = gr.Textbox(label="Status", interactive=False)
591
  with gr.Accordion("Advanced settings", open=False):
592
- i2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
593
- i2s_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
594
- i2s_block = gr.Slider(2, 512, value=256, step=2, label="Block length")
595
- i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
596
- i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
597
- i2s_btn = gr.Button("Generate spoken description", variant="primary")
598
- i2s_btn.click(
599
- i2s_handler,
600
- inputs=[
601
- i2s_image_in,
602
- i2s_prompt,
603
- i2s_max_tokens,
604
- i2s_steps,
605
- i2s_block,
606
- i2s_temperature,
607
- i2s_cfg,
608
- ],
609
- outputs=[i2s_audio_out, i2s_status],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  )
611
 
 
612
  with gr.Tab("Text Chat"):
613
  chat_in = gr.Textbox(
614
  label="Message",
@@ -635,6 +789,55 @@ with gr.Blocks(
635
  outputs=[chat_out, chat_status],
636
  )
637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  with gr.Tab("MMU (2 images β†’ text)"):
639
  mmu_img_a = gr.Image(type="pil", label="Image A", sources=["upload"])
640
  mmu_img_b = gr.Image(type="pil", label="Image B", sources=["upload"])
@@ -665,45 +868,16 @@ with gr.Blocks(
665
  outputs=[mmu_answer, mmu_status],
666
  )
667
 
668
- with gr.Tab("Text β†’ Image (T2I)"):
669
- t2i_prompt = gr.Textbox(
670
- label="Prompt",
671
- lines=4,
672
- placeholder="Describe the image you want to generate...",
673
- )
674
- t2i_image_out = gr.Image(label="Generated image")
675
- t2i_status = gr.Textbox(label="Status", interactive=False)
676
- with gr.Accordion("Advanced settings", open=False):
677
- t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps")
678
- t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
679
- t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
680
- t2i_btn = gr.Button("Generate image", variant="primary")
681
- t2i_btn.click(
682
- t2i_handler,
683
- inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance],
684
- outputs=[t2i_image_out, t2i_status],
685
- )
686
-
687
- with gr.Tab("Image Editing (I2I)"):
688
- i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"])
689
- i2i_instr = gr.Textbox(
690
- label="Editing instruction",
691
- lines=4,
692
- placeholder="Describe how you want to edit the image...",
693
- )
694
- i2i_image_out = gr.Image(label="Edited image")
695
- i2i_status = gr.Textbox(label="Status", interactive=False)
696
- with gr.Accordion("Advanced settings", open=False):
697
- i2i_timesteps = gr.Slider(4, 128, value=18, step=2, label="Timesteps")
698
- i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
699
- i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
700
- i2i_btn = gr.Button("Apply edit", variant="primary")
701
- i2i_btn.click(
702
- i2i_handler,
703
- inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance],
704
- outputs=[i2i_image_out, i2i_status],
705
- )
706
 
 
 
707
 
708
  if __name__ == "__main__":
709
  demo.launch()
 
168
  return aliased
169
 
170
 
171
+ # ---------------------------
172
+ # Assets & examples from HF dataset
173
+ # ---------------------------
174
+
175
+ ASSET_ROOT = download_assets()
176
+ DEMO_ROOT = ASSET_ROOT / "demo"
177
+
178
+ LOGO_PATH = DEMO_ROOT / "logo.png"
179
+ T2S_TEXT_PATH = DEMO_ROOT / "t2s" / "text.txt"
180
+ CHAT_TEXT_PATH = DEMO_ROOT / "chat" / "text.txt"
181
+ T2I_TEXT_PATH = DEMO_ROOT / "t2i" / "text.txt"
182
+
183
+
184
+ def _load_text_examples(path: Path):
185
+ if not path.exists():
186
+ return []
187
+ try:
188
+ lines = [
189
+ line.strip()
190
+ for line in path.read_text(encoding="utf-8").splitlines()
191
+ if line.strip()
192
+ ]
193
+ except Exception:
194
+ return []
195
+ return [[line] for line in lines]
196
+
197
+
198
+ def _load_media_examples(subdir: str, suffixes):
199
+ d = DEMO_ROOT / subdir
200
+ if not d.exists():
201
+ return []
202
+ examples = []
203
+ for p in sorted(d.iterdir()):
204
+ if p.is_file() and p.suffix.lower() in suffixes:
205
+ examples.append([str(p)])
206
+ return examples
207
+
208
+
209
+ # ν…μŠ€νŠΈ 기반 예제
210
+ T2S_EXAMPLES = _load_text_examples(T2S_TEXT_PATH)
211
+ CHAT_EXAMPLES = _load_text_examples(CHAT_TEXT_PATH)
212
+ T2I_EXAMPLES = _load_text_examples(T2I_TEXT_PATH)
213
+
214
+ # μ˜€λ””μ˜€ / λΉ„λ””μ˜€ / 이미지 예제
215
+ _AUDIO_SUFFIXES = {".wav", ".mp3", ".flac", ".ogg"}
216
+ _VIDEO_SUFFIXES = {".mp4", ".mov", ".avi", ".webm"}
217
+ _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp"}
218
+
219
+ S2T_EXAMPLES = _load_media_examples("s2t", _AUDIO_SUFFIXES)
220
+ V2T_EXAMPLES = _load_media_examples("v2t", _VIDEO_SUFFIXES)
221
+ S2S_EXAMPLES = _load_media_examples("s2s", _AUDIO_SUFFIXES)
222
+ if not S2S_EXAMPLES and S2T_EXAMPLES:
223
+ S2S_EXAMPLES = S2T_EXAMPLES[: min(4, len(S2T_EXAMPLES))]
224
+
225
+ V2S_EXAMPLES = _load_media_examples("v2s", _VIDEO_SUFFIXES)
226
+ if not V2S_EXAMPLES and V2T_EXAMPLES:
227
+ V2S_EXAMPLES = V2T_EXAMPLES[: min(4, len(V2T_EXAMPLES))]
228
+
229
+ I2S_EXAMPLES = _load_media_examples("i2s", _IMAGE_SUFFIXES)
230
+
231
+ # MMU: 2 images + question
232
+ MMU_DIR = DEMO_ROOT / "mmu"
233
+ MMU_EXAMPLES = []
234
+ if MMU_DIR.exists():
235
+ mmu_imgs = [
236
+ p for p in sorted(MMU_DIR.iterdir())
237
+ if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
238
+ ]
239
+ if len(mmu_imgs) >= 2:
240
+ MMU_EXAMPLES = [[
241
+ str(mmu_imgs[0]),
242
+ str(mmu_imgs[1]),
243
+ "What are the differences between the two images?"
244
+ ]]
245
+
246
+ # i2sκ°€ μ—†κ³  mmu μ˜ˆμ œκ°€ 있으면, 첫 번째 이미지λ₯Ό 이미지 예제둜 μž¬μ‚¬μš©
247
+ if not I2S_EXAMPLES and MMU_EXAMPLES:
248
+ I2S_EXAMPLES = [[MMU_EXAMPLES[0][0]]]
249
+
250
+
251
  # ---------------------------
252
  # Global OmadaDemo instance
253
  # ---------------------------
 
260
  if APP is not None:
261
  return APP
262
 
263
+ # Download ckpt + style centroids once
264
  ckpt_dir = download_checkpoint()
 
265
  style_root = download_style()
266
 
267
  # Wire style centroids to expected locations
 
519
  theme=theme,
520
  js=FORCE_LIGHT_MODE_JS,
521
  ) as demo:
522
+ # 둜고 (있으면)
523
+ if LOGO_PATH.exists():
524
+ gr.Image(
525
+ value=str(LOGO_PATH),
526
+ show_label=False,
527
+ height=140,
528
+ interactive=False,
529
+ )
530
+
531
  gr.Markdown(
532
  "## Omni-modal Diffusion Foundation Model\n"
533
  "### AIDAS Lab @ SNU"
534
  )
535
 
536
+ # ---------- T2S ----------
537
  with gr.Tab("Text β†’ Speech (T2S)"):
538
  with gr.Row():
539
  t2s_text = gr.Textbox(
 
573
  outputs=[t2s_audio, t2s_status],
574
  )
575
 
576
+ if T2S_EXAMPLES:
577
+ gr.Markdown("**Sample prompts**")
578
+ gr.Examples(
579
+ examples=T2S_EXAMPLES,
580
+ inputs=[t2s_text],
581
+ examples_per_page=4,
582
+ )
583
+
584
+ # ---------- S2S ----------
585
  with gr.Tab("Speech β†’ Speech (S2S)"):
586
  s2s_audio_in = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"])
587
  s2s_audio_out = gr.Audio(type="numpy", label="Reply speech")
 
606
  outputs=[s2s_audio_out, s2s_status],
607
  )
608
 
609
+ if S2S_EXAMPLES:
610
+ gr.Markdown("**Sample S2S clips**")
611
+ gr.Examples(
612
+ examples=S2S_EXAMPLES,
613
+ inputs=[s2s_audio_in],
614
+ examples_per_page=4,
615
+ )
616
+
617
+ # ---------- S2T ----------
618
  with gr.Tab("Speech β†’ Text (S2T)"):
619
  s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
620
  s2t_text_out = gr.Textbox(label="Transcription", lines=4)
 
635
  outputs=[s2t_text_out, s2t_status],
636
  )
637
 
638
+ if S2T_EXAMPLES:
639
+ gr.Markdown("**Sample S2T clips**")
640
+ gr.Examples(
641
+ examples=S2T_EXAMPLES,
642
+ inputs=[s2t_audio_in],
643
+ examples_per_page=4,
644
+ )
645
+
646
+ # ---------- V2T ----------
647
  with gr.Tab("Video β†’ Text (V2T)"):
648
  v2t_video_in = gr.Video(
649
  label="Upload or record video",
 
663
  outputs=[v2t_text_out, v2t_status],
664
  )
665
 
666
+ if V2T_EXAMPLES:
667
+ gr.Markdown("**Sample videos**")
668
+ gr.Examples(
669
+ examples=V2T_EXAMPLES,
670
+ inputs=[v2t_video_in],
671
+ examples_per_page=4,
672
+ )
673
+
674
+ # ---------- V2S ----------
675
  with gr.Tab("Video β†’ Speech (V2S)"):
676
  v2s_video_in = gr.Video(
677
  label="Upload or record video",
 
705
  outputs=[v2s_audio_out, v2s_status],
706
  )
707
 
708
+ if V2S_EXAMPLES:
709
+ gr.Markdown("**Sample videos**")
710
+ gr.Examples(
711
+ examples=V2S_EXAMPLES,
712
+ inputs=[v2s_video_in],
713
+ examples_per_page=4,
714
+ )
715
+
716
+ # ---------- T2I ----------
717
+ with gr.Tab("Text β†’ Image (T2I)"):
718
+ t2i_prompt = gr.Textbox(
719
+ label="Prompt",
720
+ lines=4,
721
+ placeholder="Describe the image you want to generate...",
722
  )
723
+ t2i_image_out = gr.Image(label="Generated image")
724
+ t2i_status = gr.Textbox(label="Status", interactive=False)
725
  with gr.Accordion("Advanced settings", open=False):
726
+ t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps")
727
+ t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
728
+ t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
729
+ t2i_btn = gr.Button("Generate image", variant="primary")
730
+ t2i_btn.click(
731
+ t2i_handler,
732
+ inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance],
733
+ outputs=[t2i_image_out, t2i_status],
734
+ )
735
+
736
+ if T2I_EXAMPLES:
737
+ gr.Markdown("**Sample prompts**")
738
+ gr.Examples(
739
+ examples=T2I_EXAMPLES,
740
+ inputs=[t2i_prompt],
741
+ examples_per_page=4,
742
+ )
743
+
744
+ # ---------- I2I ----------
745
+ with gr.Tab("Image Editing (I2I)"):
746
+ i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"])
747
+ i2i_instr = gr.Textbox(
748
+ label="Editing instruction",
749
+ lines=4,
750
+ placeholder="Describe how you want to edit the image...",
751
+ )
752
+ i2i_image_out = gr.Image(label="Edited image")
753
+ i2i_status = gr.Textbox(label="Status", interactive=False)
754
+ with gr.Accordion("Advanced settings", open=False):
755
+ i2i_timesteps = gr.Slider(4, 128, value=18, step=2, label="Timesteps")
756
+ i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
757
+ i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
758
+ i2i_btn = gr.Button("Apply edit", variant="primary")
759
+ i2i_btn.click(
760
+ i2i_handler,
761
+ inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance],
762
+ outputs=[i2i_image_out, i2i_status],
763
  )
764
 
765
+ # ---------- Chat ----------
766
  with gr.Tab("Text Chat"):
767
  chat_in = gr.Textbox(
768
  label="Message",
 
789
  outputs=[chat_out, chat_status],
790
  )
791
 
792
+ if CHAT_EXAMPLES:
793
+ gr.Markdown("**Sample prompts**")
794
+ gr.Examples(
795
+ examples=CHAT_EXAMPLES,
796
+ inputs=[chat_in],
797
+ examples_per_page=4,
798
+ )
799
+
800
+
801
+ # ---------- I2S ----------
802
+ with gr.Tab("Image β†’ Speech (I2S)"):
803
+ i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
804
+ i2s_prompt = gr.Textbox(
805
+ label="Optional question",
806
+ placeholder="(Optional) e.g., 'Describe this image aloud.'",
807
+ )
808
+ i2s_audio_out = gr.Audio(type="numpy", label="Spoken description")
809
+ i2s_status = gr.Textbox(label="Status", interactive=False)
810
+ with gr.Accordion("Advanced settings", open=False):
811
+ i2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
812
+ i2s_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
813
+ i2s_block = gr.Slider(2, 512, value=256, step=2, label="Block length")
814
+ i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
815
+ i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
816
+ i2s_btn = gr.Button("Generate spoken description", variant="primary")
817
+ i2s_btn.click(
818
+ i2s_handler,
819
+ inputs=[
820
+ i2s_image_in,
821
+ i2s_prompt,
822
+ i2s_max_tokens,
823
+ i2s_steps,
824
+ i2s_block,
825
+ i2s_temperature,
826
+ i2s_cfg,
827
+ ],
828
+ outputs=[i2s_audio_out, i2s_status],
829
+ )
830
+
831
+ if I2S_EXAMPLES:
832
+ gr.Markdown("**Sample images**")
833
+ gr.Examples(
834
+ examples=I2S_EXAMPLES,
835
+ inputs=[i2s_image_in],
836
+ examples_per_page=4,
837
+ )
838
+
839
+
840
+ # ---------- MMU ----------
841
  with gr.Tab("MMU (2 images β†’ text)"):
842
  mmu_img_a = gr.Image(type="pil", label="Image A", sources=["upload"])
843
  mmu_img_b = gr.Image(type="pil", label="Image B", sources=["upload"])
 
868
  outputs=[mmu_answer, mmu_status],
869
  )
870
 
871
+ if MMU_EXAMPLES:
872
+ gr.Markdown("**Sample MMU example**")
873
+ gr.Examples(
874
+ examples=MMU_EXAMPLES,
875
+ inputs=[mmu_img_a, mmu_img_b, mmu_question],
876
+ examples_per_page=1,
877
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
 
879
+ # I2IλŠ” 별도 예제 ν…μŠ€νŠΈ/이미지 ꡬ쑰가 μ• λ§€ν•΄μ„œ 일단 μƒλž΅
880
+ # (ν•„μš”ν•˜λ©΄ demo/i2i_prompt.txt + demo/i2i_images/ 둜 λ‚˜λˆ μ„œ λ„£κ³  wiring ν•˜λ©΄ 됨)
881
 
882
  if __name__ == "__main__":
883
  demo.launch()