"""reFlow Interactive Interpretability Demo — Gradio Blocks app.""" import gradio as gr from model_loader import get_model from experiments import ( exp_semantic_galaxy, exp_semantic_algebra, exp_typo_resilience, exp_sparsity_profile, exp_layer_evolution, exp_causal_ablation, exp_concept_inception, exp_generate, exp_basis_geometry, exp_recipe_neighbors, exp_task_crystallization, ) from i18n import t def safe_call(fn, *args, **kwargs): """Wrap experiment functions with error handling.""" try: return fn(*args, **kwargs) except Exception as e: import traceback return f"Error: {e}\n\n{traceback.format_exc()}" def build_lang_ui(lang): """Build UI for a specific language.""" with gr.Column(): gr.Markdown(f"# {t('title', lang)}\n\n{t('intro', lang)}") # Tab 0: Text Generation with gr.Tab(t("tab_generate", lang)): gr.Markdown(t("gen_title", lang)) with gr.Row(): with gr.Column(scale=1): gen_prompt = gr.Textbox(label=t("gen_prompt", lang), value="Once upon a time", lines=3) gen_samples = gr.Slider(1, 8, 1, step=1, label=t("gen_samples", lang)) gen_max_tokens = gr.Slider(1, 512, 16, step=1, label=t("gen_max_tokens", lang)) gen_temperature = gr.Slider(0.01, 2.0, 0.8, step=0.01, label=t("gen_temperature", lang)) gen_top_k = gr.Slider(0, 500, 50, step=1, label=t("gen_top_k", lang)) gen_rep_penalty = gr.Slider(1.0, 2.0, 1.1, step=0.05, label=t("gen_rep_penalty", lang)) btn_gen = gr.Button(t("btn_generate", lang), variant="primary") with gr.Column(scale=2): out_gen = gr.Textbox(label=t("gen_output", lang), lines=20, interactive=False) btn_gen.click(lambda *a: safe_call(exp_generate, *a), inputs=[gen_prompt, gen_samples, gen_max_tokens, gen_temperature, gen_top_k, gen_rep_penalty], outputs=out_gen) # Tab 1: Semantic Galaxy with gr.Tab(t("tab_galaxy", lang)): gr.Markdown(t("galaxy_title", lang)) with gr.Row(): with gr.Column(scale=1): ck_countries = gr.Checkbox(value=True, label=t("galaxy_countries", lang)) ck_animals = gr.Checkbox(value=True, label=t("galaxy_animals", lang)) ck_numbers = gr.Checkbox(value=True, label=t("galaxy_numbers", lang)) ck_colors = gr.Checkbox(value=True, label=t("galaxy_colors", lang)) ck_emotions = gr.Checkbox(value=True, label=t("galaxy_emotions", lang)) custom_words = gr.Textbox(label=t("galaxy_custom", lang), placeholder=t("galaxy_custom_ph", lang)) btn_galaxy = gr.Button(t("btn_run", lang), variant="primary") with gr.Column(scale=2): out_galaxy = gr.Plot() btn_galaxy.click(lambda *a: safe_call(exp_semantic_galaxy, *a), inputs=[ck_countries, ck_animals, ck_numbers, ck_colors, ck_emotions, custom_words], outputs=out_galaxy) # Tab 2: Semantic Algebra with gr.Tab(t("tab_algebra", lang)): gr.Markdown(t("algebra_title", lang)) with gr.Row(): with gr.Column(scale=1): pos_words = gr.Textbox(label=t("algebra_pos", lang), value="Paris, China") neg_words = gr.Textbox(label=t("algebra_neg", lang), value="France") btn_algebra = gr.Button(t("btn_run", lang), variant="primary") with gr.Column(scale=2): out_algebra = gr.Textbox(label=t("algebra_results", lang), lines=20, interactive=False) btn_algebra.click(lambda p, n: safe_call(exp_semantic_algebra, p, n), inputs=[pos_words, neg_words], outputs=out_algebra) # Tab 3: Task Crystallization with gr.Tab(t("tab_crystallization", lang)): gr.Markdown(t("crystallization_title", lang)) presets = { "Capital of France → London": ("The capital of France is", "London", 50.0, 0), "Cat sat on → moon": ("The cat sat on the", "moon", 50.0, 0), "Sky color → red": ("The color of the clear sky is", "red", 50.0, 0), "Open door → car": ("To open the locked door, you need a", "car", 50.0, 0), } with gr.Row(): with gr.Column(scale=1): preset_dropdown = gr.Dropdown(choices=list(presets.keys()), label=t("crystallization_preset", lang), value=None) cryst_prompt = gr.Textbox(label=t("crystallization_prompt", lang), value="The capital of France is") cryst_target = gr.Textbox(label=t("crystallization_target", lang), value="London") cryst_max_alpha = gr.Slider(10, 100, 50, step=5, label=t("crystallization_max_alpha", lang)) cryst_start_layer = gr.Slider(0, 35, 0, step=1, label=t("crystallization_start_layer", lang)) btn_cryst = gr.Button(t("btn_run", lang), variant="primary") with gr.Column(scale=2): out_cryst_plot = gr.Plot() out_cryst_text = gr.Textbox(label=t("crystallization_info", lang), lines=5, interactive=False) def load_preset(preset_name): if preset_name and preset_name in presets: p, tgt, alpha, layer = presets[preset_name] return p, tgt, alpha, layer return gr.update(), gr.update(), gr.update(), gr.update() preset_dropdown.change(fn=load_preset, inputs=[preset_dropdown], outputs=[cryst_prompt, cryst_target, cryst_max_alpha, cryst_start_layer]) btn_cryst.click(lambda p, tg, a, l: safe_call(exp_task_crystallization, p, tg, a, l), inputs=[cryst_prompt, cryst_target, cryst_max_alpha, cryst_start_layer], outputs=[out_cryst_plot, out_cryst_text]) # Tab 4: Typo Resilience with gr.Tab(t("tab_typo", lang)): gr.Markdown(t("typo_title", lang)) with gr.Row(): with gr.Column(scale=1): sent_normal = gr.Textbox(label=t("typo_normal", lang), value="The scientist is very intelligent") sent_typo = gr.Textbox(label=t("typo_misspelled", lang), value="The scientsit is vary intellgent") sent_diff = gr.Textbox(label=t("typo_unrelated", lang), value="The dog runs in the park") btn_typo = gr.Button(t("btn_run", lang), variant="primary") with gr.Column(scale=2): out_typo = gr.Plot() btn_typo.click(lambda a, b, c: safe_call(exp_typo_resilience, a, b, c), inputs=[sent_normal, sent_typo, sent_diff], outputs=out_typo) # Tab 5: Signal Sparsity with gr.Tab(t("tab_sparsity", lang)): gr.Markdown(t("sparsity_title", lang)) with gr.Row(): with gr.Column(scale=1): word_inspect = gr.Textbox(label=t("sparsity_word", lang), placeholder="e.g. cat") btn_sparse = gr.Button(t("btn_run", lang), variant="primary") with gr.Column(scale=2): out_sparse_plot = gr.Plot() out_sparse_text = gr.Textbox(label=t("sparsity_stats", lang), lines=10, interactive=False) btn_sparse.click(lambda w: safe_call(exp_sparsity_profile, w), inputs=[word_inspect], outputs=[out_sparse_plot, out_sparse_text]) # Tab 6: Layer Evolution with gr.Tab(t("tab_evolution", lang)): gr.Markdown(t("evolution_title", lang)) with gr.Row(): with gr.Column(scale=1): prompt_evo = gr.Textbox(label=t("evolution_prompt", lang), value="The capital of France is") btn_evo = gr.Button(t("btn_run", lang), variant="primary") with gr.Column(scale=2): out_evo = gr.Plot() btn_evo.click(lambda p: safe_call(exp_layer_evolution, p), inputs=[prompt_evo], outputs=out_evo) # Tab 7: Causal Ablation with gr.Tab(t("tab_ablation", lang)): gr.Markdown(t("ablation_title", lang)) with gr.Row(): with gr.Column(scale=1): prompt_abl = gr.Textbox(label=t("ablation_prompt", lang), value="The capital of France is") btn_abl = gr.Button(t("btn_run", lang), variant="primary") with gr.Column(scale=2): out_abl = gr.Plot() btn_abl.click(lambda p: safe_call(exp_causal_ablation, p), inputs=[prompt_abl], outputs=out_abl) # Tab 8: Concept Inception with gr.Tab(t("tab_inception", lang)): gr.Markdown(t("inception_title", lang)) with gr.Row(): with gr.Column(scale=1): prompt_inc = gr.Textbox(label=t("inception_prompt", lang), value="The capital of France is") target_inc = gr.Textbox(label=t("inception_target", lang), value="London") alpha_inc = gr.Slider(10, 500, 200, step=10, label=t("inception_alpha", lang)) btn_inc = gr.Button(t("btn_run", lang), variant="primary") with gr.Column(scale=2): out_inc_plot = gr.Plot() out_inc_text = gr.Textbox(label=t("inception_info", lang), lines=5, interactive=False) btn_inc.click(lambda p, tg, a: safe_call(exp_concept_inception, p, tg, a), inputs=[prompt_inc, target_inc, alpha_inc], outputs=[out_inc_plot, out_inc_text]) # Tab 9: Basis Geometry with gr.Tab(t("tab_basis", lang)): gr.Markdown(t("basis_title", lang)) with gr.Row(): with gr.Column(scale=1): btn_basis = gr.Button(t("btn_run", lang), variant="primary") out_basis_text = gr.Textbox(label=t("basis_stats", lang), lines=5, interactive=False) with gr.Column(scale=2): out_basis_plot = gr.Plot() btn_basis.click(lambda: safe_call(exp_basis_geometry), inputs=[], outputs=[out_basis_plot, out_basis_text]) # Tab 10: Recipe Neighbors with gr.Tab(t("tab_neighbors", lang)): gr.Markdown(t("neighbors_title", lang)) with gr.Row(): with gr.Column(scale=1): nn_word = gr.Textbox(label=t("neighbors_words", lang), value="king, cat, red") nn_top_n = gr.Slider(5, 50, 10, step=1, label=t("neighbors_top_n", lang)) btn_nn = gr.Button(t("btn_run", lang), variant="primary") with gr.Column(scale=2): out_nn = gr.Textbox(label=t("neighbors_results", lang), lines=25, interactive=False) btn_nn.click(lambda w, n: safe_call(exp_recipe_neighbors, w, n), inputs=[nn_word, nn_top_n], outputs=out_nn) def build_ui(): """Build UI with language tabs.""" with gr.Blocks(title="reFlow Interpretability") as demo: with gr.Tabs() as lang_tabs: with gr.Tab("English"): build_lang_ui("en") with gr.Tab("中文"): build_lang_ui("zh") return demo if __name__ == "__main__": print("Loading model at startup ...") get_model() demo = build_ui() demo.launch(theme=gr.themes.Soft())