reFlow / app.py
”reuAC“
Change default max_tokens to 16
08537ff
"""reFlow Interactive Interpretability Demo — Gradio Blocks app."""
import gradio as gr
from model_loader import get_model
from experiments import (
exp_semantic_galaxy,
exp_semantic_algebra,
exp_typo_resilience,
exp_sparsity_profile,
exp_layer_evolution,
exp_causal_ablation,
exp_concept_inception,
exp_generate,
exp_basis_geometry,
exp_recipe_neighbors,
exp_task_crystallization,
)
from i18n import t
def safe_call(fn, *args, **kwargs):
"""Wrap experiment functions with error handling."""
try:
return fn(*args, **kwargs)
except Exception as e:
import traceback
return f"Error: {e}\n\n{traceback.format_exc()}"
def build_lang_ui(lang):
"""Build UI for a specific language."""
with gr.Column():
gr.Markdown(f"# {t('title', lang)}\n\n{t('intro', lang)}")
# Tab 0: Text Generation
with gr.Tab(t("tab_generate", lang)):
gr.Markdown(t("gen_title", lang))
with gr.Row():
with gr.Column(scale=1):
gen_prompt = gr.Textbox(label=t("gen_prompt", lang), value="Once upon a time", lines=3)
gen_samples = gr.Slider(1, 8, 1, step=1, label=t("gen_samples", lang))
gen_max_tokens = gr.Slider(1, 512, 16, step=1, label=t("gen_max_tokens", lang))
gen_temperature = gr.Slider(0.01, 2.0, 0.8, step=0.01, label=t("gen_temperature", lang))
gen_top_k = gr.Slider(0, 500, 50, step=1, label=t("gen_top_k", lang))
gen_rep_penalty = gr.Slider(1.0, 2.0, 1.1, step=0.05, label=t("gen_rep_penalty", lang))
btn_gen = gr.Button(t("btn_generate", lang), variant="primary")
with gr.Column(scale=2):
out_gen = gr.Textbox(label=t("gen_output", lang), lines=20, interactive=False)
btn_gen.click(lambda *a: safe_call(exp_generate, *a),
inputs=[gen_prompt, gen_samples, gen_max_tokens, gen_temperature, gen_top_k, gen_rep_penalty],
outputs=out_gen)
# Tab 1: Semantic Galaxy
with gr.Tab(t("tab_galaxy", lang)):
gr.Markdown(t("galaxy_title", lang))
with gr.Row():
with gr.Column(scale=1):
ck_countries = gr.Checkbox(value=True, label=t("galaxy_countries", lang))
ck_animals = gr.Checkbox(value=True, label=t("galaxy_animals", lang))
ck_numbers = gr.Checkbox(value=True, label=t("galaxy_numbers", lang))
ck_colors = gr.Checkbox(value=True, label=t("galaxy_colors", lang))
ck_emotions = gr.Checkbox(value=True, label=t("galaxy_emotions", lang))
custom_words = gr.Textbox(label=t("galaxy_custom", lang), placeholder=t("galaxy_custom_ph", lang))
btn_galaxy = gr.Button(t("btn_run", lang), variant="primary")
with gr.Column(scale=2):
out_galaxy = gr.Plot()
btn_galaxy.click(lambda *a: safe_call(exp_semantic_galaxy, *a),
inputs=[ck_countries, ck_animals, ck_numbers, ck_colors, ck_emotions, custom_words],
outputs=out_galaxy)
# Tab 2: Semantic Algebra
with gr.Tab(t("tab_algebra", lang)):
gr.Markdown(t("algebra_title", lang))
with gr.Row():
with gr.Column(scale=1):
pos_words = gr.Textbox(label=t("algebra_pos", lang), value="Paris, China")
neg_words = gr.Textbox(label=t("algebra_neg", lang), value="France")
btn_algebra = gr.Button(t("btn_run", lang), variant="primary")
with gr.Column(scale=2):
out_algebra = gr.Textbox(label=t("algebra_results", lang), lines=20, interactive=False)
btn_algebra.click(lambda p, n: safe_call(exp_semantic_algebra, p, n),
inputs=[pos_words, neg_words], outputs=out_algebra)
# Tab 3: Task Crystallization
with gr.Tab(t("tab_crystallization", lang)):
gr.Markdown(t("crystallization_title", lang))
presets = {
"Capital of France → London": ("The capital of France is", "London", 50.0, 0),
"Cat sat on → moon": ("The cat sat on the", "moon", 50.0, 0),
"Sky color → red": ("The color of the clear sky is", "red", 50.0, 0),
"Open door → car": ("To open the locked door, you need a", "car", 50.0, 0),
}
with gr.Row():
with gr.Column(scale=1):
preset_dropdown = gr.Dropdown(choices=list(presets.keys()), label=t("crystallization_preset", lang), value=None)
cryst_prompt = gr.Textbox(label=t("crystallization_prompt", lang), value="The capital of France is")
cryst_target = gr.Textbox(label=t("crystallization_target", lang), value="London")
cryst_max_alpha = gr.Slider(10, 100, 50, step=5, label=t("crystallization_max_alpha", lang))
cryst_start_layer = gr.Slider(0, 35, 0, step=1, label=t("crystallization_start_layer", lang))
btn_cryst = gr.Button(t("btn_run", lang), variant="primary")
with gr.Column(scale=2):
out_cryst_plot = gr.Plot()
out_cryst_text = gr.Textbox(label=t("crystallization_info", lang), lines=5, interactive=False)
def load_preset(preset_name):
if preset_name and preset_name in presets:
p, tgt, alpha, layer = presets[preset_name]
return p, tgt, alpha, layer
return gr.update(), gr.update(), gr.update(), gr.update()
preset_dropdown.change(fn=load_preset, inputs=[preset_dropdown], outputs=[cryst_prompt, cryst_target, cryst_max_alpha, cryst_start_layer])
btn_cryst.click(lambda p, tg, a, l: safe_call(exp_task_crystallization, p, tg, a, l),
inputs=[cryst_prompt, cryst_target, cryst_max_alpha, cryst_start_layer],
outputs=[out_cryst_plot, out_cryst_text])
# Tab 4: Typo Resilience
with gr.Tab(t("tab_typo", lang)):
gr.Markdown(t("typo_title", lang))
with gr.Row():
with gr.Column(scale=1):
sent_normal = gr.Textbox(label=t("typo_normal", lang), value="The scientist is very intelligent")
sent_typo = gr.Textbox(label=t("typo_misspelled", lang), value="The scientsit is vary intellgent")
sent_diff = gr.Textbox(label=t("typo_unrelated", lang), value="The dog runs in the park")
btn_typo = gr.Button(t("btn_run", lang), variant="primary")
with gr.Column(scale=2):
out_typo = gr.Plot()
btn_typo.click(lambda a, b, c: safe_call(exp_typo_resilience, a, b, c),
inputs=[sent_normal, sent_typo, sent_diff], outputs=out_typo)
# Tab 5: Signal Sparsity
with gr.Tab(t("tab_sparsity", lang)):
gr.Markdown(t("sparsity_title", lang))
with gr.Row():
with gr.Column(scale=1):
word_inspect = gr.Textbox(label=t("sparsity_word", lang), placeholder="e.g. cat")
btn_sparse = gr.Button(t("btn_run", lang), variant="primary")
with gr.Column(scale=2):
out_sparse_plot = gr.Plot()
out_sparse_text = gr.Textbox(label=t("sparsity_stats", lang), lines=10, interactive=False)
btn_sparse.click(lambda w: safe_call(exp_sparsity_profile, w),
inputs=[word_inspect], outputs=[out_sparse_plot, out_sparse_text])
# Tab 6: Layer Evolution
with gr.Tab(t("tab_evolution", lang)):
gr.Markdown(t("evolution_title", lang))
with gr.Row():
with gr.Column(scale=1):
prompt_evo = gr.Textbox(label=t("evolution_prompt", lang), value="The capital of France is")
btn_evo = gr.Button(t("btn_run", lang), variant="primary")
with gr.Column(scale=2):
out_evo = gr.Plot()
btn_evo.click(lambda p: safe_call(exp_layer_evolution, p), inputs=[prompt_evo], outputs=out_evo)
# Tab 7: Causal Ablation
with gr.Tab(t("tab_ablation", lang)):
gr.Markdown(t("ablation_title", lang))
with gr.Row():
with gr.Column(scale=1):
prompt_abl = gr.Textbox(label=t("ablation_prompt", lang), value="The capital of France is")
btn_abl = gr.Button(t("btn_run", lang), variant="primary")
with gr.Column(scale=2):
out_abl = gr.Plot()
btn_abl.click(lambda p: safe_call(exp_causal_ablation, p), inputs=[prompt_abl], outputs=out_abl)
# Tab 8: Concept Inception
with gr.Tab(t("tab_inception", lang)):
gr.Markdown(t("inception_title", lang))
with gr.Row():
with gr.Column(scale=1):
prompt_inc = gr.Textbox(label=t("inception_prompt", lang), value="The capital of France is")
target_inc = gr.Textbox(label=t("inception_target", lang), value="London")
alpha_inc = gr.Slider(10, 500, 200, step=10, label=t("inception_alpha", lang))
btn_inc = gr.Button(t("btn_run", lang), variant="primary")
with gr.Column(scale=2):
out_inc_plot = gr.Plot()
out_inc_text = gr.Textbox(label=t("inception_info", lang), lines=5, interactive=False)
btn_inc.click(lambda p, tg, a: safe_call(exp_concept_inception, p, tg, a),
inputs=[prompt_inc, target_inc, alpha_inc], outputs=[out_inc_plot, out_inc_text])
# Tab 9: Basis Geometry
with gr.Tab(t("tab_basis", lang)):
gr.Markdown(t("basis_title", lang))
with gr.Row():
with gr.Column(scale=1):
btn_basis = gr.Button(t("btn_run", lang), variant="primary")
out_basis_text = gr.Textbox(label=t("basis_stats", lang), lines=5, interactive=False)
with gr.Column(scale=2):
out_basis_plot = gr.Plot()
btn_basis.click(lambda: safe_call(exp_basis_geometry), inputs=[], outputs=[out_basis_plot, out_basis_text])
# Tab 10: Recipe Neighbors
with gr.Tab(t("tab_neighbors", lang)):
gr.Markdown(t("neighbors_title", lang))
with gr.Row():
with gr.Column(scale=1):
nn_word = gr.Textbox(label=t("neighbors_words", lang), value="king, cat, red")
nn_top_n = gr.Slider(5, 50, 10, step=1, label=t("neighbors_top_n", lang))
btn_nn = gr.Button(t("btn_run", lang), variant="primary")
with gr.Column(scale=2):
out_nn = gr.Textbox(label=t("neighbors_results", lang), lines=25, interactive=False)
btn_nn.click(lambda w, n: safe_call(exp_recipe_neighbors, w, n),
inputs=[nn_word, nn_top_n], outputs=out_nn)
def build_ui():
"""Build UI with language tabs."""
with gr.Blocks(title="reFlow Interpretability") as demo:
with gr.Tabs() as lang_tabs:
with gr.Tab("English"):
build_lang_ui("en")
with gr.Tab("中文"):
build_lang_ui("zh")
return demo
if __name__ == "__main__":
print("Loading model at startup ...")
get_model()
demo = build_ui()
demo.launch(theme=gr.themes.Soft())