fr-on-device / app.py
Joseph Pollack
revery cursor fallback strategies on zerogpu decorator
0b02a9a unverified
"""
Baguettotron vs Luth models — Gradio comparison app.
All models, all outputs; tabbed by parameter size.
"""
import gradio as gr
from inference import BAGUETTOTRON_ID, run_all
from model_config import (
combined_footprint,
footprint_table_data,
get_models_by_tier,
MODELS,
)
from ui_strings import get_strings
# Example run once on startup and cached for initial display (no system prompt)
STARTUP_EXAMPLE_PROMPT = "dites moi en plus sur les jardins japonnais a paris :"
STARTUP_EXAMPLE_SYSTEM = ""
import spaces
def build_params_by_model(
temp_baguettotron: float,
max_tok_baguettotron: int,
top_p_baguettotron: float,
top_k_baguettotron: int,
rep_baguettotron: float,
temp_luth: float,
max_tok_luth: int,
top_p_luth: float,
top_k_luth: int,
rep_luth: float,
) -> dict[str, dict]:
"""Build params dict keyed by model_id from Baguettotron vs Luth controls."""
baguettotron_params = {
"temperature": temp_baguettotron,
"max_tokens": max_tok_baguettotron,
"top_p": top_p_baguettotron,
"top_k": top_k_baguettotron,
"repeat_penalty": rep_baguettotron,
}
luth_params = {
"temperature": temp_luth,
"max_tokens": max_tok_luth,
"top_p": top_p_luth,
"top_k": top_k_luth,
"repeat_penalty": rep_luth,
}
params_by_model: dict[str, dict] = {}
for m in MODELS:
params_by_model[m.repo_id] = (baguettotron_params if m.repo_id == BAGUETTOTRON_ID else luth_params).copy()
return params_by_model
@spaces.GPU
def generate_all(
prompt: str,
system_prompt: str,
temp_baguettotron: float,
max_tok_baguettotron: int,
top_p_baguettotron: float,
top_k_baguettotron: int,
rep_baguettotron: float,
temp_luth: float,
max_tok_luth: int,
top_p_luth: float,
top_k_luth: int,
rep_luth: float,
) -> tuple[str, str, str, str, str, str]:
"""Run all 6 models, return outputs in tab order: small (2), medium (2), large (2)."""
if not prompt.strip():
return ("",) * 6
params = build_params_by_model(
temp_baguettotron,
max_tok_baguettotron,
top_p_baguettotron,
top_k_baguettotron,
rep_baguettotron,
temp_luth,
max_tok_luth,
top_p_luth,
top_k_luth,
rep_luth,
)
results = run_all(prompt, params, system_prompt=system_prompt)
models_by_tier = get_models_by_tier()
outputs: list[str] = []
for tier in ["small", "medium", "large"]:
for m in models_by_tier[tier]:
outputs.append(results.get(m.repo_id, ""))
return tuple(outputs)
def _ui_updates_for_locale(locale: str) -> tuple:
"""Return (title, subtitle, heading_footprint, ...) for language toggle; order must match _LANG_OUTPUTS."""
s = get_strings(locale)
total_disk, total_vram = combined_footprint()
footprint_summary = s["FOOTPRINT_SUMMARY_TEMPLATE"].format(total_disk=total_disk, total_vram=total_vram)
return (
f"# {s['TITLE']}",
s["SUBTITLE"],
s["HEADING_FOOTPRINT"],
s["FOOTPRINT_INTRO"],
gr.update(headers=s["FOOTPRINT_HEADERS"]),
footprint_summary,
s["HEADING_GENERATION"],
s["COL_BAGUETTOTRON_HEADING"],
s["COL_LUTH_HEADING"],
gr.update(label=s["LABEL_TEMPERATURE"], info=s["INFO_TEMP_BAGUETTOTRON"]),
gr.update(label=s["LABEL_MAX_TOKENS"]),
gr.update(label=s["LABEL_TOP_P"]),
gr.update(label=s["LABEL_TOP_K"]),
gr.update(label=s["LABEL_REPEAT_PENALTY"]),
gr.update(label=s["LABEL_TEMPERATURE"]),
gr.update(label=s["LABEL_MAX_TOKENS"]),
gr.update(label=s["LABEL_TOP_P"]),
gr.update(label=s["LABEL_TOP_K"]),
gr.update(label=s["LABEL_REPEAT_PENALTY"], info=s["INFO_REP_LUTH"]),
s["HEADING_LIVE_INFERENCE"],
gr.update(label=s["TIER_SMALL"]),
gr.update(label=s["TIER_MEDIUM"]),
gr.update(label=s["TIER_LARGE"]),
gr.update(label=s["LABEL_OUT_BAGUETTOTRON"]),
gr.update(label=s["LABEL_OUT_LUTH_350"]),
gr.update(label=s["LABEL_OUT_LUTH_06"]),
gr.update(label=s["LABEL_OUT_LUTH_07"]),
gr.update(label=s["LABEL_OUT_LUTH_12"]),
gr.update(label=s["LABEL_OUT_LUTH_17"]),
gr.update(label=s["LABEL_SYSTEM_PROMPT"], placeholder=s["PLACEHOLDER_SYSTEM_PROMPT"]),
gr.update(label=s["LABEL_PROMPT"], placeholder=s["PLACEHOLDER_PROMPT"]),
gr.update(value=s["BTN_GENERATE"]),
s["HEADING_HOW_TO_USE"],
s["HOW_TO_USE"],
s["JOIN_US"],
)
def create_ui():
s = get_strings("en")
total_disk, total_vram = combined_footprint()
footprint_md = s["FOOTPRINT_SUMMARY_TEMPLATE"].format(total_disk=total_disk, total_vram=total_vram)
# Run startup example once and cache for initial output display
default_params = build_params_by_model(
0.5, 512, 0.9, 40, 1.1, # Baguettotron
0.7, 256, 0.9, 40, 1.05, # Luth
)
startup_results = run_all(
STARTUP_EXAMPLE_PROMPT,
default_params,
system_prompt=STARTUP_EXAMPLE_SYSTEM,
)
models_by_tier = get_models_by_tier()
startup_outputs: list[str] = []
for tier in ["small", "medium", "large"]:
for m in models_by_tier[tier]:
startup_outputs.append(startup_results.get(m.repo_id, ""))
TOP_TOGGLES_CSS = """
.top-toggles-row { justify-content: flex-end !important; align-items: center; flex-wrap: nowrap !important; }
.top-toggles-row > div:first-child { flex: 1 !important; min-width: 0 !important; }
.top-toggles-group { display: flex !important; flex-direction: row !important; flex-wrap: nowrap !important; align-items: center !important; gap: 0.5rem !important; }
.top-toggles-group > div { flex: 0 0 auto !important; margin-bottom: 0 !important; }
.top-toggles-group .wrap,
.top-toggles-group [class*="wrap"] { display: flex !important; flex-direction: row !important; flex-wrap: nowrap !important; margin-bottom: 0 !important; }
.top-toggles-group .wrap-inner,
.top-toggles-group [class*="wrap-inner"],
.top-toggles-group [class*="form"],
.top-toggles-group [role="radiogroup"],
.top-toggles-group div:has(> input[type="radio"]) { display: flex !important; flex-direction: row !important; flex-wrap: nowrap !important; gap: 0.2rem !important; }
.top-toggles-group label { min-height: unset !important; padding: 0.35rem 0.6rem !important; margin: 0 !important; border-radius: 6px !important; }
"""
with gr.Blocks(title=s["TITLE"]) as demo:
# Right-aligned: language + theme on a single row
with gr.Row(elem_classes="top-toggles-row"):
gr.HTML('<div style="flex:1; min-width:0;"></div>')
with gr.Row(scale=0, elem_classes=["top-toggles-group"]):
lang_radio = gr.Radio(
choices=[("🇫🇷", "fr"), ("🇺🇸", "en")],
value="en",
show_label=False,
scale=0,
)
theme_radio = gr.Radio(
choices=[("☀️", "light"), ("🌙", "dark")],
value="light",
show_label=False,
scale=0,
)
theme_radio.change(
None,
inputs=[theme_radio],
js="(v) => { document.body.classList.toggle('dark', v === 'dark'); }",
)
title_md = gr.Markdown(f"# {s['TITLE']}")
subtitle_md = gr.Markdown(s["SUBTITLE"])
# Row 1: Single consolidated comparison table
heading_footprint_md = gr.Markdown(s["HEADING_FOOTPRINT"])
footprint_intro_md = gr.Markdown(s["FOOTPRINT_INTRO"])
footprint_df = gr.Dataframe(
value=footprint_table_data(),
headers=s["FOOTPRINT_HEADERS"],
interactive=False,
)
footprint_summary_md = gr.Markdown(footprint_md)
# Row 2: Generation settings — two columns (Baguettotron | Luth)
heading_generation_md = gr.Markdown(s["HEADING_GENERATION"])
with gr.Row():
with gr.Column():
col_baguettotron_md = gr.Markdown(s["COL_BAGUETTOTRON_HEADING"])
temp_baguettotron = gr.Slider(0, 2, value=0.5, label=s["LABEL_TEMPERATURE"], info=s["INFO_TEMP_BAGUETTOTRON"])
max_tok_baguettotron = gr.Number(value=512, label=s["LABEL_MAX_TOKENS"], minimum=64, maximum=2048)
top_p_baguettotron = gr.Slider(0, 1, value=0.9, label=s["LABEL_TOP_P"])
top_k_baguettotron = gr.Number(value=40, label=s["LABEL_TOP_K"])
rep_baguettotron = gr.Slider(1.0, 1.5, value=1.1, label=s["LABEL_REPEAT_PENALTY"])
with gr.Column():
col_luth_md = gr.Markdown(s["COL_LUTH_HEADING"])
temp_luth = gr.Slider(0, 2, value=0.7, label=s["LABEL_TEMPERATURE"])
max_tok_luth = gr.Number(value=256, label=s["LABEL_MAX_TOKENS"], minimum=64, maximum=2048)
top_p_luth = gr.Slider(0, 1, value=0.9, label=s["LABEL_TOP_P"])
top_k_luth = gr.Number(value=40, label=s["LABEL_TOP_K"])
rep_luth = gr.Slider(1.0, 1.5, value=1.05, label=s["LABEL_REPEAT_PENALTY"], info=s["INFO_REP_LUTH"])
# Row 3: Live inference — outputs above inputs
heading_live_md = gr.Markdown(s["HEADING_LIVE_INFERENCE"])
models_by_tier = get_models_by_tier()
with gr.Tabs() as tabs_root:
with gr.Tab(s["TIER_SMALL"], id="tab_small") as tab_small:
with gr.Row():
out_baguettotron = gr.Textbox(
label=s["LABEL_OUT_BAGUETTOTRON"],
lines=12,
max_lines=24,
value=startup_outputs[0],
)
out_luth_350 = gr.Textbox(
label=s["LABEL_OUT_LUTH_350"],
lines=12,
max_lines=24,
value=startup_outputs[1],
)
with gr.Tab(s["TIER_MEDIUM"], id="tab_medium") as tab_medium:
with gr.Row():
out_luth_06 = gr.Textbox(
label=s["LABEL_OUT_LUTH_06"],
lines=12,
max_lines=24,
value=startup_outputs[2],
)
out_luth_07 = gr.Textbox(
label=s["LABEL_OUT_LUTH_07"],
lines=12,
max_lines=24,
value=startup_outputs[3],
)
with gr.Tab(s["TIER_LARGE"], id="tab_large") as tab_large:
with gr.Row():
out_luth_12 = gr.Textbox(
label=s["LABEL_OUT_LUTH_12"],
lines=12,
max_lines=24,
value=startup_outputs[4],
)
out_luth_17 = gr.Textbox(
label=s["LABEL_OUT_LUTH_17"],
lines=12,
max_lines=24,
value=startup_outputs[5],
)
system_prompt_in = gr.Textbox(
label=s["LABEL_SYSTEM_PROMPT"],
placeholder=s["PLACEHOLDER_SYSTEM_PROMPT"],
lines=2,
)
prompt_in = gr.Textbox(
label=s["LABEL_PROMPT"],
placeholder=s["PLACEHOLDER_PROMPT"],
lines=3,
value=STARTUP_EXAMPLE_PROMPT,
)
gen_btn = gr.Button(s["BTN_GENERATE"], variant="primary")
all_inputs = [
prompt_in,
system_prompt_in,
temp_baguettotron,
max_tok_baguettotron,
top_p_baguettotron,
top_k_baguettotron,
rep_baguettotron,
temp_luth,
max_tok_luth,
top_p_luth,
top_k_luth,
rep_luth,
]
all_outputs = [
out_baguettotron,
out_luth_350,
out_luth_06,
out_luth_07,
out_luth_12,
out_luth_17,
]
gen_btn.click(
fn=generate_all,
inputs=all_inputs,
outputs=all_outputs,
)
# How to use & join us
heading_how_to_use_md = gr.Markdown(s["HEADING_HOW_TO_USE"])
how_to_use_md = gr.Markdown(s["HOW_TO_USE"])
join_us_md = gr.Markdown(s["JOIN_US"])
# Language toggle: update all visible strings
_lang_outputs = [
title_md,
subtitle_md,
heading_footprint_md,
footprint_intro_md,
footprint_df,
footprint_summary_md,
heading_generation_md,
col_baguettotron_md,
col_luth_md,
temp_baguettotron,
max_tok_baguettotron,
top_p_baguettotron,
top_k_baguettotron,
rep_baguettotron,
temp_luth,
max_tok_luth,
top_p_luth,
top_k_luth,
rep_luth,
heading_live_md,
tab_small,
tab_medium,
tab_large,
out_baguettotron,
out_luth_350,
out_luth_06,
out_luth_07,
out_luth_12,
out_luth_17,
system_prompt_in,
prompt_in,
gen_btn,
heading_how_to_use_md,
how_to_use_md,
join_us_md,
]
lang_radio.change(
fn=_ui_updates_for_locale,
inputs=[lang_radio],
outputs=_lang_outputs,
)
demo._custom_css = TOP_TOGGLES_CSS
_original_launch = demo.launch
def _launch_with_css(**kwargs):
if "css" not in kwargs and demo._custom_css:
kwargs["css"] = demo._custom_css
return _original_launch(**kwargs)
demo.launch = _launch_with_css
return demo
if __name__ == "__main__":
demo = create_ui()
demo.launch(ssr_mode=False)