| """reFlow Interactive Interpretability Demo — Gradio Blocks app.""" |
|
|
| import gradio as gr |
| from model_loader import get_model |
| from experiments import ( |
| exp_semantic_galaxy, |
| exp_semantic_algebra, |
| exp_typo_resilience, |
| exp_sparsity_profile, |
| exp_layer_evolution, |
| exp_causal_ablation, |
| exp_concept_inception, |
| exp_generate, |
| exp_basis_geometry, |
| exp_recipe_neighbors, |
| exp_task_crystallization, |
| ) |
| from i18n import t |
|
|
|
|
| def safe_call(fn, *args, **kwargs): |
| """Wrap experiment functions with error handling.""" |
| try: |
| return fn(*args, **kwargs) |
| except Exception as e: |
| import traceback |
| return f"Error: {e}\n\n{traceback.format_exc()}" |
|
|
|
|
| def build_lang_ui(lang): |
| """Build UI for a specific language.""" |
| with gr.Column(): |
| gr.Markdown(f"# {t('title', lang)}\n\n{t('intro', lang)}") |
|
|
| |
| with gr.Tab(t("tab_generate", lang)): |
| gr.Markdown(t("gen_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gen_prompt = gr.Textbox(label=t("gen_prompt", lang), value="Once upon a time", lines=3) |
| gen_samples = gr.Slider(1, 8, 1, step=1, label=t("gen_samples", lang)) |
| gen_max_tokens = gr.Slider(1, 512, 16, step=1, label=t("gen_max_tokens", lang)) |
| gen_temperature = gr.Slider(0.01, 2.0, 0.8, step=0.01, label=t("gen_temperature", lang)) |
| gen_top_k = gr.Slider(0, 500, 50, step=1, label=t("gen_top_k", lang)) |
| gen_rep_penalty = gr.Slider(1.0, 2.0, 1.1, step=0.05, label=t("gen_rep_penalty", lang)) |
| btn_gen = gr.Button(t("btn_generate", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_gen = gr.Textbox(label=t("gen_output", lang), lines=20, interactive=False) |
| btn_gen.click(lambda *a: safe_call(exp_generate, *a), |
| inputs=[gen_prompt, gen_samples, gen_max_tokens, gen_temperature, gen_top_k, gen_rep_penalty], |
| outputs=out_gen) |
|
|
| |
| with gr.Tab(t("tab_galaxy", lang)): |
| gr.Markdown(t("galaxy_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| ck_countries = gr.Checkbox(value=True, label=t("galaxy_countries", lang)) |
| ck_animals = gr.Checkbox(value=True, label=t("galaxy_animals", lang)) |
| ck_numbers = gr.Checkbox(value=True, label=t("galaxy_numbers", lang)) |
| ck_colors = gr.Checkbox(value=True, label=t("galaxy_colors", lang)) |
| ck_emotions = gr.Checkbox(value=True, label=t("galaxy_emotions", lang)) |
| custom_words = gr.Textbox(label=t("galaxy_custom", lang), placeholder=t("galaxy_custom_ph", lang)) |
| btn_galaxy = gr.Button(t("btn_run", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_galaxy = gr.Plot() |
| btn_galaxy.click(lambda *a: safe_call(exp_semantic_galaxy, *a), |
| inputs=[ck_countries, ck_animals, ck_numbers, ck_colors, ck_emotions, custom_words], |
| outputs=out_galaxy) |
|
|
| |
| with gr.Tab(t("tab_algebra", lang)): |
| gr.Markdown(t("algebra_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| pos_words = gr.Textbox(label=t("algebra_pos", lang), value="Paris, China") |
| neg_words = gr.Textbox(label=t("algebra_neg", lang), value="France") |
| btn_algebra = gr.Button(t("btn_run", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_algebra = gr.Textbox(label=t("algebra_results", lang), lines=20, interactive=False) |
| btn_algebra.click(lambda p, n: safe_call(exp_semantic_algebra, p, n), |
| inputs=[pos_words, neg_words], outputs=out_algebra) |
|
|
| |
| with gr.Tab(t("tab_crystallization", lang)): |
| gr.Markdown(t("crystallization_title", lang)) |
| presets = { |
| "Capital of France → London": ("The capital of France is", "London", 50.0, 0), |
| "Cat sat on → moon": ("The cat sat on the", "moon", 50.0, 0), |
| "Sky color → red": ("The color of the clear sky is", "red", 50.0, 0), |
| "Open door → car": ("To open the locked door, you need a", "car", 50.0, 0), |
| } |
| with gr.Row(): |
| with gr.Column(scale=1): |
| preset_dropdown = gr.Dropdown(choices=list(presets.keys()), label=t("crystallization_preset", lang), value=None) |
| cryst_prompt = gr.Textbox(label=t("crystallization_prompt", lang), value="The capital of France is") |
| cryst_target = gr.Textbox(label=t("crystallization_target", lang), value="London") |
| cryst_max_alpha = gr.Slider(10, 100, 50, step=5, label=t("crystallization_max_alpha", lang)) |
| cryst_start_layer = gr.Slider(0, 35, 0, step=1, label=t("crystallization_start_layer", lang)) |
| btn_cryst = gr.Button(t("btn_run", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_cryst_plot = gr.Plot() |
| out_cryst_text = gr.Textbox(label=t("crystallization_info", lang), lines=5, interactive=False) |
|
|
| def load_preset(preset_name): |
| if preset_name and preset_name in presets: |
| p, tgt, alpha, layer = presets[preset_name] |
| return p, tgt, alpha, layer |
| return gr.update(), gr.update(), gr.update(), gr.update() |
|
|
| preset_dropdown.change(fn=load_preset, inputs=[preset_dropdown], outputs=[cryst_prompt, cryst_target, cryst_max_alpha, cryst_start_layer]) |
| btn_cryst.click(lambda p, tg, a, l: safe_call(exp_task_crystallization, p, tg, a, l), |
| inputs=[cryst_prompt, cryst_target, cryst_max_alpha, cryst_start_layer], |
| outputs=[out_cryst_plot, out_cryst_text]) |
|
|
| |
| with gr.Tab(t("tab_typo", lang)): |
| gr.Markdown(t("typo_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| sent_normal = gr.Textbox(label=t("typo_normal", lang), value="The scientist is very intelligent") |
| sent_typo = gr.Textbox(label=t("typo_misspelled", lang), value="The scientsit is vary intellgent") |
| sent_diff = gr.Textbox(label=t("typo_unrelated", lang), value="The dog runs in the park") |
| btn_typo = gr.Button(t("btn_run", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_typo = gr.Plot() |
| btn_typo.click(lambda a, b, c: safe_call(exp_typo_resilience, a, b, c), |
| inputs=[sent_normal, sent_typo, sent_diff], outputs=out_typo) |
|
|
| |
| with gr.Tab(t("tab_sparsity", lang)): |
| gr.Markdown(t("sparsity_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| word_inspect = gr.Textbox(label=t("sparsity_word", lang), placeholder="e.g. cat") |
| btn_sparse = gr.Button(t("btn_run", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_sparse_plot = gr.Plot() |
| out_sparse_text = gr.Textbox(label=t("sparsity_stats", lang), lines=10, interactive=False) |
| btn_sparse.click(lambda w: safe_call(exp_sparsity_profile, w), |
| inputs=[word_inspect], outputs=[out_sparse_plot, out_sparse_text]) |
|
|
| |
| with gr.Tab(t("tab_evolution", lang)): |
| gr.Markdown(t("evolution_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| prompt_evo = gr.Textbox(label=t("evolution_prompt", lang), value="The capital of France is") |
| btn_evo = gr.Button(t("btn_run", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_evo = gr.Plot() |
| btn_evo.click(lambda p: safe_call(exp_layer_evolution, p), inputs=[prompt_evo], outputs=out_evo) |
|
|
| |
| with gr.Tab(t("tab_ablation", lang)): |
| gr.Markdown(t("ablation_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| prompt_abl = gr.Textbox(label=t("ablation_prompt", lang), value="The capital of France is") |
| btn_abl = gr.Button(t("btn_run", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_abl = gr.Plot() |
| btn_abl.click(lambda p: safe_call(exp_causal_ablation, p), inputs=[prompt_abl], outputs=out_abl) |
|
|
| |
| with gr.Tab(t("tab_inception", lang)): |
| gr.Markdown(t("inception_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| prompt_inc = gr.Textbox(label=t("inception_prompt", lang), value="The capital of France is") |
| target_inc = gr.Textbox(label=t("inception_target", lang), value="London") |
| alpha_inc = gr.Slider(10, 500, 200, step=10, label=t("inception_alpha", lang)) |
| btn_inc = gr.Button(t("btn_run", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_inc_plot = gr.Plot() |
| out_inc_text = gr.Textbox(label=t("inception_info", lang), lines=5, interactive=False) |
| btn_inc.click(lambda p, tg, a: safe_call(exp_concept_inception, p, tg, a), |
| inputs=[prompt_inc, target_inc, alpha_inc], outputs=[out_inc_plot, out_inc_text]) |
|
|
| |
| with gr.Tab(t("tab_basis", lang)): |
| gr.Markdown(t("basis_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| btn_basis = gr.Button(t("btn_run", lang), variant="primary") |
| out_basis_text = gr.Textbox(label=t("basis_stats", lang), lines=5, interactive=False) |
| with gr.Column(scale=2): |
| out_basis_plot = gr.Plot() |
| btn_basis.click(lambda: safe_call(exp_basis_geometry), inputs=[], outputs=[out_basis_plot, out_basis_text]) |
|
|
| |
| with gr.Tab(t("tab_neighbors", lang)): |
| gr.Markdown(t("neighbors_title", lang)) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| nn_word = gr.Textbox(label=t("neighbors_words", lang), value="king, cat, red") |
| nn_top_n = gr.Slider(5, 50, 10, step=1, label=t("neighbors_top_n", lang)) |
| btn_nn = gr.Button(t("btn_run", lang), variant="primary") |
| with gr.Column(scale=2): |
| out_nn = gr.Textbox(label=t("neighbors_results", lang), lines=25, interactive=False) |
| btn_nn.click(lambda w, n: safe_call(exp_recipe_neighbors, w, n), |
| inputs=[nn_word, nn_top_n], outputs=out_nn) |
|
|
|
|
| def build_ui(): |
| """Build UI with language tabs.""" |
| with gr.Blocks(title="reFlow Interpretability") as demo: |
| with gr.Tabs() as lang_tabs: |
| with gr.Tab("English"): |
| build_lang_ui("en") |
| with gr.Tab("中文"): |
| build_lang_ui("zh") |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| print("Loading model at startup ...") |
| get_model() |
| demo = build_ui() |
| demo.launch(theme=gr.themes.Soft()) |
|
|