jaeikkim commited on
Commit
7b822c3
·
1 Parent(s): afc1fa2

Dynin-Omni

Browse files
.gitignore CHANGED
@@ -1,3 +1,6 @@
1
  __pycache__/
2
  *.pyc
3
  MMaDA/inference/demo/ti2ti/
 
 
 
 
1
  __pycache__/
2
  *.pyc
3
  MMaDA/inference/demo/ti2ti/
4
+ _asset_cache/
5
+ _preview_cache/
6
+ _style_cache/
EMOVA_speech_tokenizer/emova_speech_tokenizer/speech_tokenization/condition_style_centroid ADDED
@@ -0,0 +1 @@
 
 
1
+ /dataset/omada/AIDAS-Omni-Modal-Diffusion/_style_cache
MMaDA/inference/gradio_multimodal_demo_inst.py CHANGED
@@ -25,6 +25,7 @@ import io
25
  import os
26
  import math
27
  import random
 
28
  import sys
29
  import tempfile
30
  import wave
@@ -207,6 +208,14 @@ html, body, body.dark, html.dark {
207
  box-shadow: none;
208
  border: 1px solid var(--omada-border);
209
  background: #ffffff;
 
 
 
 
 
 
 
 
210
  }
211
  .omada-controls {
212
  gap: 16px !important;
@@ -724,7 +733,7 @@ import cv2
724
  import gradio as gr
725
  import numpy as np
726
  import torch
727
- from omegaconf import DictConfig, OmegaConf
728
  from PIL import Image
729
 
730
  from inference.common import (
@@ -745,6 +754,12 @@ def _cfg_get(cfg, key, default=None):
745
 
746
  if cfg is None:
747
  return default
 
 
 
 
 
 
748
  if isinstance(cfg, dict):
749
  return cfg.get(key, default)
750
  try:
@@ -897,6 +912,14 @@ class OmadaDemo:
897
  )
898
  )
899
  self.max_text_len = int(getattr(self.train_cfg.dataset.preprocessing, "max_seq_length", 1024))
 
 
 
 
 
 
 
 
900
 
901
  model_seq_len = getattr(self.model.config, "num_vq_tokens", None)
902
  if model_seq_len is None:
@@ -913,7 +936,55 @@ class OmadaDemo:
913
  self.noise_type = _cfg_get(training_cfg, "noise_type", "mask")
914
  self.predict_all_tokens = bool(_cfg_get(training_cfg, "predict_all_tokens", False))
915
  self.t2i_default_timesteps = int(_cfg_get(training_cfg, "generation_timesteps", 20))
916
- self.i2i_default_timesteps = int(_cfg_get(training_cfg, "i2i_eval_timesteps", 24))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
917
 
918
  self.audio_condition_default = "gender-female_emotion-neutral_speed-normal_pitch-normal"
919
  style_map = getattr(getattr(self.vq_audio, "config", None), "u2s_style2idx", None)
@@ -942,6 +1013,13 @@ class OmadaDemo:
942
  speed_choice: str,
943
  pitch_choice: str,
944
  ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
 
 
 
 
 
 
 
945
 
946
  if text is None or not text.strip():
947
  return None, "Please provide text to synthesize."
@@ -1013,6 +1091,110 @@ class OmadaDemo:
1013
  status = f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})."
1014
  return audio, status
1015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1016
  # ------------------------------------------------------------------
1017
  # Speech-to-Speech
1018
  # ------------------------------------------------------------------
@@ -1134,10 +1316,20 @@ class OmadaDemo:
1134
  max_new_tokens: int,
1135
  remasking: str,
1136
  ) -> Tuple[str, str]:
 
 
 
 
 
 
1137
 
1138
  if not audio_path:
1139
  return "", "Please upload an audio file first."
1140
 
 
 
 
 
1141
  tokens = self.vq_audio.encode(audio_path).to(self.device)
1142
  offset = self.text_vocab_size + self.speech_codebook
1143
  tokens = tokens + offset
@@ -1175,13 +1367,93 @@ class OmadaDemo:
1175
  remasking=str(remasking),
1176
  )
1177
 
1178
- decoded = self.uni_prompting.text_tokenizer.batch_decode(
1179
- output_ids[:, input_ids.shape[1]:],
1180
- skip_special_tokens=True,
1181
- )[0]
1182
-
 
 
1183
  return decoded.strip(), "Transcription generated successfully."
1184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1185
  # ------------------------------------------------------------------
1186
  # Video-to-Text
1187
  # ------------------------------------------------------------------
@@ -1192,6 +1464,11 @@ class OmadaDemo:
1192
  block_length: int,
1193
  max_new_tokens: int,
1194
  ) -> Tuple[str, str]:
 
 
 
 
 
1195
 
1196
  resolved_path, converted = self._prepare_video_path(video_path)
1197
  if not resolved_path:
@@ -1241,14 +1518,80 @@ class OmadaDemo:
1241
  raw_all = self.uni_prompting.text_tokenizer.decode(output_ids[0], skip_special_tokens=False)
1242
  print("[V2T] RAW ALL:", repr(raw_all))
1243
 
1244
- decoded = self.uni_prompting.text_tokenizer.batch_decode(
1245
- output_ids[:, input_ids.shape[1]:],
1246
- skip_special_tokens=True,
1247
- )[0]
1248
  print("[V2T] DECODED SLICE:", repr(decoded))
1249
-
 
 
1250
  return decoded.strip(), "Video caption generated successfully."
1251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1252
  # ------------------------------------------------------------------
1253
  # Text-to-Image
1254
  # ------------------------------------------------------------------
@@ -1259,6 +1602,11 @@ class OmadaDemo:
1259
  temperature: float,
1260
  guidance_scale: float,
1261
  ) -> Tuple[Optional[Image.Image], str]:
 
 
 
 
 
1262
  if not prompt or not prompt.strip():
1263
  return None, "Please provide a text prompt."
1264
 
@@ -1307,6 +1655,65 @@ class OmadaDemo:
1307
  image = self._decode_image_tokens(gen_tokens[0])
1308
  return image, "Image generated from text prompt."
1309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1310
  # ------------------------------------------------------------------
1311
  # Image-to-Image Editing
1312
  # ------------------------------------------------------------------
@@ -1318,6 +1725,11 @@ class OmadaDemo:
1318
  temperature: float,
1319
  guidance_scale: float,
1320
  ) -> Tuple[Optional[Image.Image], str]:
 
 
 
 
 
1321
  if source_image is None:
1322
  return None, "Please upload a reference image."
1323
  if not instruction or not instruction.strip():
@@ -1378,6 +1790,80 @@ class OmadaDemo:
1378
  image = self._decode_image_tokens(gen_tokens[0])
1379
  return image, "Edited image generated."
1380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1381
  # ------------------------------------------------------------------
1382
  # Video-to-Speech
1383
  # ------------------------------------------------------------------
@@ -1550,6 +2036,12 @@ class OmadaDemo:
1550
  block_length: int,
1551
  temperature: float,
1552
  ) -> Tuple[str, str]:
 
 
 
 
 
 
1553
  content = (message or "").strip()
1554
  if not content:
1555
  return "", "Type a message to start chatting."
@@ -1592,10 +2084,10 @@ class OmadaDemo:
1592
  else:
1593
  output_ids, step_snapshots = output_result, []
1594
 
1595
- decoded = tokenizer.batch_decode(
1596
- output_ids[:, input_ids.shape[1]:],
1597
- skip_special_tokens=True,
1598
- )[0]
1599
  return decoded.strip(), "Assistant reply generated."
1600
 
1601
  def run_chat_stream(
@@ -1609,6 +2101,12 @@ class OmadaDemo:
1609
  max_tokens_per_step: int = 0,
1610
  update_every: int = 25,
1611
  ):
 
 
 
 
 
 
1612
  content = (message or "").strip()
1613
  if not content:
1614
  yield "", "Type a message to start chatting.", True
@@ -1655,12 +2153,12 @@ class OmadaDemo:
1655
  if len(step_snapshots) > max_step_snapshots:
1656
  step_snapshots = step_snapshots[-max_step_snapshots:]
1657
  step_counter += 1
 
 
1658
  if update_every > 1 and step_counter % update_every != 0:
1659
  continue
1660
- decoded = tokenizer.batch_decode(
1661
- snapshot[:, prompt_len:],
1662
- skip_special_tokens=True,
1663
- )[0].strip()
1664
  steps_html = self._render_diffusion_steps(
1665
  step_snapshots,
1666
  max_tokens_per_step=max_tokens_per_step,
@@ -1671,10 +2169,11 @@ class OmadaDemo:
1671
  yield "", "Assistant reply generated.", True
1672
  return
1673
 
1674
- decoded = tokenizer.batch_decode(
1675
- latest_ids[:, input_ids.shape[1]:],
1676
- skip_special_tokens=True,
1677
- )[0].strip()
 
1678
  step_snapshots = [latest_ids[0, input_ids.shape[1]:].detach().cpu()]
1679
  steps_html = self._render_diffusion_steps(
1680
  step_snapshots,
@@ -1682,6 +2181,36 @@ class OmadaDemo:
1682
  )
1683
  yield self._format_chat_output(decoded, steps_html), "Assistant reply generated.", True
1684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1685
  # ------------------------------------------------------------------
1686
  # General MMU (N Images → Text)
1687
  # ------------------------------------------------------------------
@@ -1694,6 +2223,12 @@ class OmadaDemo:
1694
  block_length: int,
1695
  temperature: float,
1696
  ) -> Tuple[str, str]:
 
 
 
 
 
 
1697
  """
1698
  MMU demo now consumes exactly one image. If callers pass a list (for
1699
  backwards compatibility), we keep only the first valid image.
@@ -1762,49 +2297,78 @@ class OmadaDemo:
1762
  )
1763
 
1764
  def _format_chat_output(self, text: str, steps_html: str = "") -> str:
1765
- """Wrap <think> blocks in a collapsible section for chat UI."""
1766
- safe_text = text or ""
1767
- start_tag = "<think>"
1768
- end_tag = "</think>"
1769
- out = []
1770
- idx = 0
1771
- injected_steps = False
1772
- while True:
1773
- start = safe_text.find(start_tag, idx)
1774
- if start == -1:
1775
- tail = safe_text[idx:]
1776
- if tail:
1777
- out.append(html.escape(tail).replace("\n", "<br>"))
1778
- break
1779
- prefix = safe_text[idx:start]
 
 
 
 
 
 
 
1780
  if prefix:
1781
- out.append(html.escape(prefix).replace("\n", "<br>"))
1782
- end = safe_text.find(end_tag, start + len(start_tag))
1783
- if end == -1:
1784
- out.append(html.escape(safe_text[start:]).replace("\n", "<br>"))
1785
- break
1786
- think_body = safe_text[start + len(start_tag):end].strip()
1787
- think_block = html.escape(think_body).replace("\n", "<br>")
1788
- if steps_html and not injected_steps:
1789
- think_block = f"{think_block}{steps_html}"
1790
- injected_steps = True
1791
- out.append(
1792
- "\n<details><summary>Show think</summary>\n\n"
1793
- f"{think_block}\n"
1794
- "</details>\n"
1795
- )
1796
- idx = end + len(end_tag)
1797
- if steps_html and not injected_steps:
1798
- out.append(
1799
- "\n<details><summary>Show think</summary>\n\n"
1800
- f"{steps_html}\n"
1801
- "</details>\n"
1802
- )
1803
- body = "".join(out).strip()
1804
  if not body:
1805
  return ""
1806
  return f"<div class='omada-response-block'>{body}</div>"
1807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1808
  def _render_diffusion_steps(
1809
  self,
1810
  step_snapshots: List[torch.Tensor],
@@ -2096,6 +2660,10 @@ class OmadaDemo:
2096
  mmu_input_ids = mmu_input_ids.to(self.device)
2097
  prompt_masks = prompt_masks.to(self.device)
2098
 
 
 
 
 
2099
  answer_tokens = int((prompt_masks == 0).sum(dim=1).max().item())
2100
  default_budget = max(1, answer_tokens) if answer_tokens > 0 else min(self.max_text_len, 256)
2101
  gen_tokens = int(max_new_tokens or default_budget)
@@ -2108,14 +2676,7 @@ class OmadaDemo:
2108
  )
2109
  temperature = float(temperature if temperature is not None else 0.7)
2110
 
2111
- if gen_tokens > 0:
2112
- mask_block = torch.full(
2113
- (mmu_input_ids.size(0), gen_tokens),
2114
- self.mask_token_id,
2115
- dtype=torch.long,
2116
- device=self.device,
2117
- )
2118
- mmu_input_ids = torch.cat([mmu_input_ids, mask_block], dim=1)
2119
 
2120
  with torch.no_grad():
2121
  output_ids = self.model.mmu_generate(
@@ -2128,14 +2689,57 @@ class OmadaDemo:
2128
  mask_id=self.mask_token_id,
2129
  )
2130
 
2131
- decoded = self.uni_prompting.text_tokenizer.batch_decode(
2132
- output_ids[:, mmu_input_ids.shape[1]:],
2133
- skip_special_tokens=True,
2134
- )[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
2135
  if not decoded:
2136
  return "", "MMU response was empty."
2137
  return decoded, "Image understanding succeeded."
2138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2139
  def _generate_text_tokens(
2140
  self,
2141
  prompt_ids: torch.Tensor,
@@ -2213,6 +2817,13 @@ class OmadaDemo:
2213
 
2214
  transfer_index = torch.zeros_like(work, dtype=torch.bool)
2215
  for b in range(batch_size):
 
 
 
 
 
 
 
2216
  k = int(num_transfer_tokens[b, inner_step].item())
2217
  if k <= 0:
2218
  continue
@@ -2223,6 +2834,15 @@ class OmadaDemo:
2223
  if return_steps and batch_size > 0:
2224
  step_snapshots.append(work[0, prompt_len:].detach().cpu())
2225
 
 
 
 
 
 
 
 
 
 
2226
  if return_steps:
2227
  return work, step_snapshots
2228
  return work
@@ -2303,6 +2923,13 @@ class OmadaDemo:
2303
 
2304
  transfer_index = torch.zeros_like(work, dtype=torch.bool)
2305
  for b in range(batch_size):
 
 
 
 
 
 
 
2306
  k = int(num_transfer_tokens[b, inner_step].item())
2307
  if k <= 0:
2308
  continue
@@ -2312,6 +2939,14 @@ class OmadaDemo:
2312
  work[transfer_index] = x0[transfer_index]
2313
  yield work.clone(), prompt_len
2314
 
 
 
 
 
 
 
 
 
2315
  def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optional[int]):
2316
  theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
2317
  with gr.Blocks(title="AIDAS Lab @ SNU", css=CUSTOM_CSS, theme=theme, js=FORCE_LIGHT_MODE_JS) as demo:
@@ -2750,15 +3385,23 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2750
  response = ""
2751
 
2752
  if mode == "Text":
2753
- reply, status = app.run_chat(
 
 
 
 
 
2754
  message,
2755
  chat_max_tokens,
2756
  chat_steps,
2757
  chat_block,
2758
  chat_temperature,
2759
- )
2760
- response = _render_text_message(status, reply)
2761
- display_user_raw = message or "[Text request]"
 
 
 
2762
  elif mode == "Text → Speech":
2763
  if not message:
2764
  status = "Please type some text for speech synthesis."
 
25
  import os
26
  import math
27
  import random
28
+ import re
29
  import sys
30
  import tempfile
31
  import wave
 
208
  box-shadow: none;
209
  border: 1px solid var(--omada-border);
210
  background: #ffffff;
211
+ overflow-y: auto !important;
212
+ }
213
+ .omada-chat-column .gradio-chatbot .wrap,
214
+ .omada-chat-column .gradio-chatbot .message-wrap {
215
+ overflow-y: auto !important;
216
+ }
217
+ .omada-chat-column .gradio-chatbot .message {
218
+ overflow-wrap: anywhere;
219
  }
220
  .omada-controls {
221
  gap: 16px !important;
 
733
  import gradio as gr
734
  import numpy as np
735
  import torch
736
+ from omegaconf import DictConfig, ListConfig, OmegaConf
737
  from PIL import Image
738
 
739
  from inference.common import (
 
754
 
755
  if cfg is None:
756
  return default
757
+ if isinstance(cfg, (list, tuple, ListConfig)):
758
+ for item in cfg:
759
+ value = _cfg_get(item, key, None)
760
+ if value is not None:
761
+ return value
762
+ return default
763
  if isinstance(cfg, dict):
764
  return cfg.get(key, default)
765
  try:
 
912
  )
913
  )
914
  self.max_text_len = int(getattr(self.train_cfg.dataset.preprocessing, "max_seq_length", 1024))
915
+ self.max_seq_mmu = int(
916
+ getattr(
917
+ self.train_cfg.dataset.preprocessing,
918
+ "max_seq_length_mmu",
919
+ self.max_text_len,
920
+ )
921
+ )
922
+ self.chat_mask_surface_token = "<mdm_mask>"
923
 
924
  model_seq_len = getattr(self.model.config, "num_vq_tokens", None)
925
  if model_seq_len is None:
 
936
  self.noise_type = _cfg_get(training_cfg, "noise_type", "mask")
937
  self.predict_all_tokens = bool(_cfg_get(training_cfg, "predict_all_tokens", False))
938
  self.t2i_default_timesteps = int(_cfg_get(training_cfg, "generation_timesteps", 20))
939
+ # Align i2i defaults with eval (use generation_timesteps unless explicitly set).
940
+ self.i2i_default_timesteps = int(_cfg_get(training_cfg, "generation_timesteps", 20))
941
+
942
+ # Force demo to use eval-matched defaults unless explicitly disabled.
943
+ self.force_eval_settings = str(os.getenv("FORCE_EVAL_SETTINGS", "1")).lower() not in {"0", "false", "no"}
944
+ self.eval_defaults = {
945
+ "t2i": {
946
+ "timesteps": 16,
947
+ "guidance_scale": 2.5,
948
+ "temperature": 0.0,
949
+ },
950
+ "i2i": {
951
+ "timesteps": 64,
952
+ "guidance_scale": 2.5,
953
+ "temperature": 0.0,
954
+ },
955
+ # Match defaults used in inference scripts for eval parity.
956
+ "t2s": {
957
+ "steps": 128,
958
+ "block_length": 128,
959
+ "max_new_tokens": int(self.max_audio_len_short),
960
+ "temperature": 0.0,
961
+ "guidance_scale": float(_cfg_get(training_cfg, "guidance_scale", 3.5)),
962
+ },
963
+ "s2t": {
964
+ "steps": 128,
965
+ "block_length": 16,
966
+ "max_new_tokens": 128,
967
+ "remasking": "low_confidence",
968
+ },
969
+ "v2t": {
970
+ "steps": 256,
971
+ "block_length": 16,
972
+ "max_new_tokens": 256,
973
+ },
974
+ # LLM eval uses gen_length=steps=block_length=16
975
+ "chat": {
976
+ "steps": 512,
977
+ "block_length": 16,
978
+ "max_new_tokens": 512,
979
+ "temperature": 0.0,
980
+ },
981
+ "mmu": {
982
+ "steps": 128,
983
+ "block_length": 16,
984
+ "max_new_tokens": 128,
985
+ "temperature": 0.0,
986
+ },
987
+ }
988
 
989
  self.audio_condition_default = "gender-female_emotion-neutral_speed-normal_pitch-normal"
990
  style_map = getattr(getattr(self.vq_audio, "config", None), "u2s_style2idx", None)
 
1013
  speed_choice: str,
1014
  pitch_choice: str,
1015
  ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
1016
+ if self.force_eval_settings:
1017
+ d = self.eval_defaults["t2s"]
1018
+ max_new_tokens = int(d["max_new_tokens"])
1019
+ steps = int(d["steps"])
1020
+ block_length = int(d["block_length"])
1021
+ temperature = float(d["temperature"])
1022
+ cfg_scale = float(d["guidance_scale"])
1023
 
1024
  if text is None or not text.strip():
1025
  return None, "Please provide text to synthesize."
 
1091
  status = f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})."
1092
  return audio, status
1093
 
1094
+ def run_t2s_stream(
1095
+ self,
1096
+ text: str,
1097
+ max_new_tokens: int,
1098
+ steps: int,
1099
+ block_length: int,
1100
+ temperature: float,
1101
+ cfg_scale: float,
1102
+ gender_choice: str,
1103
+ emotion_choice: str,
1104
+ speed_choice: str,
1105
+ pitch_choice: str,
1106
+ update_every: Optional[int] = None,
1107
+ ):
1108
+ if self.force_eval_settings:
1109
+ d = self.eval_defaults["t2s"]
1110
+ max_new_tokens = int(d["max_new_tokens"])
1111
+ steps = int(d["steps"])
1112
+ block_length = int(d["block_length"])
1113
+ temperature = float(d["temperature"])
1114
+ cfg_scale = float(d["guidance_scale"])
1115
+
1116
+ if text is None or not text.strip():
1117
+ yield None, "Please provide text to synthesize."
1118
+ return
1119
+
1120
+ speech_len, steps, block_length = self._prepare_block_schedule(
1121
+ max_new_tokens,
1122
+ steps,
1123
+ block_length,
1124
+ )
1125
+
1126
+ gender = self._resolve_choice(gender_choice, self.genders)
1127
+ emotion = self._resolve_choice(emotion_choice, self.emotions)
1128
+ speed = self._resolve_choice(speed_choice, self.speeds)
1129
+ pitch = self._resolve_choice(pitch_choice, self.pitches)
1130
+
1131
+ text = text.strip().upper()
1132
+ prompt = (
1133
+ "<|start_header_id|>user<|end_header_id|>\n"
1134
+ f"{random.choice(T2S_INSTRUCTION)}\n{text}"
1135
+ "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
1136
+ )
1137
+
1138
+ audio_tokens = torch.full(
1139
+ (1, speech_len),
1140
+ fill_value=self.mask_token_id,
1141
+ dtype=torch.long,
1142
+ device=self.device,
1143
+ )
1144
+
1145
+ input_ids, attention_mask = self.uni_prompting(([prompt], audio_tokens), "t2s_gen")
1146
+ input_ids = input_ids.to(self.device)
1147
+ attention_mask = attention_mask.to(self.device)
1148
+
1149
+ condition = f"gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}"
1150
+
1151
+ last_audio = None
1152
+ accumulated = None
1153
+ prev_len = 0
1154
+ for rel_list, step_status in self.model.t2s_generate_mmu_like_stream(
1155
+ input_ids=input_ids,
1156
+ max_new_tokens=int(speech_len),
1157
+ steps=int(steps),
1158
+ block_length=int(block_length),
1159
+ temperature=float(temperature),
1160
+ cfg_scale=float(cfg_scale),
1161
+ mask_token_id=self.mask_token_id,
1162
+ attention_mask=attention_mask,
1163
+ uni_prompting=self.uni_prompting,
1164
+ codebook_size=self.codebook_size,
1165
+ update_every=update_every,
1166
+ ):
1167
+ if not rel_list:
1168
+ continue
1169
+ rel = rel_list[0]
1170
+ if isinstance(rel, torch.Tensor):
1171
+ rel_ids = rel.detach().cpu().tolist()
1172
+ else:
1173
+ rel_ids = list(rel)
1174
+ if not rel_ids:
1175
+ continue
1176
+ if prev_len >= len(rel_ids):
1177
+ continue
1178
+ new_ids = rel_ids[prev_len:]
1179
+ prev_len = len(rel_ids)
1180
+ speech_units = "".join(f"<|speech_{sid}|>" for sid in new_ids)
1181
+ wav = self.vq_audio.decode(
1182
+ speech_units,
1183
+ condition=condition,
1184
+ output_wav_file=os.path.join("/tmp", "omada_t2s_stream.wav"),
1185
+ )
1186
+ chunk = wav.astype(np.float32)
1187
+ if accumulated is None:
1188
+ accumulated = chunk
1189
+ else:
1190
+ accumulated = np.concatenate([accumulated, chunk], axis=0)
1191
+ audio = (self.sample_rate, accumulated)
1192
+ last_audio = audio
1193
+ yield audio, f"{step_status} ({gender}/{emotion}/{speed}/{pitch})"
1194
+
1195
+ if last_audio is not None:
1196
+ yield last_audio, f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})."
1197
+
1198
  # ------------------------------------------------------------------
1199
  # Speech-to-Speech
1200
  # ------------------------------------------------------------------
 
1316
  max_new_tokens: int,
1317
  remasking: str,
1318
  ) -> Tuple[str, str]:
1319
+ if self.force_eval_settings:
1320
+ d = self.eval_defaults["s2t"]
1321
+ steps = int(d["steps"])
1322
+ block_length = int(d["block_length"])
1323
+ max_new_tokens = int(d["max_new_tokens"])
1324
+ remasking = str(d["remasking"])
1325
 
1326
  if not audio_path:
1327
  return "", "Please upload an audio file first."
1328
 
1329
+ remasking = str(remasking).lower()
1330
+ if remasking == "full":
1331
+ remasking = "low_confidence"
1332
+
1333
  tokens = self.vq_audio.encode(audio_path).to(self.device)
1334
  offset = self.text_vocab_size + self.speech_codebook
1335
  tokens = tokens + offset
 
1367
  remasking=str(remasking),
1368
  )
1369
 
1370
+ decoded = self._decode_chat_tokens(
1371
+ output_ids[0, input_ids.shape[1]:],
1372
+ self.uni_prompting.text_tokenizer,
1373
+ ).strip()
1374
+ decoded = self._postprocess_chat_text(decoded)
1375
+ decoded = self._strip_trailing_masks(decoded)
1376
+ decoded = self._remove_mask_artifacts(decoded)
1377
  return decoded.strip(), "Transcription generated successfully."
1378
 
1379
+ def run_s2t_stream(
1380
+ self,
1381
+ audio_path: Optional[str],
1382
+ steps: int,
1383
+ block_length: int,
1384
+ max_new_tokens: int,
1385
+ remasking: str,
1386
+ update_every: int = 32,
1387
+ ):
1388
+ if self.force_eval_settings:
1389
+ d = self.eval_defaults["s2t"]
1390
+ steps = int(d["steps"])
1391
+ block_length = int(d["block_length"])
1392
+ max_new_tokens = int(d["max_new_tokens"])
1393
+ remasking = str(d["remasking"])
1394
+
1395
+ if not audio_path:
1396
+ yield "", "Please upload an audio file first."
1397
+ return
1398
+
1399
+ remasking = str(remasking).lower()
1400
+ if remasking == "full":
1401
+ remasking = "low_confidence"
1402
+
1403
+ tokens = self.vq_audio.encode(audio_path).to(self.device)
1404
+ offset = self.text_vocab_size + self.speech_codebook
1405
+ tokens = tokens + offset
1406
+
1407
+ spt = self.uni_prompting.sptids_dict
1408
+ audio_block = torch.cat(
1409
+ [
1410
+ spt['<|s2t|>'].to(self.device).unsqueeze(0),
1411
+ spt['<|soa|>'].to(self.device).unsqueeze(0),
1412
+ tokens.to(self.device),
1413
+ spt['<|eoa|>'].to(self.device).unsqueeze(0),
1414
+ ],
1415
+ dim=1,
1416
+ )
1417
+
1418
+ prompt_text = random.choice(S2T_INSTRUCTION)
1419
+ chat_prompt = (
1420
+ "<|start_header_id|>user<|end_header_id|>\n"
1421
+ f"{prompt_text}"
1422
+ "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
1423
+ )
1424
+ prompt_tensor = self.uni_prompting.text_tokenizer(
1425
+ chat_prompt,
1426
+ return_tensors="pt",
1427
+ ).input_ids.to(self.device)
1428
+
1429
+ input_ids = torch.cat([audio_block, prompt_tensor], dim=1)
1430
+
1431
+ step_counter = 0
1432
+ latest_decoded = ""
1433
+ for snapshot, prompt_len in self._generate_text_tokens_stream(
1434
+ input_ids,
1435
+ max_new_tokens=int(max_new_tokens),
1436
+ steps=int(steps),
1437
+ block_length=int(block_length),
1438
+ temperature=1.0,
1439
+ cfg_scale=0.0,
1440
+ attention_mask=None,
1441
+ remasking=remasking,
1442
+ ):
1443
+ step_counter += 1
1444
+ if update_every > 1 and step_counter % update_every != 0:
1445
+ continue
1446
+ decoded = self._decode_chat_tokens(
1447
+ snapshot[0, prompt_len:],
1448
+ self.uni_prompting.text_tokenizer,
1449
+ ).strip()
1450
+ decoded = self._postprocess_chat_text(decoded)
1451
+ latest_decoded = decoded
1452
+ yield decoded, "Generating..."
1453
+
1454
+ finalized = self._remove_mask_artifacts(self._strip_trailing_masks(latest_decoded))
1455
+ yield finalized.strip(), "Transcription generated successfully."
1456
+
1457
  # ------------------------------------------------------------------
1458
  # Video-to-Text
1459
  # ------------------------------------------------------------------
 
1464
  block_length: int,
1465
  max_new_tokens: int,
1466
  ) -> Tuple[str, str]:
1467
+ if self.force_eval_settings:
1468
+ d = self.eval_defaults["v2t"]
1469
+ steps = int(d["steps"])
1470
+ block_length = int(d["block_length"])
1471
+ max_new_tokens = int(d["max_new_tokens"])
1472
 
1473
  resolved_path, converted = self._prepare_video_path(video_path)
1474
  if not resolved_path:
 
1518
  raw_all = self.uni_prompting.text_tokenizer.decode(output_ids[0], skip_special_tokens=False)
1519
  print("[V2T] RAW ALL:", repr(raw_all))
1520
 
1521
+ decoded = self._decode_chat_tokens(
1522
+ output_ids[0, input_ids.shape[1]:],
1523
+ self.uni_prompting.text_tokenizer,
1524
+ )
1525
  print("[V2T] DECODED SLICE:", repr(decoded))
1526
+ decoded = self._postprocess_chat_text(decoded)
1527
+ decoded = self._strip_trailing_masks(decoded)
1528
+ decoded = self._remove_mask_artifacts(decoded)
1529
  return decoded.strip(), "Video caption generated successfully."
1530
 
1531
+ def run_v2t_stream(
1532
+ self,
1533
+ video_path: Any,
1534
+ steps: int,
1535
+ block_length: int,
1536
+ max_new_tokens: int,
1537
+ update_every: int = 32,
1538
+ ):
1539
+ if self.force_eval_settings:
1540
+ d = self.eval_defaults["v2t"]
1541
+ steps = int(d["steps"])
1542
+ block_length = int(d["block_length"])
1543
+ max_new_tokens = int(d["max_new_tokens"])
1544
+
1545
+ resolved_path, converted = self._prepare_video_path(video_path)
1546
+ if not resolved_path:
1547
+ yield "", "Please upload or record a video first."
1548
+ return
1549
+
1550
+ try:
1551
+ video_tokens = self._extract_video_tokens(resolved_path, num_frame=self.num_frames_v2t)
1552
+ except Exception as exc:
1553
+ yield "", f"Failed to process video: {exc}"
1554
+ return
1555
+
1556
+ prompt_text = random.choice(V2T_INSTRUCTION)
1557
+ prompt = (
1558
+ "<|start_header_id|>user<|end_header_id|>\n"
1559
+ f"{prompt_text}"
1560
+ "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
1561
+ )
1562
+ prompt_tensor = self.uni_prompting.text_tokenizer(
1563
+ prompt,
1564
+ return_tensors="pt",
1565
+ ).input_ids.to(self.device)
1566
+
1567
+ input_ids = torch.cat([video_tokens, prompt_tensor], dim=1)
1568
+
1569
+ step_counter = 0
1570
+ latest_decoded = ""
1571
+ for snapshot, prompt_len in self._generate_text_tokens_stream(
1572
+ input_ids,
1573
+ max_new_tokens=int(max_new_tokens),
1574
+ steps=int(steps),
1575
+ block_length=int(block_length),
1576
+ temperature=1.0,
1577
+ cfg_scale=0.0,
1578
+ attention_mask=None,
1579
+ remasking="low_confidence",
1580
+ ):
1581
+ step_counter += 1
1582
+ if update_every > 1 and step_counter % update_every != 0:
1583
+ continue
1584
+ decoded = self._decode_chat_tokens(
1585
+ snapshot[0, prompt_len:],
1586
+ self.uni_prompting.text_tokenizer,
1587
+ ).strip()
1588
+ decoded = self._postprocess_chat_text(decoded)
1589
+ latest_decoded = decoded
1590
+ yield decoded, "Generating..."
1591
+
1592
+ finalized = self._remove_mask_artifacts(self._strip_trailing_masks(latest_decoded))
1593
+ yield finalized.strip(), "Video caption generated successfully."
1594
+
1595
  # ------------------------------------------------------------------
1596
  # Text-to-Image
1597
  # ------------------------------------------------------------------
 
1602
  temperature: float,
1603
  guidance_scale: float,
1604
  ) -> Tuple[Optional[Image.Image], str]:
1605
+ if self.force_eval_settings:
1606
+ d = self.eval_defaults["t2i"]
1607
+ timesteps = int(d["timesteps"])
1608
+ temperature = float(d["temperature"])
1609
+ guidance_scale = float(d["guidance_scale"])
1610
  if not prompt or not prompt.strip():
1611
  return None, "Please provide a text prompt."
1612
 
 
1655
  image = self._decode_image_tokens(gen_tokens[0])
1656
  return image, "Image generated from text prompt."
1657
 
1658
+ def run_t2i_stream(
1659
+ self,
1660
+ prompt: str,
1661
+ timesteps: int,
1662
+ temperature: float,
1663
+ guidance_scale: float,
1664
+ update_every: int = 2,
1665
+ ):
1666
+ if self.force_eval_settings:
1667
+ d = self.eval_defaults["t2i"]
1668
+ timesteps = int(d["timesteps"])
1669
+ temperature = float(d["temperature"])
1670
+ guidance_scale = float(d["guidance_scale"])
1671
+ if not prompt or not prompt.strip():
1672
+ yield None, "Please provide a text prompt."
1673
+ return
1674
+
1675
+ image_seq_len = 1024
1676
+ image_tokens = torch.full(
1677
+ (1, image_seq_len),
1678
+ self.mask_token_id,
1679
+ dtype=torch.long,
1680
+ device=self.device,
1681
+ )
1682
+ input_ids, attention_mask = self.uni_prompting(([prompt.strip()], image_tokens), "t2i_gen")
1683
+ input_ids = input_ids.to(self.device)
1684
+ attention_mask = attention_mask.to(self.device)
1685
+
1686
+ if guidance_scale > 0:
1687
+ uncond_ids, uncond_mask = self.uni_prompting(([""], image_tokens.clone()), "t2i_gen")
1688
+ uncond_ids = uncond_ids.to(self.device)
1689
+ uncond_mask = uncond_mask.to(self.device)
1690
+ else:
1691
+ uncond_ids = None
1692
+ uncond_mask = None
1693
+
1694
+ step_count = 0
1695
+ for pil_image, status in self.model.t2i_generate_decoding_stepwise(
1696
+ input_ids=input_ids,
1697
+ uncond_input_ids=uncond_ids,
1698
+ attention_mask=attention_mask,
1699
+ uncond_attention_mask=uncond_mask,
1700
+ guidance_scale=float(guidance_scale),
1701
+ temperature=float(temperature),
1702
+ timesteps=int(timesteps),
1703
+ noise_schedule=self.mask_schedule,
1704
+ noise_type=self.noise_type,
1705
+ predict_all_tokens=self.predict_all_tokens,
1706
+ seq_len=image_seq_len,
1707
+ mask_token_id=self.mask_token_id,
1708
+ codebook_size=self.codebook_size,
1709
+ uni_prompting=self.uni_prompting,
1710
+ config=self.train_cfg,
1711
+ vq_model=self.vq_image,
1712
+ ):
1713
+ step_count += 1
1714
+ if update_every <= 1 or step_count % update_every == 0 or step_count == int(timesteps):
1715
+ yield pil_image, status
1716
+
1717
  # ------------------------------------------------------------------
1718
  # Image-to-Image Editing
1719
  # ------------------------------------------------------------------
 
1725
  temperature: float,
1726
  guidance_scale: float,
1727
  ) -> Tuple[Optional[Image.Image], str]:
1728
+ if self.force_eval_settings:
1729
+ d = self.eval_defaults["i2i"]
1730
+ timesteps = int(d["timesteps"])
1731
+ temperature = float(d["temperature"])
1732
+ guidance_scale = float(d["guidance_scale"])
1733
  if source_image is None:
1734
  return None, "Please upload a reference image."
1735
  if not instruction or not instruction.strip():
 
1790
  image = self._decode_image_tokens(gen_tokens[0])
1791
  return image, "Edited image generated."
1792
 
1793
+ def run_i2i_stream(
1794
+ self,
1795
+ instruction: str,
1796
+ source_image: Optional[Image.Image],
1797
+ timesteps: int,
1798
+ temperature: float,
1799
+ guidance_scale: float,
1800
+ update_every: int = 2,
1801
+ ):
1802
+ if self.force_eval_settings:
1803
+ d = self.eval_defaults["i2i"]
1804
+ timesteps = int(d["timesteps"])
1805
+ temperature = float(d["temperature"])
1806
+ guidance_scale = float(d["guidance_scale"])
1807
+ if source_image is None:
1808
+ yield None, "Please upload a reference image."
1809
+ return
1810
+ if not instruction or not instruction.strip():
1811
+ yield None, "Provide editing instructions for the image."
1812
+ return
1813
+
1814
+ try:
1815
+ input_tokens = self._prepare_image_tokens(source_image, resolution=self.image_resolution)
1816
+ except Exception as exc:
1817
+ yield None, f"Failed to encode input image: {exc}"
1818
+ return
1819
+
1820
+ seq_len = int(input_tokens.shape[-1])
1821
+ output_placeholder = torch.full(
1822
+ (1, seq_len),
1823
+ self.mask_token_id,
1824
+ dtype=torch.long,
1825
+ device=self.device,
1826
+ )
1827
+
1828
+ input_ids, attention_mask = self.uni_prompting(
1829
+ ([instruction.strip()], input_tokens, output_placeholder),
1830
+ "i2i_gen",
1831
+ )
1832
+ input_ids = input_ids.to(self.device)
1833
+ attention_mask = attention_mask.to(self.device)
1834
+
1835
+ uncond_ids = None
1836
+ uncond_attn = None
1837
+ if guidance_scale > 0:
1838
+ uncond_ids, uncond_attn = self.uni_prompting(
1839
+ ([""], input_tokens.clone(), torch.full_like(output_placeholder, self.mask_token_id)),
1840
+ "i2i_gen",
1841
+ )
1842
+ uncond_ids = uncond_ids.to(self.device)
1843
+ uncond_attn = uncond_attn.to(self.device)
1844
+
1845
+ step_count = 0
1846
+ for pil_image, status in self.model.i2i_generate_decoding_stepwise(
1847
+ input_ids=input_ids,
1848
+ uncond_input_ids=uncond_ids,
1849
+ attention_mask=attention_mask,
1850
+ uncond_attention_mask=uncond_attn,
1851
+ temperature=float(temperature),
1852
+ timesteps=int(timesteps),
1853
+ guidance_scale=float(guidance_scale),
1854
+ noise_schedule=self.mask_schedule,
1855
+ noise_type=self.noise_type,
1856
+ seq_len=seq_len,
1857
+ mask_token_id=self.mask_token_id,
1858
+ codebook_size=self.codebook_size,
1859
+ uni_prompting=self.uni_prompting,
1860
+ config=self.train_cfg,
1861
+ vq_model=self.vq_image,
1862
+ ):
1863
+ step_count += 1
1864
+ if update_every <= 1 or step_count % update_every == 0 or step_count == int(timesteps):
1865
+ yield pil_image, status
1866
+
1867
  # ------------------------------------------------------------------
1868
  # Video-to-Speech
1869
  # ------------------------------------------------------------------
 
2036
  block_length: int,
2037
  temperature: float,
2038
  ) -> Tuple[str, str]:
2039
+ if self.force_eval_settings:
2040
+ d = self.eval_defaults["chat"]
2041
+ max_new_tokens = int(d["max_new_tokens"])
2042
+ steps = int(d["steps"])
2043
+ block_length = int(d["block_length"])
2044
+ temperature = float(d["temperature"])
2045
  content = (message or "").strip()
2046
  if not content:
2047
  return "", "Type a message to start chatting."
 
2084
  else:
2085
  output_ids, step_snapshots = output_result, []
2086
 
2087
+ decoded = self._decode_chat_tokens(output_ids[0, input_ids.shape[1]:], tokenizer)
2088
+ decoded = self._postprocess_chat_text(decoded)
2089
+ decoded = self._strip_trailing_masks(decoded)
2090
+ decoded = self._remove_mask_artifacts(decoded)
2091
  return decoded.strip(), "Assistant reply generated."
2092
 
2093
  def run_chat_stream(
 
2101
  max_tokens_per_step: int = 0,
2102
  update_every: int = 25,
2103
  ):
2104
+ if self.force_eval_settings:
2105
+ d = self.eval_defaults["chat"]
2106
+ max_new_tokens = int(d["max_new_tokens"])
2107
+ steps = int(d["steps"])
2108
+ block_length = int(d["block_length"])
2109
+ temperature = float(d["temperature"])
2110
  content = (message or "").strip()
2111
  if not content:
2112
  yield "", "Type a message to start chatting.", True
 
2153
  if len(step_snapshots) > max_step_snapshots:
2154
  step_snapshots = step_snapshots[-max_step_snapshots:]
2155
  step_counter += 1
2156
+ raw_decoded = self._decode_chat_tokens(snapshot[0, prompt_len:], tokenizer)
2157
+ print(f"[CHAT_STREAM][step={step_counter}] raw_decoded={raw_decoded!r}", flush=True)
2158
  if update_every > 1 and step_counter % update_every != 0:
2159
  continue
2160
+ decoded = raw_decoded.strip()
2161
+ decoded = self._postprocess_chat_text(decoded)
 
 
2162
  steps_html = self._render_diffusion_steps(
2163
  step_snapshots,
2164
  max_tokens_per_step=max_tokens_per_step,
 
2169
  yield "", "Assistant reply generated.", True
2170
  return
2171
 
2172
+ decoded = self._decode_chat_tokens(latest_ids[0, input_ids.shape[1]:], tokenizer).strip()
2173
+ print(f"[CHAT_STREAM][final] raw_decoded={decoded!r}", flush=True)
2174
+ decoded = self._postprocess_chat_text(decoded)
2175
+ decoded = self._strip_trailing_masks(decoded)
2176
+ decoded = self._remove_mask_artifacts(decoded)
2177
  step_snapshots = [latest_ids[0, input_ids.shape[1]:].detach().cpu()]
2178
  steps_html = self._render_diffusion_steps(
2179
  step_snapshots,
 
2181
  )
2182
  yield self._format_chat_output(decoded, steps_html), "Assistant reply generated.", True
2183
 
2184
+ def _decode_chat_tokens(self, token_ids: torch.Tensor, tokenizer) -> str:
2185
+ """Decode chat tokens while preserving mask placeholders for UI."""
2186
+ ids = token_ids.detach().cpu().tolist()
2187
+ pieces = []
2188
+ run_ids = []
2189
+
2190
+ def _flush_run():
2191
+ nonlocal run_ids
2192
+ if not run_ids:
2193
+ return
2194
+ try:
2195
+ decoded_run = tokenizer.decode(
2196
+ run_ids,
2197
+ skip_special_tokens=False,
2198
+ clean_up_tokenization_spaces=False,
2199
+ )
2200
+ except Exception:
2201
+ decoded_run = ""
2202
+ pieces.append(decoded_run if decoded_run is not None else "")
2203
+ run_ids = []
2204
+
2205
+ for tid in ids:
2206
+ if int(tid) == int(self.mask_token_id):
2207
+ _flush_run()
2208
+ pieces.append(self.chat_mask_surface_token)
2209
+ else:
2210
+ run_ids.append(int(tid))
2211
+ _flush_run()
2212
+ return "".join(pieces)
2213
+
2214
  # ------------------------------------------------------------------
2215
  # General MMU (N Images → Text)
2216
  # ------------------------------------------------------------------
 
2223
  block_length: int,
2224
  temperature: float,
2225
  ) -> Tuple[str, str]:
2226
+ if self.force_eval_settings:
2227
+ d = self.eval_defaults["mmu"]
2228
+ max_new_tokens = int(d["max_new_tokens"])
2229
+ steps = int(d["steps"])
2230
+ block_length = int(d["block_length"])
2231
+ temperature = float(d["temperature"])
2232
  """
2233
  MMU demo now consumes exactly one image. If callers pass a list (for
2234
  backwards compatibility), we keep only the first valid image.
 
2297
  )
2298
 
2299
  def _format_chat_output(self, text: str, steps_html: str = "") -> str:
2300
+ """Render chat text inline; only mask tokens are shown as pills."""
2301
+ safe_text = (text or "").strip()
2302
+ if not safe_text:
2303
+ return ""
2304
+
2305
+ def _fmt_tokens(segment: str) -> str:
2306
+ mask_pat = r"(<MDM_MASK>|<\|?MDM_MASK[^>\s]*\|?>|\[MASK\]|<MASK>|MASK_TOKEN|<\|?MASK[^>\s]*\|?>)"
2307
+ pieces = re.split(mask_pat, segment, flags=re.IGNORECASE)
2308
+ out = []
2309
+ for p in pieces:
2310
+ if not p:
2311
+ continue
2312
+ if re.fullmatch(mask_pat, p, flags=re.IGNORECASE):
2313
+ out.append("<span class='omada-token omada-token-mask'>MASK</span>")
2314
+ else:
2315
+ out.append(html.escape(p).replace("\n", "<br>"))
2316
+ return "".join(out)
2317
+
2318
+ parts = []
2319
+ cursor = 0
2320
+ for m in re.finditer(r"<think>(.*?)</think>", safe_text, flags=re.DOTALL | re.IGNORECASE):
2321
+ prefix = safe_text[cursor:m.start()]
2322
  if prefix:
2323
+ parts.append(_fmt_tokens(prefix))
2324
+ think_body = m.group(1) or ""
2325
+ parts.append(f"<div class='omada-response-block'><b>Think:</b><br>{_fmt_tokens(think_body)}</div>")
2326
+ cursor = m.end()
2327
+ tail = safe_text[cursor:]
2328
+ if tail:
2329
+ parts.append(_fmt_tokens(tail))
2330
+ body = "".join(parts).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2331
  if not body:
2332
  return ""
2333
  return f"<div class='omada-response-block'>{body}</div>"
2334
 
2335
+ def _postprocess_chat_text(self, text: str) -> str:
2336
+ """Remove special/system tokens while keeping think content."""
2337
+ if not text:
2338
+ return ""
2339
+ cleaned = text
2340
+ # Normalize common malformed boundaries seen in streamed decode.
2341
+ cleaned = cleaned.replace("</thinkboxed", "</think>boxed")
2342
+ cleaned = cleaned.replace("<thinkboxed", "<think>boxed")
2343
+ # Keep think tags/content; only strip protocol-level special tokens.
2344
+ # Strip special tokens like <|eot_id|>, <|start_header_id|>, etc.
2345
+ cleaned = re.sub(r"<\|[^>]*\|>", "", cleaned)
2346
+ # Also remove truncated special tokens without the trailing ">".
2347
+ cleaned = re.sub(r"<\|[^\n]*\|", "", cleaned)
2348
+ cleaned = cleaned.replace("<|endoftext|>", "")
2349
+ cleaned = cleaned.replace("<|endoftext|", "")
2350
+ return cleaned.strip()
2351
+
2352
+ def _strip_trailing_masks(self, text: str) -> str:
2353
+ if not text:
2354
+ return ""
2355
+ mask_tail = (
2356
+ r"(?:\s*(?:\(|\[)?(?:<MDM_MASK>|<\|?MDM_MASK[^>\s]*\|?>|\[MASK\]|<MASK>|MASK_TOKEN)"
2357
+ r"(?:\)|\])?)+\s*$"
2358
+ )
2359
+ return re.sub(mask_tail, "", text, flags=re.IGNORECASE).rstrip()
2360
+
2361
+ def _remove_mask_artifacts(self, text: str) -> str:
2362
+ if not text:
2363
+ return ""
2364
+ mask_pat = r"(<MDM_MASK>|<\|?MDM_MASK[^>\s]*\|?>|\[MASK\]|<MASK>|MASK_TOKEN|<\|?MASK[^>\s]*\|?>)"
2365
+ cleaned = re.sub(mask_pat, " ", text, flags=re.IGNORECASE)
2366
+ # Some tokenizers may emit literal MASK text instead of the special token.
2367
+ cleaned = re.sub(r"MASK", " ", cleaned)
2368
+ cleaned = re.sub(r"\s+([,.;:!?])", r"\1", cleaned)
2369
+ cleaned = re.sub(r"\s{2,}", " ", cleaned).strip()
2370
+ return cleaned
2371
+
2372
  def _render_diffusion_steps(
2373
  self,
2374
  step_snapshots: List[torch.Tensor],
 
2660
  mmu_input_ids = mmu_input_ids.to(self.device)
2661
  prompt_masks = prompt_masks.to(self.device)
2662
 
2663
+ prompt_len = int(prompt_masks.sum(dim=1).max().item())
2664
+ if prompt_len > 0:
2665
+ mmu_input_ids = mmu_input_ids[:, :prompt_len]
2666
+
2667
  answer_tokens = int((prompt_masks == 0).sum(dim=1).max().item())
2668
  default_budget = max(1, answer_tokens) if answer_tokens > 0 else min(self.max_text_len, 256)
2669
  gen_tokens = int(max_new_tokens or default_budget)
 
2676
  )
2677
  temperature = float(temperature if temperature is not None else 0.7)
2678
 
2679
+ input_prompt_len = mmu_input_ids.shape[1]
 
 
 
 
 
 
 
2680
 
2681
  with torch.no_grad():
2682
  output_ids = self.model.mmu_generate(
 
2689
  mask_id=self.mask_token_id,
2690
  )
2691
 
2692
+ gen_slice = output_ids[0, input_prompt_len:]
2693
+ if gen_slice.numel() == 0:
2694
+ # Some checkpoints may return only generated ids (without prepended prompt).
2695
+ gen_slice = output_ids[0]
2696
+ decoded = self._decode_chat_tokens(
2697
+ gen_slice,
2698
+ self.uni_prompting.text_tokenizer,
2699
+ ).strip()
2700
+ print(
2701
+ f"[MMU] input_prompt_len={input_prompt_len} output_len={int(output_ids.shape[1])} "
2702
+ f"gen_len={int(gen_slice.numel())} first_ids={gen_slice[:16].detach().cpu().tolist()}",
2703
+ flush=True,
2704
+ )
2705
+ print(f"[MMU] raw_decoded={decoded!r}", flush=True)
2706
+ decoded = self._postprocess_chat_text(decoded)
2707
+ decoded = self._strip_trailing_masks(decoded)
2708
+ decoded = self._remove_mask_artifacts(decoded)
2709
  if not decoded:
2710
  return "", "MMU response was empty."
2711
  return decoded, "Image understanding succeeded."
2712
 
2713
+ def _finalize_generation_masks(
2714
+ self,
2715
+ work: torch.Tensor,
2716
+ prompt_len: int,
2717
+ attention_bias: Optional[torch.Tensor] = None,
2718
+ cfg_scale: float = 0.0,
2719
+ ) -> torch.Tensor:
2720
+ """Force-fill any residual masks after scheduled diffusion steps."""
2721
+ if work.numel() == 0:
2722
+ return work
2723
+ if not (work[:, prompt_len:] == self.mask_token_id).any():
2724
+ return work
2725
+
2726
+ with torch.no_grad():
2727
+ if cfg_scale > 0.0:
2728
+ prompt_index = work != self.mask_token_id
2729
+ unconditional = work.clone()
2730
+ unconditional[prompt_index] = self.mask_token_id
2731
+ model_input = torch.cat([work, unconditional], dim=0)
2732
+ logits = self.model(model_input).logits
2733
+ cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
2734
+ logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits)
2735
+ else:
2736
+ logits = self.model(work, attention_bias=attention_bias).logits
2737
+
2738
+ greedy = torch.argmax(logits, dim=-1)
2739
+ mask_idx = work == self.mask_token_id
2740
+ work = torch.where(mask_idx, greedy, work)
2741
+ return work
2742
+
2743
  def _generate_text_tokens(
2744
  self,
2745
  prompt_ids: torch.Tensor,
 
2817
 
2818
  transfer_index = torch.zeros_like(work, dtype=torch.bool)
2819
  for b in range(batch_size):
2820
+ block_mask_now = torch.where(work[b, block_slice] == self.mask_token_id)[0]
2821
+ if inner_step == inner_steps - 1:
2822
+ # Guarantee: no masks remain in this block after its last step.
2823
+ if block_mask_now.numel() > 0:
2824
+ transfer_index[b, prompt_len + block_idx * block_length + block_mask_now] = True
2825
+ continue
2826
+
2827
  k = int(num_transfer_tokens[b, inner_step].item())
2828
  if k <= 0:
2829
  continue
 
2834
  if return_steps and batch_size > 0:
2835
  step_snapshots.append(work[0, prompt_len:].detach().cpu())
2836
 
2837
+ work = self._finalize_generation_masks(
2838
+ work,
2839
+ prompt_len=prompt_len,
2840
+ attention_bias=attention_bias,
2841
+ cfg_scale=cfg_scale,
2842
+ )
2843
+ if return_steps and batch_size > 0:
2844
+ step_snapshots.append(work[0, prompt_len:].detach().cpu())
2845
+
2846
  if return_steps:
2847
  return work, step_snapshots
2848
  return work
 
2923
 
2924
  transfer_index = torch.zeros_like(work, dtype=torch.bool)
2925
  for b in range(batch_size):
2926
+ block_mask_now = torch.where(work[b, block_slice] == self.mask_token_id)[0]
2927
+ if inner_step == inner_steps - 1:
2928
+ # Guarantee: no masks remain in this block after its last step.
2929
+ if block_mask_now.numel() > 0:
2930
+ transfer_index[b, prompt_len + block_idx * block_length + block_mask_now] = True
2931
+ continue
2932
+
2933
  k = int(num_transfer_tokens[b, inner_step].item())
2934
  if k <= 0:
2935
  continue
 
2939
  work[transfer_index] = x0[transfer_index]
2940
  yield work.clone(), prompt_len
2941
 
2942
+ work = self._finalize_generation_masks(
2943
+ work,
2944
+ prompt_len=prompt_len,
2945
+ attention_bias=attention_bias,
2946
+ cfg_scale=cfg_scale,
2947
+ )
2948
+ yield work.clone(), prompt_len
2949
+
2950
  def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optional[int]):
2951
  theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
2952
  with gr.Blocks(title="AIDAS Lab @ SNU", css=CUSTOM_CSS, theme=theme, js=FORCE_LIGHT_MODE_JS) as demo:
 
3385
  response = ""
3386
 
3387
  if mode == "Text":
3388
+ display_user_raw = message or "[Text request]"
3389
+ display_user = _format_user_message(display_user_raw)
3390
+ history = history + [(display_user, _render_text_message("Generating...", ""))]
3391
+ yield history, ""
3392
+
3393
+ for reply_html, status, done in app.run_chat_stream(
3394
  message,
3395
  chat_max_tokens,
3396
  chat_steps,
3397
  chat_block,
3398
  chat_temperature,
3399
+ update_every=32,
3400
+ ):
3401
+ response = _render_text_message(status, reply_html)
3402
+ history[-1] = (display_user, response)
3403
+ yield history, ""
3404
+ return
3405
  elif mode == "Text → Speech":
3406
  if not message:
3407
  status = "Please type some text for speech synthesis."
MMaDA/models/modeling_omada.py CHANGED
@@ -597,6 +597,196 @@ class OMadaModelLM(LLaDAModelLM):
597
 
598
  return final_outputs
599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  @torch.no_grad()
601
  def t2s_fixed_generate(
602
  self,
@@ -2116,8 +2306,8 @@ class OMadaModelLM(LLaDAModelLM):
2116
  uncond_input_ids = torch.cat(
2117
  [uncond_prefix, input_ids[:, resolution + 1:]], dim=1)
2118
  model_input = torch.cat([input_ids, uncond_input_ids])
2119
- attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0)
2120
- attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
2121
  logits = self(model_input, attention_bias=attention_bias).logits
2122
  # print(f"logits.shape: {logits.shape}")
2123
  cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
@@ -2178,6 +2368,101 @@ class OMadaModelLM(LLaDAModelLM):
2178
 
2179
 
2180
  return sampled_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2181
 
2182
 
2183
  AutoConfig.register("omada", OMadaConfig)
 
597
 
598
  return final_outputs
599
 
600
+ @torch.no_grad()
601
+ def t2s_generate_mmu_like_stream(
602
+ self,
603
+ input_ids: torch.LongTensor,
604
+ max_new_tokens: Optional[int] = None,
605
+ steps: int = 256,
606
+ block_length: int = 128,
607
+ temperature: float = 0.0,
608
+ cfg_scale: float = 0.0,
609
+ mask_token_id: int = 126336,
610
+ attention_mask: Optional[torch.LongTensor] = None,
611
+ uni_prompting=None,
612
+ codebook_size: Optional[int] = None,
613
+ audio_codebook_size: int = 4096,
614
+ update_every: Optional[int] = None,
615
+ ):
616
+ """
617
+ Stream speech token generation. Yields intermediate token lists.
618
+ """
619
+ if uni_prompting is None:
620
+ raise ValueError("uni_prompting must be provided")
621
+ if block_length <= 0:
622
+ raise ValueError("block_length must be positive")
623
+
624
+ batch_size, seq_len = input_ids.shape
625
+ device = input_ids.device
626
+
627
+ mask_positions_full = (input_ids == mask_token_id)
628
+ if not mask_positions_full.any():
629
+ raise ValueError("No mask tokens detected for T2S generation")
630
+
631
+ mask_cols = torch.where(mask_positions_full[0])[0]
632
+ speech_region_start = mask_cols[0].item()
633
+ speech_region_len = mask_cols.numel()
634
+
635
+ mask_counts = mask_positions_full.sum(dim=1)
636
+ if not torch.all(mask_counts == mask_counts[0]):
637
+ raise ValueError("All batch items must contain the same number of masked speech tokens for MMU-like generation")
638
+
639
+ if max_new_tokens is None:
640
+ max_new_tokens = speech_region_len
641
+ else:
642
+ max_new_tokens = min(max_new_tokens, speech_region_len)
643
+
644
+ block_length = max(1, min(block_length, max_new_tokens))
645
+ num_blocks = math.ceil(max_new_tokens / block_length)
646
+ inner_steps = max(1, steps // num_blocks)
647
+
648
+ codebook_base = codebook_size if codebook_size is not None else getattr(self.config, "codebook_size", 8192)
649
+ speech_vocab_start = len(uni_prompting.text_tokenizer) + codebook_base
650
+ speech_vocab_end = speech_vocab_start + audio_codebook_size
651
+
652
+ eoa_token_id = uni_prompting.sptids_dict['<|eoa|>'][0].item()
653
+ eos_token_id = uni_prompting.text_tokenizer.eos_token_id
654
+ vq_code_relative_eoa_id = audio_codebook_size
655
+ vq_code_relative_eos_id = audio_codebook_size + 1
656
+
657
+ work = input_ids.clone()
658
+
659
+ attention_bias = None
660
+ if attention_mask is not None:
661
+ attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
662
+
663
+ speech_indices = mask_cols[:max_new_tokens]
664
+ total_steps = num_blocks * inner_steps
665
+ global_step = 0
666
+
667
+ def _extract_relative_tokens(work_tensor: torch.Tensor):
668
+ audio_slice = slice(speech_region_start, speech_region_start + speech_region_len)
669
+ audio_region = work_tensor[:, audio_slice]
670
+ final_outputs = []
671
+ for seq in audio_region:
672
+ mask_tensor = seq.new_full(seq.shape, mask_token_id)
673
+ rel_eoa = seq.new_full(seq.shape, vq_code_relative_eoa_id)
674
+ rel_eos = seq.new_full(seq.shape, vq_code_relative_eos_id)
675
+ relative = torch.where(
676
+ seq == mask_token_id,
677
+ mask_tensor,
678
+ torch.where(
679
+ seq == eoa_token_id,
680
+ rel_eoa,
681
+ torch.where(
682
+ seq == eos_token_id,
683
+ rel_eos,
684
+ seq - speech_vocab_start
685
+ )
686
+ )
687
+ )
688
+ eoa_positions = (relative >= vq_code_relative_eoa_id).nonzero(as_tuple=True)[0]
689
+ if eoa_positions.numel() > 0:
690
+ relative = relative[:eoa_positions[0]]
691
+ final_outputs.append(relative[relative != mask_token_id])
692
+ return final_outputs
693
+
694
+ for block_idx in range(num_blocks):
695
+ block_start = block_idx * block_length
696
+ block_end = min(block_start + block_length, max_new_tokens)
697
+ curr_indices = speech_indices[block_start:block_end]
698
+ if curr_indices.numel() == 0:
699
+ continue
700
+
701
+ block_mask = mask_positions_full[:, curr_indices]
702
+ num_transfer_tokens = get_num_transfer_tokens(block_mask, inner_steps)
703
+
704
+ for inner_step in range(inner_steps):
705
+ if cfg_scale > 0.0:
706
+ un_cond = work.clone()
707
+ un_cond[:, speech_indices] = mask_token_id
708
+ stacked = torch.cat([work, un_cond], dim=0)
709
+ if attention_bias is not None:
710
+ att_bias = torch.cat([attention_bias, attention_bias], dim=0)
711
+ else:
712
+ att_bias = None
713
+ logits = self(stacked, attention_bias=att_bias).logits
714
+ cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
715
+ logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits)
716
+ else:
717
+ logits = self(work, attention_bias=attention_bias).logits
718
+
719
+ logits_block = logits.index_select(1, curr_indices.to(device))
720
+ logits_vq = logits_block[:, :, speech_vocab_start:speech_vocab_end]
721
+ logits_eoa = logits_block[:, :, eoa_token_id:eoa_token_id + 1]
722
+ logits_eos = logits_block[:, :, eos_token_id:eos_token_id + 1]
723
+
724
+ combined_logits = torch.cat([logits_vq, logits_eoa, logits_eos], dim=-1)
725
+ if temperature > 0.0:
726
+ combined_logits = combined_logits / max(temperature, 1e-5)
727
+ probs = F.softmax(combined_logits, dim=-1)
728
+
729
+ sampled = torch.multinomial(
730
+ probs.view(-1, probs.size(-1)), 1
731
+ ).view(batch_size, curr_indices.numel())
732
+
733
+ selected_probs = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
734
+
735
+ eos_tensor = sampled.new_full(sampled.shape, eos_token_id)
736
+ eoa_tensor = sampled.new_full(sampled.shape, eoa_token_id)
737
+ sampled_absolute = torch.where(
738
+ sampled == vq_code_relative_eos_id,
739
+ eos_tensor,
740
+ torch.where(
741
+ sampled == vq_code_relative_eoa_id,
742
+ eoa_tensor,
743
+ sampled + speech_vocab_start
744
+ )
745
+ )
746
+
747
+ current_block_vals = work.index_select(1, curr_indices)
748
+ mask_current = current_block_vals == mask_token_id
749
+
750
+ confidence = torch.where(
751
+ mask_current,
752
+ selected_probs,
753
+ torch.full_like(selected_probs, float('-inf'))
754
+ )
755
+
756
+ finalize = torch.zeros_like(mask_current, dtype=torch.bool)
757
+ for b in range(batch_size):
758
+ available = mask_current[b].sum().item()
759
+ if available == 0:
760
+ continue
761
+ transfer = min(int(num_transfer_tokens[b, inner_step].item()), available)
762
+ if transfer <= 0:
763
+ continue
764
+ _, idxs = torch.topk(confidence[b], k=transfer, largest=True)
765
+ finalize[b, idxs] = True
766
+
767
+ mask_fill = sampled_absolute.new_full(sampled_absolute.shape, mask_token_id)
768
+ updates = torch.where(finalize, sampled_absolute, mask_fill)
769
+ new_block = torch.where(mask_current, updates, current_block_vals)
770
+
771
+ work[:, curr_indices] = new_block
772
+ mask_positions_full[:, curr_indices] = new_block == mask_token_id
773
+
774
+ global_step += 1
775
+ should_yield = False
776
+ if update_every is not None and update_every > 0:
777
+ if global_step % update_every == 0 or global_step == total_steps:
778
+ should_yield = True
779
+ else:
780
+ if inner_step == inner_steps - 1 or global_step == total_steps:
781
+ should_yield = True
782
+ if should_yield:
783
+ yield _extract_relative_tokens(work), f"Step {global_step}/{total_steps}"
784
+
785
+ if not mask_positions_full[:, curr_indices].any():
786
+ break
787
+
788
+ return
789
+
790
  @torch.no_grad()
791
  def t2s_fixed_generate(
792
  self,
 
2306
  uncond_input_ids = torch.cat(
2307
  [uncond_prefix, input_ids[:, resolution + 1:]], dim=1)
2308
  model_input = torch.cat([input_ids, uncond_input_ids])
2309
+ all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0)
2310
+ attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1)
2311
  logits = self(model_input, attention_bias=attention_bias).logits
2312
  # print(f"logits.shape: {logits.shape}")
2313
  cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
 
2368
 
2369
 
2370
  return sampled_ids
2371
+
2372
+ @torch.no_grad()
2373
+ def i2i_generate_decoding_stepwise(
2374
+ self,
2375
+ input_ids: torch.LongTensor = None,
2376
+ uncond_input_ids: torch.LongTensor = None,
2377
+ attention_mask=None,
2378
+ uncond_attention_mask=None,
2379
+ temperature=1.0,
2380
+ timesteps=18, # ideal number of steps is 18 in maskgit paper
2381
+ guidance_scale=0,
2382
+ noise_schedule=cosine_schedule,
2383
+ generator: torch.Generator = None,
2384
+ config=None,
2385
+ seq_len=1024,
2386
+ mask_token_id=126336,
2387
+ resolution=512,
2388
+ codebook_size=8192,
2389
+ vq_model=None,
2390
+ **kwargs,
2391
+ ):
2392
+ """
2393
+ Stepwise i2i decoding that yields intermediate images per step.
2394
+ """
2395
+ if vq_model is None:
2396
+ raise ValueError("vq_model is required for stepwise decoding.")
2397
+
2398
+ mask_count = (input_ids == mask_token_id).sum().item()
2399
+ num_vq_tokens = seq_len
2400
+ num_new_special_tokens = 0
2401
+ uni_prompting = kwargs.get("uni_prompting", None)
2402
+ input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
2403
+ input_ids_minus_lm_vocab_size = torch.where(
2404
+ input_ids_minus_lm_vocab_size == mask_token_id,
2405
+ mask_token_id,
2406
+ input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens,
2407
+ )
2408
+
2409
+ if uncond_input_ids is not None:
2410
+ uncond_prefix = uncond_input_ids[:, :resolution + 1]
2411
+
2412
+ for step in range(timesteps):
2413
+ if uncond_input_ids is not None and guidance_scale > 0:
2414
+ uncond_input_ids = torch.cat(
2415
+ [uncond_prefix, input_ids[:, resolution + 1:]], dim=1)
2416
+ model_input = torch.cat([input_ids, uncond_input_ids])
2417
+ all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0)
2418
+ attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1)
2419
+ logits = self(model_input, attention_bias=attention_bias).logits
2420
+ cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
2421
+ logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
2422
+ logits = logits[:, -(num_vq_tokens + 1):-1,
2423
+ len(uni_prompting.text_tokenizer) + num_new_special_tokens:
2424
+ len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
2425
+ else:
2426
+ attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
2427
+ logits = self(input_ids, attention_bias=attention_bias).logits
2428
+ logits = logits[:, -(num_vq_tokens + 1):-1,
2429
+ len(uni_prompting.text_tokenizer) + num_new_special_tokens:
2430
+ len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
2431
+
2432
+ probs = logits.softmax(dim=-1)
2433
+ sampled = probs.reshape(-1, logits.size(-1))
2434
+ sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1])
2435
+
2436
+ unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
2437
+ sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
2438
+
2439
+ current_image_vq_indices = torch.clamp(sampled_ids.clone(), 0, codebook_size - 1)
2440
+ current_image = vq_model.decode_code(current_image_vq_indices)
2441
+ images = torch.clamp((current_image + 1.0) / 2.0, min=0.0, max=1.0)
2442
+ images *= 255.0
2443
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
2444
+ pil_images = Image.fromarray(images[0])
2445
+ yield pil_images, f"Step {step + 1}/{timesteps}"
2446
+
2447
+ ratio = 1.0 * (step + 1) / timesteps
2448
+ mask_ratio = noise_schedule(torch.tensor(ratio))
2449
+ selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]).squeeze(-1)
2450
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
2451
+ mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
2452
+ mask_len = torch.max(
2453
+ torch.tensor([1], device=logits.device),
2454
+ torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len),
2455
+ )
2456
+ temperature = temperature * (1.0 - ratio)
2457
+ masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
2458
+ input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(
2459
+ masking,
2460
+ mask_token_id,
2461
+ sampled_ids + len(uni_prompting.text_tokenizer) + num_new_special_tokens,
2462
+ )
2463
+ input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
2464
+
2465
+ return sampled_ids
2466
 
2467
 
2468
  AutoConfig.register("omada", OMadaConfig)
MMaDA/models/speech_tokenization/condition_style_centroid ADDED
@@ -0,0 +1 @@
 
 
1
+ /dataset/omada/AIDAS-Omni-Modal-Diffusion/_style_cache
app.py CHANGED
@@ -11,8 +11,18 @@ import os
11
  import sys
12
  import subprocess
13
  import importlib
 
 
 
 
 
 
 
 
14
  from pathlib import Path
15
  from typing import List
 
 
16
 
17
  import gradio as gr
18
  import spaces
@@ -23,6 +33,9 @@ from packaging.version import parse as parse_version
23
  # ---------------------------
24
 
25
  PROJECT_ROOT = Path(__file__).resolve().parent
 
 
 
26
  MMADA_ROOT = PROJECT_ROOT / "MMaDA"
27
  if str(MMADA_ROOT) not in sys.path:
28
  sys.path.insert(0, str(MMADA_ROOT))
@@ -135,7 +148,7 @@ def download_checkpoint() -> Path:
135
  raise FileNotFoundError(f"MODEL_CHECKPOINT_PATH does not exist: {override_path}")
136
  return override_path
137
 
138
- repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion")
139
  revision = os.getenv("MODEL_REVISION", "main")
140
  token = os.getenv("HF_TOKEN")
141
  cache_dir = PROJECT_ROOT / "_ckpt_cache"
@@ -226,6 +239,323 @@ CHAT_EXAMPLES = _load_text_examples(ASSET_ROOT / "chat" / "text.txt")
226
  T2I_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2i" / "text.txt")
227
  I2I_EXAMPLES = _load_i2i_examples()
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  # audio / video / image examples
230
  S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
231
  V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
@@ -277,15 +607,27 @@ def get_app() -> OmadaDemo:
277
 
278
  default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
279
  legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
 
280
  train_config = os.getenv("TRAIN_CONFIG_PATH")
281
  if not train_config:
282
- train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
 
 
 
283
 
284
  device = os.getenv("DEVICE", "cuda")
285
  APP = OmadaDemo(train_config=train_config, checkpoint=str(ckpt_dir), device=device)
286
  return APP
287
 
288
 
 
 
 
 
 
 
 
 
289
  # ---------------------------
290
  # ZeroGPU-wrapped handlers
291
  # ---------------------------
@@ -310,37 +652,40 @@ def t2s_handler(text, max_tokens, steps, block_len, temperature, cfg_scale, gend
310
  @spaces.GPU
311
  def s2t_handler(audio_path, steps, block_len, max_tokens, remasking):
312
  app = get_app()
313
- text, status = app.run_s2t(
314
  audio_path=audio_path,
315
  steps=int(steps),
316
  block_length=int(block_len),
317
  max_new_tokens=int(max_tokens),
318
  remasking=str(remasking),
319
- )
320
- return text, status
 
321
 
322
  @spaces.GPU
323
  def v2t_handler(video, steps, block_len, max_tokens):
324
  app = get_app()
325
- text, status = app.run_v2t(
326
  video_path=video,
327
  steps=int(steps),
328
  block_length=int(block_len),
329
  max_new_tokens=int(max_tokens),
330
- )
331
- return text, status
 
332
 
333
  @spaces.GPU
334
  def chat_handler(message, max_tokens, steps, block_len, temperature):
335
  app = get_app()
336
- text, status = app.run_chat(
337
  message=message,
338
  max_new_tokens=int(max_tokens),
339
  steps=int(steps),
340
  block_length=int(block_len),
341
  temperature=float(temperature),
342
- )
343
- return text, status
 
344
 
345
  @spaces.GPU
346
  def mmu_handler(image, question, max_tokens, steps, block_len, temperature):
@@ -358,25 +703,27 @@ def mmu_handler(image, question, max_tokens, steps, block_len, temperature):
358
  @spaces.GPU
359
  def t2i_handler(prompt, timesteps, temperature, guidance):
360
  app = get_app()
361
- image, status = app.run_t2i(
362
  prompt=prompt,
363
  timesteps=int(timesteps),
364
  temperature=float(temperature),
365
  guidance_scale=float(guidance),
366
- )
367
- return image, status
 
368
 
369
  @spaces.GPU
370
  def i2i_handler(instruction, image, timesteps, temperature, guidance):
371
  app = get_app()
372
- image_out, status = app.run_i2i(
373
  instruction=instruction,
374
  source_image=image,
375
  timesteps=int(timesteps),
376
  temperature=float(temperature),
377
  guidance_scale=float(guidance),
378
- )
379
- return image_out, status
 
380
 
381
 
382
  # ---------------------------
@@ -384,255 +731,898 @@ def i2i_handler(instruction, image, timesteps, temperature, guidance):
384
  # ---------------------------
385
 
386
  theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
  with gr.Blocks(
389
  title="AIDAS Lab @ SNU - Omni-modal Diffusion",
390
- css=CUSTOM_CSS,
391
  theme=theme,
392
  js=FORCE_LIGHT_MODE_JS,
393
  ) as demo:
394
- with gr.Row():
395
- if LOGO_PATH.exists():
396
- gr.Image(
397
- value=str(LOGO_PATH),
398
- show_label=False,
399
- height=80,
400
- interactive=False,
401
- )
402
- gr.Markdown(
403
- "## Omni-modal Diffusion Foundation Model\n"
404
- "### AIDAS Lab @ SNU"
405
- )
 
406
 
407
- # ---- T2S ----
408
- with gr.Tab("Text → Speech (T2S)"):
409
- with gr.Row():
410
- t2s_text = gr.Textbox(
411
- label="Input text",
412
- lines=4,
413
- placeholder="Type the speech you want to synthesize...",
414
- )
415
- t2s_audio = gr.Audio(label="Generated speech", type="numpy")
416
- t2s_status = gr.Textbox(label="Status", interactive=False)
417
- with gr.Accordion("Advanced settings", open=False):
418
- t2s_max_tokens = gr.Slider(2, 512, value=384, step=2, label="Speech token length")
419
- t2s_steps = gr.Slider(2, 512, value=128, step=2, label="Total refinement steps")
420
- t2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
421
- t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
422
- t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, step=0.1, label="CFG scale")
423
- with gr.Row():
424
- t2s_gender = gr.Dropdown(["random", "female", "male"], value="random", label="Gender")
425
- t2s_emotion = gr.Dropdown(["random", "angry", "happy", "neutral", "sad"], value="random", label="Emotion")
426
- with gr.Row():
427
- t2s_speed = gr.Dropdown(["random", "normal", "fast", "slow"], value="random", label="Speed")
428
- t2s_pitch = gr.Dropdown(["random", "normal", "high", "low"], value="random", label="Pitch")
429
- if T2S_EXAMPLES:
430
- with gr.Accordion("Sample prompts", open=False):
431
- gr.Examples(
432
- examples=T2S_EXAMPLES,
433
- inputs=[t2s_text],
434
- examples_per_page=6,
435
- )
436
- t2s_btn = gr.Button("Generate speech", variant="primary")
437
- t2s_btn.click(
438
- t2s_handler,
439
- inputs=[
440
- t2s_text,
441
- t2s_max_tokens,
442
- t2s_steps,
443
- t2s_block,
444
- t2s_temperature,
445
- t2s_cfg,
446
- t2s_gender,
447
- t2s_emotion,
448
- t2s_speed,
449
- t2s_pitch,
450
- ],
451
- outputs=[t2s_audio, t2s_status],
452
  )
453
 
454
- # ---- S2T ----
455
- with gr.Tab("Speech Text (S2T)"):
456
- s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
457
- s2t_text_out = gr.Textbox(label="Transcription", lines=4)
458
- s2t_status = gr.Textbox(label="Status", interactive=False)
459
- with gr.Accordion("Advanced settings", open=False):
460
- s2t_steps = gr.Slider(2, 512, value=128, step=2, label="Denoising steps")
461
- s2t_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
462
- s2t_max_tokens = gr.Slider(2, 512, value=128, step=2, label="Max new tokens")
463
- s2t_remasking = gr.Dropdown(
464
- ["low_confidence", "random"],
465
- value="low_confidence",
466
- label="Remasking strategy",
467
- )
468
- if S2T_EXAMPLES:
469
- with gr.Accordion("Sample clips", open=False):
470
- gr.Examples(
471
- examples=S2T_EXAMPLES,
472
- inputs=[s2t_audio_in],
473
- examples_per_page=4,
474
- )
475
- s2t_btn = gr.Button("Transcribe", variant="primary")
476
- s2t_btn.click(
477
- s2t_handler,
478
- inputs=[s2t_audio_in, s2t_steps, s2t_block, s2t_max_tokens, s2t_remasking],
479
- outputs=[s2t_text_out, s2t_status],
480
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
- # ---- V2T ----
483
- with gr.Tab("Video Text (V2T)"):
484
- v2t_video_in = gr.Video(
485
- label="Upload or record video",
486
- height=256,
487
- sources=["upload", "webcam"],
488
- )
489
- v2t_text_out = gr.Textbox(label="Caption / answer", lines=4)
490
- v2t_status = gr.Textbox(label="Status", interactive=False)
491
- with gr.Accordion("Advanced settings", open=False):
492
- v2t_steps = gr.Slider(2, 512, value=64, step=2, label="Denoising steps")
493
- v2t_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
494
- v2t_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Max new tokens")
495
- if V2T_EXAMPLES:
496
- with gr.Accordion("Sample videos", open=False):
497
- gr.Examples(
498
- examples=V2T_EXAMPLES,
499
- inputs=[v2t_video_in],
500
- examples_per_page=4,
501
- )
502
- v2t_btn = gr.Button("Generate caption", variant="primary")
503
- v2t_btn.click(
504
- v2t_handler,
505
- inputs=[v2t_video_in, v2t_steps, v2t_block, v2t_max_tokens],
506
- outputs=[v2t_text_out, v2t_status],
507
  )
508
 
509
- # ---- T2I ----
510
- with gr.Tab("Text Image (T2I)"):
511
- t2i_prompt = gr.Textbox(
512
- label="Prompt",
513
- lines=4,
514
- placeholder="Describe the image you want to generate...",
515
- )
516
- t2i_image_out = gr.Image(label="Generated image")
517
- t2i_status = gr.Textbox(label="Status", interactive=False)
518
- with gr.Accordion("Advanced settings", open=False):
519
- t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps")
520
- t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
521
- t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
522
- if T2I_EXAMPLES:
523
- with gr.Accordion("Sample prompts", open=False):
524
- gr.Examples(
525
- examples=T2I_EXAMPLES,
526
- inputs=[t2i_prompt],
527
- examples_per_page=6,
528
- )
529
- t2i_btn = gr.Button("Generate image", variant="primary")
530
- t2i_btn.click(
531
- t2i_handler,
532
- inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance],
533
- outputs=[t2i_image_out, t2i_status],
534
- )
535
 
536
- # ---- I2I ----
537
- with gr.Tab("Image Editing (I2I)"):
538
- i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"])
539
- i2i_instr = gr.Textbox(
540
- label="Editing instruction",
541
- lines=4,
542
- placeholder="Describe how you want to edit the image...",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  )
544
- i2i_image_out = gr.Image(label="Edited image")
545
- i2i_status = gr.Textbox(label="Status", interactive=False)
546
- with gr.Accordion("Advanced settings", open=False):
547
- i2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps")
548
- i2i_temperature = gr.Slider(0.0, 2.0, value=0.3, step=0.05, label="Sampling temperature")
549
- i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
550
-
551
- if I2I_EXAMPLES:
552
- with gr.Accordion("Sample edits", open=False):
553
- gr.Examples(
554
- examples=I2I_EXAMPLES,
555
- inputs=[i2i_image_in, i2i_instr],
556
- examples_per_page=4,
557
- )
558
- i2i_btn = gr.Button("Apply edit", variant="primary")
559
- i2i_btn.click(
560
- i2i_handler,
561
- inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance],
562
- outputs=[i2i_image_out, i2i_status],
563
  )
564
 
565
- # ---- Chat ----
566
- with gr.Tab("Text Chat"):
567
- chat_in = gr.Textbox(
568
- label="Message",
569
- lines=4,
570
- placeholder="Ask anything. The model will reply in text.",
 
 
571
  )
572
- chat_out = gr.Textbox(label="Assistant reply", lines=6)
573
- chat_status = gr.Textbox(label="Status", interactive=False)
574
- with gr.Accordion("Advanced settings", open=False):
575
- chat_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Reply max tokens")
576
- chat_steps = gr.Slider(2, 512, value=64, step=2, label="Refinement steps")
577
- chat_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
578
- chat_temperature_slider = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Sampling temperature")
579
- if CHAT_EXAMPLES:
580
- with gr.Accordion("Sample prompts", open=False):
581
- gr.Examples(
582
- examples=CHAT_EXAMPLES,
583
- inputs=[chat_in],
584
- examples_per_page=6,
585
- )
586
- chat_btn = gr.Button("Send", variant="primary")
587
- chat_btn.click(
588
- chat_handler,
589
- inputs=[
590
- chat_in,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  chat_max_tokens,
592
  chat_steps,
593
  chat_block,
594
- chat_temperature_slider,
595
- ],
596
- outputs=[chat_out, chat_status],
597
- )
598
-
599
- # ---- MMU ----
600
- with gr.Tab("MMU (Image → Text)"):
601
- mmu_img = gr.Image(type="pil", label="Input image", sources=["upload"])
602
- mmu_question = gr.Textbox(
603
- label="Question",
604
- lines=3,
605
- placeholder="Ask about the scene, objects, or context of the image.",
606
- )
607
- mmu_answer = gr.Textbox(label="Answer", lines=6)
608
- mmu_status = gr.Textbox(label="Status", interactive=False)
609
- with gr.Accordion("Advanced settings", open=False):
610
- mmu_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Answer max tokens")
611
- mmu_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
612
- mmu_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
613
- mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Sampling temperature")
614
- if MMU_EXAMPLES:
615
- with gr.Accordion("Sample MMU prompts", open=False):
616
- gr.Examples(
617
- examples=MMU_EXAMPLES,
618
- inputs=[mmu_img, mmu_question],
619
- examples_per_page=1,
620
- )
621
- mmu_btn = gr.Button("Answer about the image", variant="primary")
622
- mmu_btn.click(
623
- mmu_handler,
624
- inputs=[
625
- mmu_img,
626
- mmu_question,
627
- mmu_max_tokens,
628
- mmu_steps,
629
- mmu_block,
630
- mmu_temperature,
631
- ],
632
- outputs=[mmu_answer, mmu_status],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  )
634
 
635
 
636
 
637
  if __name__ == "__main__":
638
- demo.launch()
 
11
  import sys
12
  import subprocess
13
  import importlib
14
+ import base64
15
+ import html
16
+ import io
17
+ import re
18
+ import wave
19
+ import tempfile
20
+ import shutil
21
+ from urllib.parse import quote
22
  from pathlib import Path
23
  from typing import List
24
+ import numpy as np
25
+ from PIL import Image
26
 
27
  import gradio as gr
28
  import spaces
 
33
  # ---------------------------
34
 
35
  PROJECT_ROOT = Path(__file__).resolve().parent
36
+ os.environ.setdefault("FORCE_EVAL_SETTINGS", "0")
37
+ PREVIEW_DIR = PROJECT_ROOT / "_preview_cache"
38
+ PREVIEW_DIR.mkdir(parents=True, exist_ok=True)
39
  MMADA_ROOT = PROJECT_ROOT / "MMaDA"
40
  if str(MMADA_ROOT) not in sys.path:
41
  sys.path.insert(0, str(MMADA_ROOT))
 
148
  raise FileNotFoundError(f"MODEL_CHECKPOINT_PATH does not exist: {override_path}")
149
  return override_path
150
 
151
+ repo_id = os.getenv("MODEL_REPO_ID", "snu-aidas/Dynin-Omni")
152
  revision = os.getenv("MODEL_REVISION", "main")
153
  token = os.getenv("HF_TOKEN")
154
  cache_dir = PROJECT_ROOT / "_ckpt_cache"
 
239
  T2I_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2i" / "text.txt")
240
  I2I_EXAMPLES = _load_i2i_examples()
241
 
242
+
243
+ def _render_response(status: str, body_html: str = "") -> str:
244
+ safe_status = html.escape(status or "")
245
+ parts = []
246
+ if safe_status:
247
+ parts.append(f"<p class='omada-response-status'>{safe_status}</p>")
248
+ if body_html:
249
+ parts.append(body_html)
250
+ content = "".join(parts)
251
+ return f"<div class='omada-response-container'>{content}</div>"
252
+
253
+
254
+ def _render_text_message(status: str, content: str) -> str:
255
+ content = (content or "").strip()
256
+ if not content:
257
+ return _render_response(status)
258
+ safe_content = _format_tokenized_text(content)
259
+ body = f"<div class='omada-response-block'>{safe_content}</div>"
260
+ return _render_response(status, body)
261
+
262
+
263
+ def _is_mask_like_token(token: str) -> bool:
264
+ t = token.strip()
265
+ if not t:
266
+ return False
267
+ upper = t.upper()
268
+ return (
269
+ upper in {"[MASK]", "<MASK>", "<|MASK|>", "<MASK_TOKEN>", "<|MASK_TOKEN|>"}
270
+ or upper in {"<MDM_MASK>", "MDM_MASK", "<|MDM_MASK|>"}
271
+ or "MASK" in upper
272
+ )
273
+
274
+
275
+ def _is_special_token(token: str) -> bool:
276
+ t = token.strip()
277
+ return bool(t) and t.startswith("<|") and t.endswith("|>")
278
+
279
+
280
+ def _format_tokenized_text(text: str) -> str:
281
+ if not text:
282
+ return ""
283
+ # Handle both complete and partially-streamed mask tokens.
284
+ mask_pat = r"(<[^>\n]*MASK[^>\n]*>?|\[MASK\]|MASK_TOKEN)"
285
+ chunks = re.split(mask_pat, text, flags=re.IGNORECASE)
286
+ out = []
287
+ for chunk in chunks:
288
+ if not chunk:
289
+ continue
290
+ if re.fullmatch(mask_pat, chunk, flags=re.IGNORECASE) or _is_mask_like_token(chunk):
291
+ out.append("<span class='omada-token-pill omada-token-mask'>MASK</span>")
292
+ continue
293
+ if chunk.isspace():
294
+ out.append(chunk.replace("\n", "<br>"))
295
+ continue
296
+ safe = html.escape(chunk)
297
+ if _is_special_token(chunk):
298
+ out.append(f"<span class='omada-token-pill omada-token-special'>{safe}</span>")
299
+ else:
300
+ out.append(safe)
301
+ return "".join(out).replace("\n", "<br>")
302
+
303
+
304
+ def _render_audio_message(status: str, audio):
305
+ if not audio:
306
+ return _render_response(status)
307
+
308
+ sample_rate, data = audio
309
+ if data is None:
310
+ return _render_response(status)
311
+
312
+ waveform = np.asarray(data, dtype=np.float32)
313
+ if waveform.size == 0:
314
+ return _render_response(status)
315
+
316
+ if waveform.ndim == 1:
317
+ waveform = waveform[:, None]
318
+
319
+ channels = waveform.shape[1]
320
+ clipped = np.clip(waveform, -1.0, 1.0)
321
+ pcm16 = (clipped * 32767.0).astype(np.int16)
322
+
323
+ buffer = io.BytesIO()
324
+ with wave.open(buffer, "wb") as wav_writer:
325
+ wav_writer.setnchannels(channels)
326
+ wav_writer.setsampwidth(2)
327
+ wav_writer.setframerate(int(sample_rate))
328
+ wav_writer.writeframes(pcm16.tobytes())
329
+
330
+ encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
331
+ audio_tag = (
332
+ "<div class='omada-audio-block'>"
333
+ "<audio controls preload='auto' playsinline>"
334
+ f"<source src='data:audio/wav;base64,{encoded}' type='audio/wav' /></audio>"
335
+ "</div>"
336
+ )
337
+ return _render_response(status, audio_tag)
338
+
339
+
340
+ def _render_image_message(status: str, image: Image.Image):
341
+ if image is None:
342
+ return _render_response(status)
343
+
344
+ buffer = io.BytesIO()
345
+ try:
346
+ image.save(buffer, format="PNG")
347
+ except Exception:
348
+ return _render_response(status)
349
+
350
+ encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
351
+ image_html = (
352
+ "<div class='omada-response-block'>"
353
+ "<img src='data:image/png;base64,"
354
+ f"{encoded}"
355
+ "' alt='Generated image' style='max-width:100%;border-radius:12px;' />"
356
+ "</div>"
357
+ )
358
+ return _render_response(status, image_html)
359
+
360
+
361
+ def _render_user_message(mode: str, message: str, image_in, audio_in, video_in, defer_video: bool = False) -> str:
362
+ def _cache_media_copy(src_path: str) -> str:
363
+ path = str(src_path or "")
364
+ if not path or not os.path.exists(path):
365
+ return path
366
+ try:
367
+ suffix = Path(path).suffix or ""
368
+ fd, dst = tempfile.mkstemp(prefix="omada_media_", suffix=suffix, dir=str(PREVIEW_DIR))
369
+ os.close(fd)
370
+ shutil.copy2(path, dst)
371
+ return dst
372
+ except Exception:
373
+ return path
374
+
375
+ def _to_browser_mp4(video_path: str) -> str:
376
+ path = str(video_path or "")
377
+ if not path:
378
+ return path
379
+ try:
380
+ fd, out_path = tempfile.mkstemp(prefix="omada_preview_", suffix=".mp4", dir=str(PREVIEW_DIR))
381
+ os.close(fd)
382
+ cmd = [
383
+ "ffmpeg",
384
+ "-y",
385
+ "-i",
386
+ path,
387
+ "-an",
388
+ "-c:v",
389
+ "libx264",
390
+ "-pix_fmt",
391
+ "yuv420p",
392
+ "-movflags",
393
+ "+faststart",
394
+ out_path,
395
+ ]
396
+ proc = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
397
+ if proc.returncode == 0 and os.path.exists(out_path):
398
+ return out_path
399
+ if os.path.exists(out_path):
400
+ os.remove(out_path)
401
+ except Exception:
402
+ pass
403
+ return path
404
+
405
+ def _video_data_uri(video_path: str, mime: str, max_bytes: int = 25 * 1024 * 1024) -> str:
406
+ try:
407
+ size = os.path.getsize(video_path)
408
+ if size <= 0 or size > max_bytes:
409
+ return ""
410
+ with open(video_path, "rb") as f:
411
+ encoded = base64.b64encode(f.read()).decode("ascii")
412
+ return f"data:{mime};base64,{encoded}"
413
+ except Exception:
414
+ return ""
415
+
416
+ def _video_poster_data_uri(video_path: str) -> str:
417
+ try:
418
+ import cv2 # type: ignore
419
+
420
+ cap = cv2.VideoCapture(video_path)
421
+ ok, frame = cap.read()
422
+ cap.release()
423
+ if not ok or frame is None:
424
+ return ""
425
+ ok, buf = cv2.imencode(".jpg", frame)
426
+ if not ok:
427
+ return ""
428
+ encoded = base64.b64encode(buf.tobytes()).decode("ascii")
429
+ return f"data:image/jpeg;base64,{encoded}"
430
+ except Exception:
431
+ return ""
432
+
433
+ parts = []
434
+ text = (message or "").strip()
435
+ if image_in is not None:
436
+ try:
437
+ if isinstance(image_in, Image.Image):
438
+ buffer = io.BytesIO()
439
+ image_in.save(buffer, format="PNG")
440
+ encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
441
+ parts.append(
442
+ "<div class='omada-user-media'>"
443
+ f"<img src='data:image/png;base64,{encoded}' alt='Input image' />"
444
+ "</div>"
445
+ )
446
+ elif isinstance(image_in, str) and image_in:
447
+ try:
448
+ with Image.open(image_in).convert("RGB") as pil_img:
449
+ buf = io.BytesIO()
450
+ pil_img.save(buf, format="PNG")
451
+ encoded = base64.b64encode(buf.getvalue()).decode("ascii")
452
+ parts.append(
453
+ "<div class='omada-user-media'>"
454
+ f"<img src='data:image/png;base64,{encoded}' alt='Input image' />"
455
+ "</div>"
456
+ )
457
+ except Exception:
458
+ image_path = _cache_media_copy(image_in)
459
+ parts.append(
460
+ "<div class='omada-user-media'>"
461
+ f"<img src='/file={quote(image_path)}' alt='Input image' />"
462
+ "</div>"
463
+ )
464
+ except Exception:
465
+ pass
466
+
467
+ if mode == "MMU (Video → Text)" and video_in:
468
+ if defer_video:
469
+ parts.append("<div class='omada-user-media'><div class='omada-video-loading'>Video loading...</div></div>")
470
+ if text:
471
+ parts.append(f"<div>{html.escape(text)}</div>")
472
+ return "".join(parts)
473
+ video_path = None
474
+ if isinstance(video_in, str):
475
+ video_path = video_in
476
+ elif isinstance(video_in, dict):
477
+ video_path = video_in.get("path") or video_in.get("name")
478
+ if video_path:
479
+ cached_original = _cache_media_copy(video_path)
480
+ preview_path = _to_browser_mp4(cached_original)
481
+ poster = _video_poster_data_uri(cached_original)
482
+ poster_attr = f" poster='{poster}'" if poster else ""
483
+ source_path = str(preview_path or cached_original)
484
+ fallback_path = str(cached_original)
485
+ def _video_mime(path: str) -> str:
486
+ ext = os.path.splitext(path.lower())[1]
487
+ return {
488
+ ".mp4": "video/mp4",
489
+ ".webm": "video/webm",
490
+ ".mov": "video/quicktime",
491
+ ".m4v": "video/mp4",
492
+ ".avi": "video/x-msvideo",
493
+ ".mkv": "video/x-matroska",
494
+ }.get(ext, "video/mp4")
495
+ parts.append(
496
+ "<div class='omada-user-media'>"
497
+ f"<video class='omada-user-video' controls playsinline preload='metadata'{poster_attr}>"
498
+ f"<source src='{(_video_data_uri(source_path, _video_mime(source_path)) or f'/file={quote(source_path)}')}' type='{_video_mime(source_path)}' />"
499
+ f"<source src='/file={quote(fallback_path)}' type='{_video_mime(fallback_path)}' />"
500
+ f"<a href='/file={quote(fallback_path)}' target='_blank' rel='noopener'>Open video</a>"
501
+ "</video>"
502
+ "</div>"
503
+ )
504
+
505
+ if audio_in is not None:
506
+ audio_path = ""
507
+ if isinstance(audio_in, str):
508
+ audio_path = audio_in
509
+ elif isinstance(audio_in, dict):
510
+ audio_path = audio_in.get("path") or audio_in.get("name") or ""
511
+ elif isinstance(audio_in, (tuple, list)) and len(audio_in) == 2:
512
+ try:
513
+ sample_rate, data = audio_in
514
+ waveform = np.asarray(data, dtype=np.float32)
515
+ if waveform.ndim == 1:
516
+ waveform = waveform[:, None]
517
+ waveform = np.clip(waveform, -1.0, 1.0)
518
+ pcm16 = (waveform * 32767.0).astype(np.int16)
519
+ fd, temp_audio = tempfile.mkstemp(prefix="omada_user_audio_", suffix=".wav", dir=str(PREVIEW_DIR))
520
+ os.close(fd)
521
+ with wave.open(temp_audio, "wb") as wav_writer:
522
+ wav_writer.setnchannels(pcm16.shape[1])
523
+ wav_writer.setsampwidth(2)
524
+ wav_writer.setframerate(int(sample_rate))
525
+ wav_writer.writeframes(pcm16.tobytes())
526
+ audio_path = temp_audio
527
+ except Exception:
528
+ audio_path = ""
529
+ if audio_path:
530
+ ext = os.path.splitext(audio_path.lower())[1]
531
+ mime = {
532
+ ".wav": "audio/wav",
533
+ ".mp3": "audio/mpeg",
534
+ ".flac": "audio/flac",
535
+ ".ogg": "audio/ogg",
536
+ ".m4a": "audio/mp4",
537
+ }.get(ext, "audio/wav")
538
+ src = ""
539
+ try:
540
+ with open(audio_path, "rb") as f:
541
+ encoded_audio = base64.b64encode(f.read()).decode("ascii")
542
+ src = f"data:{mime};base64,{encoded_audio}"
543
+ except Exception:
544
+ audio_path = _cache_media_copy(audio_path)
545
+ src = f"/file={quote(audio_path)}"
546
+ parts.append(
547
+ "<div class='omada-user-media'>"
548
+ f"<audio controls preload='metadata'><source src='{src}' type='{mime}' /></audio>"
549
+ f"<div><a href='{src}' target='_blank' rel='noopener'>Open audio</a></div>"
550
+ "</div>"
551
+ )
552
+
553
+ if text:
554
+ parts.append(f"<div>{html.escape(text)}</div>")
555
+ if not parts:
556
+ parts.append(f"<div>[{html.escape(mode)}]</div>")
557
+ return "".join(parts)
558
+
559
  # audio / video / image examples
560
  S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
561
  V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
 
607
 
608
  default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
609
  legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
610
+ eval_cfg = Path("/dataset/omada/OMaDA/MMaDA/configs/omada_instruction_tuning2.yaml")
611
  train_config = os.getenv("TRAIN_CONFIG_PATH")
612
  if not train_config:
613
+ if eval_cfg.exists():
614
+ train_config = str(eval_cfg)
615
+ else:
616
+ train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
617
 
618
  device = os.getenv("DEVICE", "cuda")
619
  APP = OmadaDemo(train_config=train_config, checkpoint=str(ckpt_dir), device=device)
620
  return APP
621
 
622
 
623
+ def warmup_model_status() -> str:
624
+ try:
625
+ get_app()
626
+ return "Model status: Loaded. Inference is ready."
627
+ except Exception as exc:
628
+ return f"Model status: Load failed ({exc})."
629
+
630
+
631
  # ---------------------------
632
  # ZeroGPU-wrapped handlers
633
  # ---------------------------
 
652
  @spaces.GPU
653
  def s2t_handler(audio_path, steps, block_len, max_tokens, remasking):
654
  app = get_app()
655
+ for text, status in app.run_s2t_stream(
656
  audio_path=audio_path,
657
  steps=int(steps),
658
  block_length=int(block_len),
659
  max_new_tokens=int(max_tokens),
660
  remasking=str(remasking),
661
+ update_every=32,
662
+ ):
663
+ yield text, status
664
 
665
  @spaces.GPU
666
  def v2t_handler(video, steps, block_len, max_tokens):
667
  app = get_app()
668
+ for text, status in app.run_v2t_stream(
669
  video_path=video,
670
  steps=int(steps),
671
  block_length=int(block_len),
672
  max_new_tokens=int(max_tokens),
673
+ update_every=32,
674
+ ):
675
+ yield text, status
676
 
677
  @spaces.GPU
678
  def chat_handler(message, max_tokens, steps, block_len, temperature):
679
  app = get_app()
680
+ for reply_html, status, done in app.run_chat_stream(
681
  message=message,
682
  max_new_tokens=int(max_tokens),
683
  steps=int(steps),
684
  block_length=int(block_len),
685
  temperature=float(temperature),
686
+ update_every=32,
687
+ ):
688
+ yield reply_html, status
689
 
690
  @spaces.GPU
691
  def mmu_handler(image, question, max_tokens, steps, block_len, temperature):
 
703
  @spaces.GPU
704
  def t2i_handler(prompt, timesteps, temperature, guidance):
705
  app = get_app()
706
+ for image, status in app.run_t2i_stream(
707
  prompt=prompt,
708
  timesteps=int(timesteps),
709
  temperature=float(temperature),
710
  guidance_scale=float(guidance),
711
+ update_every=2,
712
+ ):
713
+ yield image, status
714
 
715
  @spaces.GPU
716
  def i2i_handler(instruction, image, timesteps, temperature, guidance):
717
  app = get_app()
718
+ for image_out, status in app.run_i2i_stream(
719
  instruction=instruction,
720
  source_image=image,
721
  timesteps=int(timesteps),
722
  temperature=float(temperature),
723
  guidance_scale=float(guidance),
724
+ update_every=2,
725
+ ):
726
+ yield image_out, status
727
 
728
 
729
  # ---------------------------
 
731
  # ---------------------------
732
 
733
  theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
734
+ EXTRA_CSS = """
735
+ html, body, .gradio-container {
736
+ background: var(--omada-surface) !important;
737
+ color: var(--omada-text-primary) !important;
738
+ }
739
+ .omada-shell {
740
+ min-height: 0;
741
+ display: flex;
742
+ flex-direction: column;
743
+ padding-bottom: 6px;
744
+ }
745
+ .omada-sample-row {
746
+ gap: 10px !important;
747
+ justify-content: center !important;
748
+ margin-bottom: 6px;
749
+ }
750
+ .omada-sample-row .gradio-button {
751
+ max-width: 280px !important;
752
+ }
753
+ .omada-hero {
754
+ text-align: center;
755
+ margin: 40px 0 24px 0;
756
+ }
757
+ .omada-hero h2 {
758
+ font-size: 2.2rem;
759
+ margin: 0;
760
+ color: var(--omada-dark-text);
761
+ }
762
+ .omada-hero p {
763
+ margin: 10px 0 0 0;
764
+ color: var(--omada-dark-muted);
765
+ }
766
+ .omada-input-row {
767
+ gap: 6px !important;
768
+ align-items: center !important;
769
+ display: flex !important;
770
+ flex-direction: row !important;
771
+ justify-content: center !important;
772
+ position: relative !important;
773
+ inset: auto !important;
774
+ top: auto !important;
775
+ right: auto !important;
776
+ bottom: auto !important;
777
+ left: auto !important;
778
+ transform: none !important;
779
+ background: var(--omada-surface-alt);
780
+ padding: 6px 14px;
781
+ border-radius: 999px;
782
+ z-index: 5;
783
+ width: min(980px, calc(100vw - 24px));
784
+ margin: 4px auto 8px;
785
+ box-shadow: 0 8px 24px rgba(0,0,0,0.08);
786
+ box-sizing: border-box;
787
+ }
788
+ .omada-input-row > * {
789
+ min-width: 0 !important;
790
+ margin: 0 !important;
791
+ align-self: center !important;
792
+ background: transparent !important;
793
+ box-shadow: none !important;
794
+ border: none !important;
795
+ }
796
+ .omada-input-row .gradio-textbox textarea {
797
+ background: var(--omada-surface) !important;
798
+ color: var(--omada-text-primary) !important;
799
+ border-radius: 999px !important;
800
+ border: 1px solid var(--omada-border) !important;
801
+ padding: 6px 10px !important;
802
+ min-height: 36px !important;
803
+ }
804
+ .omada-plus-btn button,
805
+ .omada-send-btn button {
806
+ border-radius: 999px !important;
807
+ width: 36px !important;
808
+ min-width: 36px !important;
809
+ height: 36px !important;
810
+ background: var(--omada-surface) !important;
811
+ color: var(--omada-text-primary) !important;
812
+ border: 1px solid var(--omada-border) !important;
813
+ padding: 0 !important;
814
+ font-size: 1.2rem !important;
815
+ line-height: 1 !important;
816
+ }
817
+ .omada-plus-btn,
818
+ .omada-send-btn {
819
+ flex: 0 0 36px !important;
820
+ display: flex !important;
821
+ align-items: center !important;
822
+ justify-content: center !important;
823
+ }
824
+ .omada-auto {
825
+ width: 110px !important;
826
+ flex: 0 0 110px !important;
827
+ display: flex !important;
828
+ align-items: center !important;
829
+ }
830
+ .omada-auto select {
831
+ height: 36px !important;
832
+ min-height: 36px !important;
833
+ font-size: 0.95rem !important;
834
+ padding: 0 12px !important;
835
+ background: var(--omada-surface) !important;
836
+ border: 1px solid var(--omada-border) !important;
837
+ color: var(--omada-text-primary) !important;
838
+ border-radius: 999px !important;
839
+ appearance: none !important;
840
+ -webkit-appearance: none !important;
841
+ -moz-appearance: none !important;
842
+ background-image: none !important;
843
+ }
844
+ .omada-auto svg,
845
+ .omada-auto .wrap > svg,
846
+ .omada-auto .dropdown-arrow {
847
+ display: none !important;
848
+ }
849
+ .omada-plus-btn button,
850
+ .omada-send-btn button {
851
+ flex: 0 0 auto !important;
852
+ }
853
+ .omada-input-row .gradio-textbox {
854
+ width: 100% !important;
855
+ flex: 1 1 auto !important;
856
+ min-width: 0 !important;
857
+ opacity: 1 !important;
858
+ pointer-events: auto !important;
859
+ background: transparent !important;
860
+ border: none !important;
861
+ box-shadow: none !important;
862
+ }
863
+ .omada-input-row .gradio-textbox > div,
864
+ .omada-input-row .gradio-dropdown,
865
+ .omada-input-row .gradio-dropdown > div,
866
+ .omada-plus-btn,
867
+ .omada-send-btn,
868
+ .omada-auto {
869
+ background: transparent !important;
870
+ border: none !important;
871
+ box-shadow: none !important;
872
+ }
873
+ .omada-send-btn {
874
+ margin-left: -2px !important;
875
+ }
876
+ .omada-input-row .gradio-textbox textarea {
877
+ width: 100% !important;
878
+ display: block !important;
879
+ pointer-events: auto !important;
880
+ opacity: 1 !important;
881
+ cursor: text !important;
882
+ }
883
+ .omada-panel-backdrop {
884
+ display: none !important;
885
+ }
886
+ .omada-panel {
887
+ position: relative !important;
888
+ top: auto !important;
889
+ left: auto !important;
890
+ transform: none !important;
891
+ max-height: none !important;
892
+ overflow: visible !important;
893
+ width: min(980px, calc(100vw - 24px));
894
+ margin: 0 auto 14px auto;
895
+ box-shadow: 0 20px 60px rgba(0,0,0,0.12);
896
+ z-index: 9999;
897
+ pointer-events: auto !important;
898
+ isolation: isolate;
899
+ }
900
+ .omada-controls-safe {
901
+ width: min(980px, calc(100vw - 24px));
902
+ margin: 0 auto 6px auto;
903
+ }
904
+ .omada-panel * {
905
+ pointer-events: auto;
906
+ }
907
+ .omada-panel input,
908
+ .omada-panel select,
909
+ .omada-panel textarea,
910
+ .omada-panel button,
911
+ .omada-panel .gradio-slider,
912
+ .omada-panel .gradio-slider * {
913
+ pointer-events: auto !important;
914
+ }
915
+ .omada-panel .gradio-radio,
916
+ .omada-panel .gradio-radio label,
917
+ .omada-panel .gradio-radio input {
918
+ pointer-events: auto !important;
919
+ cursor: pointer !important;
920
+ }
921
+ .omada-panel .gradio-radio {
922
+ position: relative !important;
923
+ z-index: 300 !important;
924
+ }
925
+ .omada-panel .gradio-slider,
926
+ .omada-panel .gradio-slider .wrap,
927
+ .omada-panel .gradio-slider .wrap-inner,
928
+ .omada-panel .gradio-slider input[type="range"],
929
+ .omada-panel .gradio-slider input[type="number"],
930
+ .omada-panel .gradio-dropdown,
931
+ .omada-panel .gradio-dropdown select,
932
+ .omada-panel .gradio-textbox textarea {
933
+ pointer-events: auto !important;
934
+ position: relative !important;
935
+ z-index: 400 !important;
936
+ }
937
+ .omada-panel .gradio-slider input[type="range"] {
938
+ touch-action: pan-x !important;
939
+ }
940
+ .omada-panel .gradio-dropdown,
941
+ .omada-panel .gradio-dropdown .wrap {
942
+ z-index: 1000 !important;
943
+ }
944
+ .gradio-dropdown .options,
945
+ .gradio-dropdown .wrap .options {
946
+ z-index: 2000 !important;
947
+ }
948
+ .gradio-container .input-status,
949
+ .gradio-container .status,
950
+ .gradio-container .status-dot,
951
+ .gradio-container .status-indicator,
952
+ .gradio-container .label-wrap .status,
953
+ .gradio-container .label-wrap .status-dot {
954
+ display: none !important;
955
+ }
956
+ .omada-chatbot {
957
+ background: transparent !important;
958
+ border: none !important;
959
+ }
960
+ .gradio-chatbot .message {
961
+ border-radius: 18px !important;
962
+ }
963
+ .gradio-chatbot .message.user {
964
+ margin-left: auto !important;
965
+ background: #2e3037 !important;
966
+ color: var(--omada-text-primary) !important;
967
+ pointer-events: auto !important;
968
+ }
969
+ .gradio-chatbot .message.bot {
970
+ margin-right: auto !important;
971
+ background: #22242a !important;
972
+ color: var(--omada-text-primary) !important;
973
+ pointer-events: auto !important;
974
+ }
975
+ .gradio-chatbot .message.user *,
976
+ .gradio-chatbot .message.bot * {
977
+ pointer-events: auto !important;
978
+ }
979
+ .omada-panel {
980
+ background: var(--omada-dark-panel);
981
+ border: 1px solid var(--omada-dark-border);
982
+ border-radius: 16px;
983
+ padding: 16px;
984
+ }
985
+ .omada-chip button {
986
+ border-radius: 999px !important;
987
+ background: linear-gradient(160deg, rgba(255,255,255,0.62), rgba(255,255,255,0.36)) !important;
988
+ color: #22324a !important;
989
+ border: 1px solid rgba(255,255,255,0.72) !important;
990
+ font-size: 0.68rem !important;
991
+ line-height: 1.2 !important;
992
+ padding: 6px 10px !important;
993
+ backdrop-filter: blur(14px) saturate(165%);
994
+ -webkit-backdrop-filter: blur(14px) saturate(165%);
995
+ box-shadow: 0 8px 20px rgba(36, 56, 92, 0.16) !important;
996
+ }
997
+ .omada-sample-row .gradio-button,
998
+ .omada-sample-row .gradio-button > div,
999
+ .omada-sample-row .gradio-button > button {
1000
+ background: transparent !important;
1001
+ }
1002
+ .omada-chip button:hover {
1003
+ transform: translateY(-1px);
1004
+ background: linear-gradient(160deg, rgba(255,255,255,0.74), rgba(255,255,255,0.44)) !important;
1005
+ }
1006
+ .omada-video-loading {
1007
+ width: 360px;
1008
+ max-width: min(42vw, 360px);
1009
+ min-height: 64px;
1010
+ border-radius: 12px;
1011
+ border: 1px solid var(--omada-glass-border);
1012
+ background: rgba(255,255,255,0.35);
1013
+ display: flex;
1014
+ align-items: center;
1015
+ justify-content: center;
1016
+ font-size: 0.9rem;
1017
+ color: #304463;
1018
+ backdrop-filter: blur(10px) saturate(150%);
1019
+ -webkit-backdrop-filter: blur(10px) saturate(150%);
1020
+ }
1021
+ .omada-user-media {
1022
+ margin-bottom: 6px;
1023
+ }
1024
+ .omada-user-media img,
1025
+ .omada-user-media video {
1026
+ max-width: 240px;
1027
+ width: 240px;
1028
+ max-height: 180px;
1029
+ object-fit: contain;
1030
+ border-radius: 10px;
1031
+ border: 1px solid var(--omada-border);
1032
+ display: block;
1033
+ }
1034
+ .omada-user-media .omada-user-video {
1035
+ width: 360px;
1036
+ max-width: min(42vw, 360px);
1037
+ max-height: 240px;
1038
+ }
1039
+ .omada-user-media audio {
1040
+ width: 360px;
1041
+ max-width: min(42vw, 360px);
1042
+ display: block;
1043
+ }
1044
+ .omada-response-status {
1045
+ color: var(--omada-dark-muted) !important;
1046
+ }
1047
+ .omada-token-pill {
1048
+ display: inline-block;
1049
+ padding: 1px 8px;
1050
+ margin: 1px 2px;
1051
+ border-radius: 999px;
1052
+ border: 1px solid var(--omada-border);
1053
+ font-size: 0.82em;
1054
+ line-height: 1.6;
1055
+ vertical-align: baseline;
1056
+ background: #f7f8fa;
1057
+ }
1058
+ .omada-token-mask {
1059
+ border-color: #8da2c6;
1060
+ background: #eef3ff;
1061
+ color: #1f3d7a;
1062
+ font-weight: 600;
1063
+ }
1064
+ .omada-token-special {
1065
+ border-color: #c5ccd8;
1066
+ background: #f3f4f7;
1067
+ color: #4b5563;
1068
+ }
1069
+ /* Apple-like glass look */
1070
+ :root {
1071
+ --omada-glass-bg: rgba(255, 255, 255, 0.22);
1072
+ --omada-glass-strong: rgba(255, 255, 255, 0.32);
1073
+ --omada-glass-border: rgba(255, 255, 255, 0.58);
1074
+ --omada-glass-shadow: 0 20px 52px rgba(31, 38, 70, 0.14);
1075
+ }
1076
+ html, body, .gradio-container {
1077
+ background:
1078
+ radial-gradient(1200px 500px at 10% -10%, rgba(255,255,255,0.80), rgba(255,255,255,0.30) 45%, rgba(245,247,251,0.92) 100%),
1079
+ linear-gradient(135deg, #edf1f7 0%, #e7ecf3 45%, #eff3f8 100%) !important;
1080
+ }
1081
+ .omada-input-row,
1082
+ .omada-controls-safe,
1083
+ .omada-panel,
1084
+ .gradio-chatbot .message,
1085
+ .omada-chip button,
1086
+ .omada-input-row .gradio-textbox textarea,
1087
+ .omada-plus-btn button,
1088
+ .omada-send-btn button,
1089
+ .omada-auto select {
1090
+ background: var(--omada-glass-bg) !important;
1091
+ border: 1px solid var(--omada-glass-border) !important;
1092
+ box-shadow: var(--omada-glass-shadow) !important;
1093
+ backdrop-filter: blur(22px) saturate(175%);
1094
+ -webkit-backdrop-filter: blur(22px) saturate(175%);
1095
+ }
1096
+ .omada-controls-safe {
1097
+ padding: 14px 16px !important;
1098
+ border-radius: 28px !important;
1099
+ margin: 10px auto 10px auto !important;
1100
+ }
1101
+ .omada-controls-safe > div {
1102
+ padding: 10px 12px !important;
1103
+ border-radius: 22px !important;
1104
+ }
1105
+ .omada-controls-safe .gradio-button,
1106
+ .omada-controls-safe button,
1107
+ .omada-controls-safe .gradio-dropdown,
1108
+ .omada-controls-safe .gradio-textbox,
1109
+ .omada-controls-safe .gradio-slider {
1110
+ border-radius: 16px !important;
1111
+ }
1112
+ .omada-controls-safe .gradio-button {
1113
+ border: 1px solid var(--omada-glass-border) !important;
1114
+ }
1115
+ .gradio-chatbot .message.user {
1116
+ background: var(--omada-glass-strong) !important;
1117
+ color: #1f2937 !important;
1118
+ }
1119
+ .gradio-chatbot .message.bot {
1120
+ background: rgba(255, 255, 255, 0.50) !important;
1121
+ color: #1f2937 !important;
1122
+ }
1123
+ .omada-chip button {
1124
+ color: #273247 !important;
1125
+ }
1126
+ .omada-panel {
1127
+ border-radius: 28px !important;
1128
+ padding: 20px !important;
1129
+ }
1130
+ .omada-input-row {
1131
+ border-radius: 999px !important;
1132
+ }
1133
+ """
1134
 
1135
  with gr.Blocks(
1136
  title="AIDAS Lab @ SNU - Omni-modal Diffusion",
1137
+ css=CUSTOM_CSS + EXTRA_CSS,
1138
  theme=theme,
1139
  js=FORCE_LIGHT_MODE_JS,
1140
  ) as demo:
1141
+ model_status = gr.Markdown("Model status: Loading model...", visible=False)
1142
+ demo.load(warmup_model_status, outputs=[model_status])
1143
+
1144
+
1145
+ MODE_OPTIONS = [
1146
+ "Chat",
1147
+ "MMU (Image → Text)",
1148
+ "MMU (Video → Text)",
1149
+ "Image Generation",
1150
+ "Image Editing",
1151
+ "ASR",
1152
+ "TTS",
1153
+ ]
1154
 
1155
+ with gr.Column(elem_classes=["omada-shell"]):
1156
+ chatbox = gr.Chatbot(
1157
+ label=None,
1158
+ height=850,
1159
+ sanitize_html=False,
1160
+ bubble_full_width=False,
1161
+ elem_classes=["omada-chatbot"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1162
  )
1163
 
1164
+ sample_prompts = [ex[0] for ex in CHAT_EXAMPLES[:3]] if CHAT_EXAMPLES else []
1165
+ sample_state = gr.State((sample_prompts + ["", "", ""])[:3])
1166
+ sample_payloads = gr.State(
1167
+ ([{"text": p, "image": None, "audio": None, "video": None} for p in sample_prompts] + [{"text": "", "image": None, "audio": None, "video": None}] * 3)[:3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1168
  )
1169
+ with gr.Row(elem_classes=["omada-sample-row"]):
1170
+ sample_buttons = []
1171
+ for i in range(3):
1172
+ label = sample_prompts[i] if i < len(sample_prompts) else ""
1173
+ sample_buttons.append(gr.Button(label, elem_classes=["omada-chip"], visible=bool(label)))
1174
+
1175
+ with gr.Row(elem_classes=["omada-input-row"]):
1176
+ plus_btn = gr.Button("+", elem_classes=["omada-plus-btn"], scale=1, min_width=36)
1177
+ chat_input = gr.Textbox(
1178
+ show_label=False,
1179
+ placeholder="How can I help you today?",
1180
+ lines=1,
1181
+ scale=12,
1182
+ min_width=0,
1183
+ )
1184
+ auto_dropdown = gr.Dropdown(
1185
+ ["Auto", "Custom"],
1186
+ value="Auto",
1187
+ show_label=False,
1188
+ elem_classes=["omada-auto"],
1189
+ scale=2,
1190
+ min_width=0,
1191
+ )
1192
+ send_button = gr.Button("↑", elem_classes=["omada-send-btn"], scale=1, min_width=36)
1193
+
1194
+ controls_visible = gr.State(False)
1195
+ backdrop = gr.HTML("<div></div>", visible=False, elem_classes=["omada-panel-backdrop"])
1196
+ controls_panel = gr.Column(visible=False, elem_classes=["omada-controls-safe"])
1197
+ with controls_panel:
1198
+ gr.Markdown("**Task Settings**")
1199
+ mode_selector = gr.State("Chat")
1200
+ selected_task_label = gr.Markdown("Selected task: `Chat`")
1201
+ with gr.Row():
1202
+ task_buttons = [gr.Button(option, size="sm") for option in MODE_OPTIONS]
1203
+ media_image = gr.Image(type="pil", label="Image", sources=["upload"], visible=False)
1204
+ media_audio = gr.Audio(type="filepath", label="Audio", sources=["microphone", "upload"], visible=False)
1205
+ media_video = gr.Video(label="Video", sources=["upload", "webcam"], visible=False)
1206
+
1207
+ auto_mode = auto_dropdown
1208
+
1209
+ adv_chat = gr.Column(visible=False)
1210
+ with adv_chat:
1211
+ chat_max_tokens = gr.Slider(2, 512, value=512, step=2, label="Chat max tokens", interactive=True)
1212
+ chat_steps = gr.Slider(2, 512, value=512, step=2, label="Chat steps", interactive=True)
1213
+ chat_block = gr.Slider(2, 512, value=16, step=2, label="Chat block length", interactive=True)
1214
+ chat_temperature_slider = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="Chat temperature", interactive=True)
1215
+
1216
+ adv_t2s = gr.Column(visible=False)
1217
+ with adv_t2s:
1218
+ t2s_max_tokens = gr.Slider(2, 512, value=384, step=2, label="Speech token length", interactive=True)
1219
+ t2s_steps = gr.Slider(2, 512, value=128, step=2, label="T2S refinement steps", interactive=True)
1220
+ t2s_block = gr.Slider(2, 512, value=128, step=2, label="T2S block length", interactive=True)
1221
+ t2s_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="T2S temperature", interactive=True)
1222
+ t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, step=0.1, label="T2S CFG scale", interactive=True)
1223
+ t2s_gender = gr.Dropdown(["random", "female", "male"], value="random", label="T2S gender", interactive=True)
1224
+ t2s_emotion = gr.Dropdown(["random", "angry", "happy", "neutral", "sad"], value="random", label="T2S emotion", interactive=True)
1225
+ t2s_speed = gr.Dropdown(["random", "normal", "fast", "slow"], value="random", label="T2S speed", interactive=True)
1226
+ t2s_pitch = gr.Dropdown(["random", "normal", "high", "low"], value="random", label="T2S pitch", interactive=True)
1227
+
1228
+ adv_s2t = gr.Column(visible=False)
1229
+ with adv_s2t:
1230
+ s2t_steps = gr.Slider(2, 512, value=128, step=2, label="S2T steps", interactive=True)
1231
+ s2t_block = gr.Slider(2, 512, value=16, step=2, label="S2T block length", interactive=True)
1232
+ s2t_max_tokens = gr.Slider(2, 512, value=128, step=2, label="S2T max tokens", interactive=True)
1233
+ s2t_remasking = gr.Dropdown(["low_confidence", "random"], value="low_confidence", label="S2T remasking", interactive=True)
1234
+
1235
+ adv_v2t = gr.Column(visible=False)
1236
+ with adv_v2t:
1237
+ v2t_steps = gr.Slider(2, 512, value=256, step=2, label="V2T steps", interactive=True)
1238
+ v2t_block = gr.Slider(2, 512, value=16, step=2, label="V2T block length", interactive=True)
1239
+ v2t_max_tokens = gr.Slider(2, 512, value=256, step=2, label="V2T max tokens", interactive=True)
1240
+
1241
+ adv_t2i = gr.Column(visible=False)
1242
+ with adv_t2i:
1243
+ t2i_timesteps = gr.Slider(4, 128, value=16, step=2, label="T2I timesteps", interactive=True)
1244
+ t2i_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="T2I temperature", interactive=True)
1245
+ t2i_guidance = gr.Slider(0.0, 8.0, value=2.5, step=0.1, label="T2I CFG scale", interactive=True)
1246
+
1247
+ adv_i2i = gr.Column(visible=False)
1248
+ with adv_i2i:
1249
+ i2i_timesteps = gr.Slider(4, 128, value=64, step=2, label="I2I timesteps", interactive=True)
1250
+ i2i_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="I2I temperature", interactive=True)
1251
+ i2i_guidance = gr.Slider(0.0, 8.0, value=2.5, step=0.1, label="I2I CFG scale", interactive=True)
1252
+
1253
+ adv_mmu = gr.Column(visible=False)
1254
+ with adv_mmu:
1255
+ mmu_max_tokens = gr.Slider(2, 512, value=128, step=2, label="MMU max tokens", interactive=True)
1256
+ mmu_steps = gr.Slider(2, 512, value=128, step=2, label="MMU steps", interactive=True)
1257
+ mmu_block = gr.Slider(2, 512, value=16, step=2, label="MMU block length", interactive=True)
1258
+ mmu_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="MMU temperature", interactive=True)
1259
+
1260
+ save_btn = gr.Button("Save", variant="primary")
1261
+
1262
+ def _open_controls(auto_mode, mode):
1263
+ adv_updates = _update_advanced(mode, auto_mode)
1264
+ return (gr.update(visible=True), True, *adv_updates)
1265
+
1266
+ plus_btn.click(
1267
+ _open_controls,
1268
+ inputs=[auto_dropdown, mode_selector],
1269
+ outputs=[controls_panel, controls_visible, adv_chat, adv_t2s, adv_s2t, adv_v2t, adv_t2i, adv_i2i, adv_mmu],
1270
+ )
1271
 
1272
+ def _update_advanced(mode, auto_mode):
1273
+ show = auto_mode == "Custom"
1274
+ return (
1275
+ gr.update(visible=show and mode == "Chat"),
1276
+ gr.update(visible=show and mode == "TTS"),
1277
+ gr.update(visible=show and mode == "ASR"),
1278
+ gr.update(visible=show and mode == "MMU (Video → Text)"),
1279
+ gr.update(visible=show and mode == "Image Generation"),
1280
+ gr.update(visible=show and mode == "Image Editing"),
1281
+ gr.update(visible=show and mode == "MMU (Image → Text)"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1282
  )
1283
 
1284
+ def _handle_custom(auto_mode, current_visible, mode):
1285
+ if auto_mode == "Custom":
1286
+ adv_updates = _update_advanced(mode, auto_mode)
1287
+ return (gr.update(visible=True), True, *adv_updates)
1288
+ # Auto -> show only task selector, hide advanced panels
1289
+ adv_updates = (gr.update(visible=False),) * 7
1290
+ return (gr.update(visible=True), True, *adv_updates)
1291
+
1292
+ auto_dropdown.change(
1293
+ _handle_custom,
1294
+ inputs=[auto_dropdown, controls_visible, mode_selector],
1295
+ outputs=[controls_panel, controls_visible, adv_chat, adv_t2s, adv_s2t, adv_v2t, adv_t2i, adv_i2i, adv_mmu],
1296
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
1297
 
1298
+ def _update_mode(mode):
1299
+ show_image = mode in {"Image Editing", "MMU (Image → Text)"}
1300
+ show_audio = mode in {"ASR"}
1301
+ show_video = mode in {"MMU (Video → Text)"}
1302
+ placeholders = {
1303
+ "Chat": "How can I help you today?",
1304
+ "TTS": "Type the speech you want to synthesize...",
1305
+ "ASR": "Upload audio, then add notes here...",
1306
+ "MMU (Video → Text)": "Upload video, then add notes here...",
1307
+ "Image Generation": "Describe the image you want to generate...",
1308
+ "Image Editing": "Describe how you want to edit the image...",
1309
+ "MMU (Image → Text)": "Ask about the uploaded image...",
1310
+ }
1311
+ payloads = []
1312
+ if mode == "Chat":
1313
+ payloads = [{"text": ex[0], "image": None, "audio": None, "video": None} for ex in CHAT_EXAMPLES[:3]]
1314
+ elif mode == "TTS":
1315
+ payloads = [{"text": ex[0], "image": None, "audio": None, "video": None} for ex in T2S_EXAMPLES[:3]]
1316
+ elif mode == "Image Generation":
1317
+ payloads = [{"text": ex[0], "image": None, "audio": None, "video": None} for ex in T2I_EXAMPLES[:3]]
1318
+ elif mode == "Image Editing":
1319
+ payloads = [{"text": ex[1], "image": ex[0], "audio": None, "video": None} for ex in I2I_EXAMPLES[:3]]
1320
+ elif mode == "MMU (Video → Text)":
1321
+ payloads = [{"text": "", "image": None, "audio": None, "video": ex[0]} for ex in V2T_EXAMPLES[:3]]
1322
+ elif mode == "ASR":
1323
+ payloads = [{"text": "", "image": None, "audio": ex[0], "video": None} for ex in S2T_EXAMPLES[:3]]
1324
+ elif mode == "MMU (Image → Text)":
1325
+ payloads = [{"text": ex[1], "image": ex[0], "audio": None, "video": None} for ex in MMU_EXAMPLES[:3]]
1326
+ payloads = (payloads + [{"text": "", "image": None, "audio": None, "video": None}] * 3)[:3]
1327
+ samples = [p.get("text", "") or os.path.basename(p.get("video") or p.get("audio") or p.get("image") or "") for p in payloads]
1328
+ return (
1329
+ gr.update(visible=show_image),
1330
+ gr.update(visible=show_audio),
1331
+ gr.update(visible=show_video),
1332
+ gr.update(placeholder=placeholders.get(mode, "How can I help you today?")),
1333
+ [s for s in samples],
1334
+ payloads,
1335
+ gr.update(value=samples[0], visible=bool(samples[0])),
1336
+ gr.update(value=samples[1], visible=bool(samples[1])),
1337
+ gr.update(value=samples[2], visible=bool(samples[2])),
1338
  )
1339
+ _update_mode("Chat")
1340
+ def _pick_mode(choice, auto_mode):
1341
+ adv_updates = _update_advanced(choice, auto_mode)
1342
+ return choice, f"Selected task: `{choice}`", *adv_updates
1343
+
1344
+ for idx, task_btn in enumerate(task_buttons):
1345
+ task_btn.click(
1346
+ lambda auto_mode, choice=MODE_OPTIONS[idx]: _pick_mode(choice, auto_mode),
1347
+ inputs=[auto_dropdown],
1348
+ outputs=[mode_selector, selected_task_label, adv_chat, adv_t2s, adv_s2t, adv_v2t, adv_t2i, adv_i2i, adv_mmu],
 
 
 
 
 
 
 
 
 
1349
  )
1350
 
1351
+ def _save_controls(mode, auto_mode):
1352
+ mode_updates = _update_mode(mode)
1353
+ adv_updates = _update_advanced(mode, auto_mode)
1354
+ return (
1355
+ gr.update(visible=False),
1356
+ False,
1357
+ *mode_updates,
1358
+ *adv_updates,
1359
  )
1360
+
1361
+ save_btn.click(
1362
+ _save_controls,
1363
+ inputs=[mode_selector, auto_dropdown],
1364
+ outputs=[
1365
+ controls_panel,
1366
+ controls_visible,
1367
+ media_image,
1368
+ media_audio,
1369
+ media_video,
1370
+ chat_input,
1371
+ sample_state,
1372
+ sample_payloads,
1373
+ *sample_buttons,
1374
+ adv_chat,
1375
+ adv_t2s,
1376
+ adv_s2t,
1377
+ adv_v2t,
1378
+ adv_t2i,
1379
+ adv_i2i,
1380
+ adv_mmu,
1381
+ ],
1382
+ )
1383
+
1384
+
1385
+ def _format_user_message(msg: str) -> str:
1386
+ return msg.strip() if msg else " "
1387
+
1388
+ def _chat_handler(
1389
+ history,
1390
+ message,
1391
+ mode,
1392
+ auto_mode,
1393
+ image_in,
1394
+ audio_in,
1395
+ video_in,
1396
+ chat_max_tokens,
1397
+ chat_steps,
1398
+ chat_block,
1399
+ chat_temperature,
1400
+ t2s_max_tokens,
1401
+ t2s_steps,
1402
+ t2s_block,
1403
+ t2s_temperature,
1404
+ t2s_cfg,
1405
+ t2s_gender,
1406
+ t2s_emotion,
1407
+ t2s_speed,
1408
+ t2s_pitch,
1409
+ s2t_steps,
1410
+ s2t_block,
1411
+ s2t_max_tokens,
1412
+ s2t_remasking,
1413
+ v2t_steps,
1414
+ v2t_block,
1415
+ v2t_max_tokens,
1416
+ t2i_timesteps,
1417
+ t2i_temperature,
1418
+ t2i_guidance,
1419
+ i2i_timesteps,
1420
+ i2i_temperature,
1421
+ i2i_guidance,
1422
+ mmu_max_tokens,
1423
+ mmu_steps,
1424
+ mmu_block,
1425
+ mmu_temperature,
1426
+ ):
1427
+ history = history or []
1428
+ message = (message or "").strip()
1429
+ defer_video = mode == "MMU (Video → Text)" and bool(video_in)
1430
+ display_user = _render_user_message(mode, message, image_in, audio_in, video_in, defer_video=defer_video)
1431
+ history.append((display_user, _render_text_message("Generating...", "")))
1432
+ yield history, ""
1433
+
1434
+ if defer_video:
1435
+ display_user = _render_user_message(mode, message, image_in, audio_in, video_in, defer_video=False)
1436
+ history[-1] = (display_user, history[-1][1])
1437
+ yield history, ""
1438
+
1439
+ app = get_app()
1440
+ # Respect UI mode: Auto uses eval-matched defaults, Custom uses UI values.
1441
+ app.force_eval_settings = str(auto_mode).strip().lower() == "auto"
1442
+
1443
+ if mode == "Chat":
1444
+ for reply_html, status, done in app.run_chat_stream(
1445
+ message,
1446
  chat_max_tokens,
1447
  chat_steps,
1448
  chat_block,
1449
+ chat_temperature,
1450
+ update_every=64,
1451
+ ):
1452
+ response = _render_response(status, reply_html)
1453
+ history[-1] = (display_user, response)
1454
+ yield history, ""
1455
+ return
1456
+
1457
+ if mode == "TTS":
1458
+ if not message:
1459
+ history[-1] = (display_user, _render_text_message("Please type some text.", ""))
1460
+ yield history, ""
1461
+ return
1462
+ audio, status = app.run_t2s(
1463
+ message,
1464
+ t2s_max_tokens,
1465
+ t2s_steps,
1466
+ t2s_block,
1467
+ t2s_temperature,
1468
+ t2s_cfg,
1469
+ t2s_gender,
1470
+ t2s_emotion,
1471
+ t2s_speed,
1472
+ t2s_pitch,
1473
+ )
1474
+ history[-1] = (display_user, _render_audio_message(status, audio))
1475
+ yield history, ""
1476
+ return
1477
+
1478
+ if mode == "ASR":
1479
+ if not audio_in:
1480
+ history[-1] = (display_user, _render_text_message("Please upload audio.", ""))
1481
+ yield history, ""
1482
+ return
1483
+ for text, status in app.run_s2t_stream(
1484
+ audio_in,
1485
+ s2t_steps,
1486
+ s2t_block,
1487
+ s2t_max_tokens,
1488
+ s2t_remasking,
1489
+ update_every=32,
1490
+ ):
1491
+ history[-1] = (display_user, _render_text_message(status, text))
1492
+ yield history, ""
1493
+ return
1494
+
1495
+ if mode == "MMU (Video → Text)":
1496
+ if not video_in:
1497
+ history[-1] = (display_user, _render_text_message("Please upload a video.", ""))
1498
+ yield history, ""
1499
+ return
1500
+ for text, status in app.run_v2t_stream(
1501
+ video_in,
1502
+ v2t_steps,
1503
+ v2t_block,
1504
+ v2t_max_tokens,
1505
+ update_every=32,
1506
+ ):
1507
+ history[-1] = (display_user, _render_text_message(status, text))
1508
+ yield history, ""
1509
+ return
1510
+
1511
+ if mode == "Image Generation":
1512
+ if not message:
1513
+ history[-1] = (display_user, _render_text_message("Please provide a prompt.", ""))
1514
+ yield history, ""
1515
+ return
1516
+ for image, status in app.run_t2i_stream(
1517
+ message,
1518
+ t2i_timesteps,
1519
+ t2i_temperature,
1520
+ t2i_guidance,
1521
+ update_every=2,
1522
+ ):
1523
+ history[-1] = (display_user, _render_image_message(status, image))
1524
+ yield history, ""
1525
+ return
1526
+
1527
+ if mode == "Image Editing":
1528
+ if not image_in:
1529
+ history[-1] = (display_user, _render_text_message("Please upload an image.", ""))
1530
+ yield history, ""
1531
+ return
1532
+ if not message:
1533
+ history[-1] = (display_user, _render_text_message("Please provide an edit instruction.", ""))
1534
+ yield history, ""
1535
+ return
1536
+ for image, status in app.run_i2i_stream(
1537
+ message,
1538
+ image_in,
1539
+ i2i_timesteps,
1540
+ i2i_temperature,
1541
+ i2i_guidance,
1542
+ update_every=2,
1543
+ ):
1544
+ history[-1] = (display_user, _render_image_message(status, image))
1545
+ yield history, ""
1546
+ return
1547
+
1548
+ if mode == "MMU (Image → Text)":
1549
+ if not image_in:
1550
+ history[-1] = (display_user, _render_text_message("Please upload an image.", ""))
1551
+ yield history, ""
1552
+ return
1553
+ reply, status = app.run_mmu(
1554
+ images=[image_in],
1555
+ message=message,
1556
+ max_new_tokens=mmu_max_tokens,
1557
+ steps=mmu_steps,
1558
+ block_length=mmu_block,
1559
+ temperature=mmu_temperature,
1560
+ )
1561
+ history[-1] = (display_user, _render_text_message(status, reply))
1562
+ yield history, ""
1563
+ return
1564
+
1565
+ history[-1] = (display_user, _render_text_message("Unsupported mode.", ""))
1566
+ yield history, ""
1567
+
1568
+ submit_inputs = [
1569
+ chatbox,
1570
+ chat_input,
1571
+ mode_selector,
1572
+ auto_dropdown,
1573
+ media_image,
1574
+ media_audio,
1575
+ media_video,
1576
+ chat_max_tokens,
1577
+ chat_steps,
1578
+ chat_block,
1579
+ chat_temperature_slider,
1580
+ t2s_max_tokens,
1581
+ t2s_steps,
1582
+ t2s_block,
1583
+ t2s_temperature,
1584
+ t2s_cfg,
1585
+ t2s_gender,
1586
+ t2s_emotion,
1587
+ t2s_speed,
1588
+ t2s_pitch,
1589
+ s2t_steps,
1590
+ s2t_block,
1591
+ s2t_max_tokens,
1592
+ s2t_remasking,
1593
+ v2t_steps,
1594
+ v2t_block,
1595
+ v2t_max_tokens,
1596
+ t2i_timesteps,
1597
+ t2i_temperature,
1598
+ t2i_guidance,
1599
+ i2i_timesteps,
1600
+ i2i_temperature,
1601
+ i2i_guidance,
1602
+ mmu_max_tokens,
1603
+ mmu_steps,
1604
+ mmu_block,
1605
+ mmu_temperature,
1606
+ ]
1607
+ submit_outputs = [chatbox, chat_input]
1608
+
1609
+ chat_input.submit(_chat_handler, inputs=submit_inputs, outputs=submit_outputs)
1610
+ send_button.click(_chat_handler, inputs=submit_inputs, outputs=submit_outputs)
1611
+
1612
+ def _use_sample(payload_list, idx):
1613
+ if not payload_list or idx >= len(payload_list):
1614
+ return "", None, None, None
1615
+ item = payload_list[idx] or {}
1616
+ return item.get("text", ""), item.get("image"), item.get("audio"), item.get("video")
1617
+
1618
+ for i, btn in enumerate(sample_buttons):
1619
+ btn.click(
1620
+ lambda payloads, idx=i: _use_sample(payloads, idx),
1621
+ inputs=[sample_payloads],
1622
+ outputs=[chat_input, media_image, media_audio, media_video],
1623
  )
1624
 
1625
 
1626
 
1627
  if __name__ == "__main__":
1628
+ demo.launch(allowed_paths=[str(PREVIEW_DIR), "/tmp"])