YucYux
commited on
Commit
·
60e176e
1
Parent(s):
d08f144
Added support for MMaDA-8B-MixCoT
Browse files
app.py
CHANGED
|
@@ -47,22 +47,23 @@ def get_num_transfer_tokens(mask_index, steps):
|
|
| 47 |
return num_transfer_tokens
|
| 48 |
|
| 49 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 50 |
-
DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-
|
| 51 |
MASK_ID = 126336
|
| 52 |
MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
|
| 53 |
TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
|
| 54 |
uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
|
| 55 |
VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
|
| 56 |
|
| 57 |
-
CURRENT_MODEL_PATH =
|
| 58 |
|
| 59 |
MODEL_CHOICES = [
|
| 60 |
"MMaDA-8B-Base",
|
| 61 |
-
"MMaDA-8B-MixCoT
|
| 62 |
"MMaDA-8B-Max (coming soon)"
|
| 63 |
]
|
| 64 |
MODEL_ACTUAL_PATHS = {
|
| 65 |
-
"MMaDA-8B-Base":
|
|
|
|
| 66 |
}
|
| 67 |
|
| 68 |
def clear_outputs_action():
|
|
@@ -116,19 +117,91 @@ def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_st
|
|
| 116 |
# return f"Error loading model '{model_display_name_for_status}': {str(e)}"
|
| 117 |
|
| 118 |
def handle_model_selection_change(selected_model_name_ui):
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
MODEL = None
|
| 122 |
TOKENIZER = None
|
| 123 |
MASK_ID = None
|
| 124 |
CURRENT_MODEL_PATH = None
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
|
|
@@ -618,7 +691,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
|
|
| 618 |
model_select_radio = gr.Radio(
|
| 619 |
label="Select Text Generation Model",
|
| 620 |
choices=MODEL_CHOICES,
|
| 621 |
-
value=
|
| 622 |
)
|
| 623 |
model_load_status_box = gr.Textbox(
|
| 624 |
label="Model Load Status",
|
|
@@ -662,17 +735,39 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
|
|
| 662 |
output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
|
| 663 |
|
| 664 |
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
|
| 677 |
gr.Markdown("---")
|
| 678 |
gr.Markdown("## Part 2. Multimodal Understanding")
|
|
@@ -681,7 +776,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
|
|
| 681 |
prompt_input_box_mmu = gr.Textbox(
|
| 682 |
label="Enter your prompt:",
|
| 683 |
lines=3,
|
| 684 |
-
value="
|
| 685 |
)
|
| 686 |
think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
|
| 687 |
with gr.Accordion("Generation Parameters", open=True):
|
|
@@ -689,7 +784,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
|
|
| 689 |
gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
|
| 690 |
steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
|
| 691 |
with gr.Row():
|
| 692 |
-
block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=
|
| 693 |
remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
|
| 694 |
with gr.Row():
|
| 695 |
cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
|
|
@@ -715,44 +810,120 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
|
|
| 715 |
gr.Markdown("## Final Generated Text")
|
| 716 |
output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
|
| 717 |
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
],
|
| 731 |
-
[
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
]
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 756 |
|
| 757 |
gr.Markdown("---")
|
| 758 |
gr.Markdown("## Part 3. Text-to-Image Generation")
|
|
@@ -823,21 +994,69 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
|
|
| 823 |
inputs=[thinking_mode_mmu],
|
| 824 |
outputs=[thinking_mode_mmu, think_button_mmu]
|
| 825 |
)
|
| 826 |
-
|
| 827 |
|
|
|
|
|
|
|
| 828 |
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 833 |
|
| 834 |
demo.load(
|
| 835 |
-
fn=
|
| 836 |
inputs=None,
|
| 837 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 838 |
queue=True
|
| 839 |
)
|
| 840 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
def clear_outputs():
|
| 842 |
return None, None, None # Clear image, visualization, and final text
|
| 843 |
|
|
|
|
| 47 |
return num_transfer_tokens
|
| 48 |
|
| 49 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 50 |
+
DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT" # Default
|
| 51 |
MASK_ID = 126336
|
| 52 |
MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
|
| 53 |
TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
|
| 54 |
uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
|
| 55 |
VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
|
| 56 |
|
| 57 |
+
CURRENT_MODEL_PATH = DEFAULT_MODEL_PATH
|
| 58 |
|
| 59 |
MODEL_CHOICES = [
|
| 60 |
"MMaDA-8B-Base",
|
| 61 |
+
"MMaDA-8B-MixCoT",
|
| 62 |
"MMaDA-8B-Max (coming soon)"
|
| 63 |
]
|
| 64 |
MODEL_ACTUAL_PATHS = {
|
| 65 |
+
"MMaDA-8B-Base": "Gen-Verse/MMaDA-8B-Base",
|
| 66 |
+
"MMaDA-8B-MixCoT": "Gen-Verse/MMaDA-8B-MixCoT"
|
| 67 |
}
|
| 68 |
|
| 69 |
def clear_outputs_action():
|
|
|
|
| 117 |
# return f"Error loading model '{model_display_name_for_status}': {str(e)}"
|
| 118 |
|
| 119 |
def handle_model_selection_change(selected_model_name_ui):
|
| 120 |
+
global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
|
| 121 |
+
|
| 122 |
+
status_msg = ""
|
| 123 |
+
# 初始化 Examples 的可见性更新
|
| 124 |
+
vis_lm_base = gr.update(visible=False)
|
| 125 |
+
vis_lm_mixcot = gr.update(visible=False)
|
| 126 |
+
vis_lm_max = gr.update(visible=False)
|
| 127 |
+
vis_mmu_base = gr.update(visible=False)
|
| 128 |
+
vis_mmu_mixcot = gr.update(visible=False)
|
| 129 |
+
vis_mmu_max = gr.update(visible=False)
|
| 130 |
+
|
| 131 |
+
# 根据选择的模型决定 thinking mode 的默认状态
|
| 132 |
+
is_mixcot_model_selected = (selected_model_name_ui == "MMaDA-8B-MixCoT")
|
| 133 |
+
|
| 134 |
+
# 初始 thinking mode 状态和按钮标签
|
| 135 |
+
# 如果是 MixCoT 模型,则默认为 True (开启)
|
| 136 |
+
current_thinking_mode_lm_state = is_mixcot_model_selected
|
| 137 |
+
current_thinking_mode_mmu_state = is_mixcot_model_selected
|
| 138 |
+
|
| 139 |
+
lm_think_button_label = "Thinking Mode ✅" if current_thinking_mode_lm_state else "Thinking Mode ❌"
|
| 140 |
+
mmu_think_button_label = "Thinking Mode ✅" if current_thinking_mode_mmu_state else "Thinking Mode ❌"
|
| 141 |
+
|
| 142 |
+
update_think_button_lm = gr.update(value=lm_think_button_label)
|
| 143 |
+
update_think_button_mmu = gr.update(value=mmu_think_button_label)
|
| 144 |
+
|
| 145 |
+
if selected_model_name_ui == "MMaDA-8B-Max (coming soon)":
|
| 146 |
MODEL = None
|
| 147 |
TOKENIZER = None
|
| 148 |
MASK_ID = None
|
| 149 |
CURRENT_MODEL_PATH = None
|
| 150 |
+
status_msg = f"'{selected_model_name_ui}' is not yet available. Please select another model."
|
| 151 |
+
vis_lm_max = gr.update(visible=True)
|
| 152 |
+
vis_mmu_max = gr.update(visible=True)
|
| 153 |
+
# 对于非 MixCoT 模型,thinking mode 在上面已经根据 is_mixcot_model_selected 设置为 False
|
| 154 |
+
else:
|
| 155 |
+
actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
|
| 156 |
+
if not actual_path:
|
| 157 |
+
MODEL = None
|
| 158 |
+
TOKENIZER = None
|
| 159 |
+
MASK_ID = None
|
| 160 |
+
CURRENT_MODEL_PATH = None
|
| 161 |
+
status_msg = f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
|
| 162 |
+
# 如果路径未定义(意味着不是有效的MixCoT加载),thinking mode应为False
|
| 163 |
+
if is_mixcot_model_selected: # 如果本应是MixCoT但路径没有
|
| 164 |
+
current_thinking_mode_lm_state = False
|
| 165 |
+
current_thinking_mode_mmu_state = False
|
| 166 |
+
update_think_button_lm = gr.update(value="Thinking Mode ❌")
|
| 167 |
+
update_think_button_mmu = gr.update(value="Thinking Mode ❌")
|
| 168 |
+
else:
|
| 169 |
+
# 尝试加载模型
|
| 170 |
+
status_msg = _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
|
| 171 |
+
|
| 172 |
+
# 检查模型是否成功加载
|
| 173 |
+
if "Error loading model" in status_msg or MODEL is None:
|
| 174 |
+
# 如果是 MixCoT 模型但加载失败,则关闭 thinking mode
|
| 175 |
+
if is_mixcot_model_selected:
|
| 176 |
+
current_thinking_mode_lm_state = False
|
| 177 |
+
current_thinking_mode_mmu_state = False
|
| 178 |
+
update_think_button_lm = gr.update(value="Thinking Mode ❌")
|
| 179 |
+
update_think_button_mmu = gr.update(value="Thinking Mode ❌")
|
| 180 |
+
if MODEL is None and "Error" not in status_msg: # 补充一个通用错误信息
|
| 181 |
+
status_msg = f"Failed to properly load model '{selected_model_name_ui}'. {status_msg}"
|
| 182 |
+
else: # 模型成功加载
|
| 183 |
+
if selected_model_name_ui == "MMaDA-8B-Base":
|
| 184 |
+
vis_lm_base = gr.update(visible=True)
|
| 185 |
+
vis_mmu_base = gr.update(visible=True)
|
| 186 |
+
elif selected_model_name_ui == "MMaDA-8B-MixCoT":
|
| 187 |
+
vis_lm_mixcot = gr.update(visible=True)
|
| 188 |
+
vis_mmu_mixcot = gr.update(visible=True)
|
| 189 |
+
# thinking mode 已经在函数开头根据 is_mixcot_model_selected 设置为 True
|
| 190 |
+
|
| 191 |
+
return (
|
| 192 |
+
status_msg,
|
| 193 |
+
vis_lm_base,
|
| 194 |
+
vis_lm_mixcot,
|
| 195 |
+
vis_lm_max,
|
| 196 |
+
vis_mmu_base,
|
| 197 |
+
vis_mmu_mixcot,
|
| 198 |
+
vis_mmu_max,
|
| 199 |
+
# 新增的返回值,用于更新 thinking_mode 状态和按钮
|
| 200 |
+
current_thinking_mode_lm_state, # 直接返回值给 gr.State
|
| 201 |
+
update_think_button_lm, # gr.update 对象给 gr.Button
|
| 202 |
+
current_thinking_mode_mmu_state,
|
| 203 |
+
update_think_button_mmu
|
| 204 |
+
)
|
| 205 |
|
| 206 |
|
| 207 |
def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
|
|
|
|
| 691 |
model_select_radio = gr.Radio(
|
| 692 |
label="Select Text Generation Model",
|
| 693 |
choices=MODEL_CHOICES,
|
| 694 |
+
value="MMaDA-8B-MixCoT"
|
| 695 |
)
|
| 696 |
model_load_status_box = gr.Textbox(
|
| 697 |
label="Model Load Status",
|
|
|
|
| 735 |
output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
|
| 736 |
|
| 737 |
|
| 738 |
+
with gr.Column(visible=False) as examples_lm_base:
|
| 739 |
+
gr.Examples(
|
| 740 |
+
examples=[
|
| 741 |
+
["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
|
| 742 |
+
["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
|
| 743 |
+
],
|
| 744 |
+
inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
|
| 745 |
+
outputs=[output_visualization_box_lm, output_final_text_box_lm],
|
| 746 |
+
fn=generate_viz_wrapper_lm,
|
| 747 |
+
cache_examples=False
|
| 748 |
+
)
|
| 749 |
+
with gr.Column(visible=True) as examples_lm_mixcot:
|
| 750 |
+
gr.Examples(
|
| 751 |
+
examples=[
|
| 752 |
+
["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
|
| 753 |
+
["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
|
| 754 |
+
],
|
| 755 |
+
inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
|
| 756 |
+
outputs=[output_visualization_box_lm, output_final_text_box_lm],
|
| 757 |
+
fn=generate_viz_wrapper_lm,
|
| 758 |
+
cache_examples=False
|
| 759 |
+
)
|
| 760 |
+
with gr.Column(visible=False) as examples_lm_max:
|
| 761 |
+
gr.Examples(
|
| 762 |
+
examples=[
|
| 763 |
+
["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
|
| 764 |
+
["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
|
| 765 |
+
],
|
| 766 |
+
inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
|
| 767 |
+
outputs=[output_visualization_box_lm, output_final_text_box_lm],
|
| 768 |
+
fn=generate_viz_wrapper_lm,
|
| 769 |
+
cache_examples=False
|
| 770 |
+
)
|
| 771 |
|
| 772 |
gr.Markdown("---")
|
| 773 |
gr.Markdown("## Part 2. Multimodal Understanding")
|
|
|
|
| 776 |
prompt_input_box_mmu = gr.Textbox(
|
| 777 |
label="Enter your prompt:",
|
| 778 |
lines=3,
|
| 779 |
+
value=""
|
| 780 |
)
|
| 781 |
think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
|
| 782 |
with gr.Accordion("Generation Parameters", open=True):
|
|
|
|
| 784 |
gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
|
| 785 |
steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
|
| 786 |
with gr.Row():
|
| 787 |
+
block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=64, step=32, label="Block Length", info="gen_length must be divisible by this.")
|
| 788 |
remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
|
| 789 |
with gr.Row():
|
| 790 |
cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
|
|
|
|
| 810 |
gr.Markdown("## Final Generated Text")
|
| 811 |
output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
|
| 812 |
|
| 813 |
+
with gr.Column(visible=False) as examples_mmu_base:
|
| 814 |
+
gr.Examples(
|
| 815 |
+
examples=[
|
| 816 |
+
[
|
| 817 |
+
"figs/sunflower.jpg",
|
| 818 |
+
"Please describe this image in detail.",
|
| 819 |
+
256,
|
| 820 |
+
512,
|
| 821 |
+
128,
|
| 822 |
+
1,
|
| 823 |
+
0,
|
| 824 |
+
"low_confidence"
|
| 825 |
+
],
|
| 826 |
+
[
|
| 827 |
+
"figs/woman.jpg",
|
| 828 |
+
"Please describe this image in detail.",
|
| 829 |
+
256,
|
| 830 |
+
512,
|
| 831 |
+
128,
|
| 832 |
+
1,
|
| 833 |
+
0,
|
| 834 |
+
"low_confidence"
|
| 835 |
+
]
|
| 836 |
],
|
| 837 |
+
inputs=[
|
| 838 |
+
image_upload_box,
|
| 839 |
+
prompt_input_box_mmu,
|
| 840 |
+
steps_slider_mmu,
|
| 841 |
+
gen_length_slider_mmu,
|
| 842 |
+
block_length_slider_mmu,
|
| 843 |
+
temperature_slider_mmu,
|
| 844 |
+
cfg_scale_slider_mmu,
|
| 845 |
+
remasking_dropdown_mmu
|
| 846 |
+
],
|
| 847 |
+
outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
|
| 848 |
+
fn=generate_viz_wrapper,
|
| 849 |
+
cache_examples=False
|
| 850 |
+
)
|
| 851 |
+
with gr.Column(visible=True) as examples_mmu_mixcot:
|
| 852 |
+
gr.Examples(
|
| 853 |
+
examples=[
|
| 854 |
+
[
|
| 855 |
+
"figs/geo.png",
|
| 856 |
+
"In the given figure, a square ABCD is inscribed in a circle with center O. Point P is located on side CD. What is the value of angle APB?",
|
| 857 |
+
256,
|
| 858 |
+
512,
|
| 859 |
+
64,
|
| 860 |
+
1,
|
| 861 |
+
0,
|
| 862 |
+
"low_confidence"
|
| 863 |
+
],
|
| 864 |
+
[
|
| 865 |
+
"figs/bus.jpg",
|
| 866 |
+
"What are the colors of the bus?",
|
| 867 |
+
256,
|
| 868 |
+
512,
|
| 869 |
+
64,
|
| 870 |
+
1,
|
| 871 |
+
0,
|
| 872 |
+
"low_confidence"
|
| 873 |
+
]
|
| 874 |
+
],
|
| 875 |
+
inputs=[
|
| 876 |
+
image_upload_box,
|
| 877 |
+
prompt_input_box_mmu,
|
| 878 |
+
steps_slider_mmu,
|
| 879 |
+
gen_length_slider_mmu,
|
| 880 |
+
block_length_slider_mmu,
|
| 881 |
+
temperature_slider_mmu,
|
| 882 |
+
cfg_scale_slider_mmu,
|
| 883 |
+
remasking_dropdown_mmu
|
| 884 |
+
],
|
| 885 |
+
outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
|
| 886 |
+
fn=generate_viz_wrapper,
|
| 887 |
+
cache_examples=False
|
| 888 |
+
)
|
| 889 |
+
with gr.Column(visible=False) as examples_mmu_max:
|
| 890 |
+
gr.Examples(
|
| 891 |
+
examples=[
|
| 892 |
+
[
|
| 893 |
+
"figs/sunflower.jpg",
|
| 894 |
+
"Please describe this image in detail.",
|
| 895 |
+
256,
|
| 896 |
+
512,
|
| 897 |
+
128,
|
| 898 |
+
1,
|
| 899 |
+
0,
|
| 900 |
+
"low_confidence"
|
| 901 |
+
],
|
| 902 |
+
[
|
| 903 |
+
"figs/woman.jpg",
|
| 904 |
+
"Please describe this image in detail.",
|
| 905 |
+
256,
|
| 906 |
+
512,
|
| 907 |
+
128,
|
| 908 |
+
1,
|
| 909 |
+
0,
|
| 910 |
+
"low_confidence"
|
| 911 |
+
]
|
| 912 |
+
],
|
| 913 |
+
inputs=[
|
| 914 |
+
image_upload_box,
|
| 915 |
+
prompt_input_box_mmu,
|
| 916 |
+
steps_slider_mmu,
|
| 917 |
+
gen_length_slider_mmu,
|
| 918 |
+
block_length_slider_mmu,
|
| 919 |
+
temperature_slider_mmu,
|
| 920 |
+
cfg_scale_slider_mmu,
|
| 921 |
+
remasking_dropdown_mmu
|
| 922 |
+
],
|
| 923 |
+
outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
|
| 924 |
+
fn=generate_viz_wrapper,
|
| 925 |
+
cache_examples=False
|
| 926 |
+
)
|
| 927 |
|
| 928 |
gr.Markdown("---")
|
| 929 |
gr.Markdown("## Part 3. Text-to-Image Generation")
|
|
|
|
| 994 |
inputs=[thinking_mode_mmu],
|
| 995 |
outputs=[thinking_mode_mmu, think_button_mmu]
|
| 996 |
)
|
|
|
|
| 997 |
|
| 998 |
+
def initialize_app_state():
|
| 999 |
+
default_model_choice = "MMaDA-8B-MixCoT" # 默认加载 MixCoT
|
| 1000 |
|
| 1001 |
+
# handle_model_selection_change 现在返回更多项
|
| 1002 |
+
status, lm_b_vis, lm_m_vis, lm_x_vis, \
|
| 1003 |
+
mmu_b_vis, mmu_m_vis, mmu_x_vis, \
|
| 1004 |
+
init_thinking_lm_state, init_think_lm_btn_update, \
|
| 1005 |
+
init_thinking_mmu_state, init_think_mmu_btn_update = handle_model_selection_change(default_model_choice)
|
| 1006 |
+
|
| 1007 |
+
return (
|
| 1008 |
+
default_model_choice,
|
| 1009 |
+
status,
|
| 1010 |
+
lm_b_vis,
|
| 1011 |
+
lm_m_vis,
|
| 1012 |
+
lm_x_vis,
|
| 1013 |
+
mmu_b_vis,
|
| 1014 |
+
mmu_m_vis,
|
| 1015 |
+
mmu_x_vis,
|
| 1016 |
+
init_thinking_lm_state,
|
| 1017 |
+
init_think_lm_btn_update,
|
| 1018 |
+
init_thinking_mmu_state,
|
| 1019 |
+
init_think_mmu_btn_update
|
| 1020 |
+
)
|
| 1021 |
|
| 1022 |
demo.load(
|
| 1023 |
+
fn=initialize_app_state,
|
| 1024 |
inputs=None,
|
| 1025 |
+
outputs=[
|
| 1026 |
+
model_select_radio,
|
| 1027 |
+
model_load_status_box,
|
| 1028 |
+
examples_lm_base,
|
| 1029 |
+
examples_lm_mixcot,
|
| 1030 |
+
examples_lm_max,
|
| 1031 |
+
examples_mmu_base,
|
| 1032 |
+
examples_mmu_mixcot,
|
| 1033 |
+
examples_mmu_max,
|
| 1034 |
+
thinking_mode_lm, # gr.State for LM thinking mode
|
| 1035 |
+
think_button_lm, # gr.Button for LM thinking mode
|
| 1036 |
+
thinking_mode_mmu, # gr.State for MMU thinking mode
|
| 1037 |
+
think_button_mmu # gr.Button for MMU thinking mode
|
| 1038 |
+
],
|
| 1039 |
queue=True
|
| 1040 |
)
|
| 1041 |
|
| 1042 |
+
model_select_radio.change(
|
| 1043 |
+
fn=handle_model_selection_change,
|
| 1044 |
+
inputs=[model_select_radio],
|
| 1045 |
+
outputs=[
|
| 1046 |
+
model_load_status_box,
|
| 1047 |
+
examples_lm_base,
|
| 1048 |
+
examples_lm_mixcot,
|
| 1049 |
+
examples_lm_max,
|
| 1050 |
+
examples_mmu_base,
|
| 1051 |
+
examples_mmu_mixcot,
|
| 1052 |
+
examples_mmu_max,
|
| 1053 |
+
thinking_mode_lm,
|
| 1054 |
+
think_button_lm,
|
| 1055 |
+
thinking_mode_mmu,
|
| 1056 |
+
think_button_mmu
|
| 1057 |
+
]
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
def clear_outputs():
|
| 1061 |
return None, None, None # Clear image, visualization, and final text
|
| 1062 |
|