Spaces:
Running on T4
Running on T4
| """ | |
| Baguettotron vs Luth models — Gradio comparison app. | |
| All models, all outputs; tabbed by parameter size. | |
| """ | |
| import gradio as gr | |
| from inference import BAGUETTOTRON_ID, run_all | |
| from model_config import ( | |
| combined_footprint, | |
| footprint_table_data, | |
| get_models_by_tier, | |
| MODELS, | |
| ) | |
| from ui_strings import get_strings | |
| # Example run once on startup and cached for initial display (no system prompt) | |
| STARTUP_EXAMPLE_PROMPT = "dites moi en plus sur les jardins japonnais a paris :" | |
| STARTUP_EXAMPLE_SYSTEM = "" | |
| import spaces | |
| def build_params_by_model( | |
| temp_baguettotron: float, | |
| max_tok_baguettotron: int, | |
| top_p_baguettotron: float, | |
| top_k_baguettotron: int, | |
| rep_baguettotron: float, | |
| temp_luth: float, | |
| max_tok_luth: int, | |
| top_p_luth: float, | |
| top_k_luth: int, | |
| rep_luth: float, | |
| ) -> dict[str, dict]: | |
| """Build params dict keyed by model_id from Baguettotron vs Luth controls.""" | |
| baguettotron_params = { | |
| "temperature": temp_baguettotron, | |
| "max_tokens": max_tok_baguettotron, | |
| "top_p": top_p_baguettotron, | |
| "top_k": top_k_baguettotron, | |
| "repeat_penalty": rep_baguettotron, | |
| } | |
| luth_params = { | |
| "temperature": temp_luth, | |
| "max_tokens": max_tok_luth, | |
| "top_p": top_p_luth, | |
| "top_k": top_k_luth, | |
| "repeat_penalty": rep_luth, | |
| } | |
| params_by_model: dict[str, dict] = {} | |
| for m in MODELS: | |
| params_by_model[m.repo_id] = (baguettotron_params if m.repo_id == BAGUETTOTRON_ID else luth_params).copy() | |
| return params_by_model | |
| def generate_all( | |
| prompt: str, | |
| system_prompt: str, | |
| temp_baguettotron: float, | |
| max_tok_baguettotron: int, | |
| top_p_baguettotron: float, | |
| top_k_baguettotron: int, | |
| rep_baguettotron: float, | |
| temp_luth: float, | |
| max_tok_luth: int, | |
| top_p_luth: float, | |
| top_k_luth: int, | |
| rep_luth: float, | |
| ) -> tuple[str, str, str, str, str, str]: | |
| """Run all 6 models, return outputs in tab order: small (2), medium (2), large (2).""" | |
| if not prompt.strip(): | |
| return ("",) * 6 | |
| params = build_params_by_model( | |
| temp_baguettotron, | |
| max_tok_baguettotron, | |
| top_p_baguettotron, | |
| top_k_baguettotron, | |
| rep_baguettotron, | |
| temp_luth, | |
| max_tok_luth, | |
| top_p_luth, | |
| top_k_luth, | |
| rep_luth, | |
| ) | |
| results = run_all(prompt, params, system_prompt=system_prompt) | |
| models_by_tier = get_models_by_tier() | |
| outputs: list[str] = [] | |
| for tier in ["small", "medium", "large"]: | |
| for m in models_by_tier[tier]: | |
| outputs.append(results.get(m.repo_id, "")) | |
| return tuple(outputs) | |
| def _ui_updates_for_locale(locale: str) -> tuple: | |
| """Return (title, subtitle, heading_footprint, ...) for language toggle; order must match _LANG_OUTPUTS.""" | |
| s = get_strings(locale) | |
| total_disk, total_vram = combined_footprint() | |
| footprint_summary = s["FOOTPRINT_SUMMARY_TEMPLATE"].format(total_disk=total_disk, total_vram=total_vram) | |
| return ( | |
| f"# {s['TITLE']}", | |
| s["SUBTITLE"], | |
| s["HEADING_FOOTPRINT"], | |
| s["FOOTPRINT_INTRO"], | |
| gr.update(headers=s["FOOTPRINT_HEADERS"]), | |
| footprint_summary, | |
| s["HEADING_GENERATION"], | |
| s["COL_BAGUETTOTRON_HEADING"], | |
| s["COL_LUTH_HEADING"], | |
| gr.update(label=s["LABEL_TEMPERATURE"], info=s["INFO_TEMP_BAGUETTOTRON"]), | |
| gr.update(label=s["LABEL_MAX_TOKENS"]), | |
| gr.update(label=s["LABEL_TOP_P"]), | |
| gr.update(label=s["LABEL_TOP_K"]), | |
| gr.update(label=s["LABEL_REPEAT_PENALTY"]), | |
| gr.update(label=s["LABEL_TEMPERATURE"]), | |
| gr.update(label=s["LABEL_MAX_TOKENS"]), | |
| gr.update(label=s["LABEL_TOP_P"]), | |
| gr.update(label=s["LABEL_TOP_K"]), | |
| gr.update(label=s["LABEL_REPEAT_PENALTY"], info=s["INFO_REP_LUTH"]), | |
| s["HEADING_LIVE_INFERENCE"], | |
| gr.update(label=s["TIER_SMALL"]), | |
| gr.update(label=s["TIER_MEDIUM"]), | |
| gr.update(label=s["TIER_LARGE"]), | |
| gr.update(label=s["LABEL_OUT_BAGUETTOTRON"]), | |
| gr.update(label=s["LABEL_OUT_LUTH_350"]), | |
| gr.update(label=s["LABEL_OUT_LUTH_06"]), | |
| gr.update(label=s["LABEL_OUT_LUTH_07"]), | |
| gr.update(label=s["LABEL_OUT_LUTH_12"]), | |
| gr.update(label=s["LABEL_OUT_LUTH_17"]), | |
| gr.update(label=s["LABEL_SYSTEM_PROMPT"], placeholder=s["PLACEHOLDER_SYSTEM_PROMPT"]), | |
| gr.update(label=s["LABEL_PROMPT"], placeholder=s["PLACEHOLDER_PROMPT"]), | |
| gr.update(value=s["BTN_GENERATE"]), | |
| s["HEADING_HOW_TO_USE"], | |
| s["HOW_TO_USE"], | |
| s["JOIN_US"], | |
| ) | |
| def create_ui(): | |
| s = get_strings("en") | |
| total_disk, total_vram = combined_footprint() | |
| footprint_md = s["FOOTPRINT_SUMMARY_TEMPLATE"].format(total_disk=total_disk, total_vram=total_vram) | |
| # Run startup example once and cache for initial output display | |
| default_params = build_params_by_model( | |
| 0.5, 512, 0.9, 40, 1.1, # Baguettotron | |
| 0.7, 256, 0.9, 40, 1.05, # Luth | |
| ) | |
| startup_results = run_all( | |
| STARTUP_EXAMPLE_PROMPT, | |
| default_params, | |
| system_prompt=STARTUP_EXAMPLE_SYSTEM, | |
| ) | |
| models_by_tier = get_models_by_tier() | |
| startup_outputs: list[str] = [] | |
| for tier in ["small", "medium", "large"]: | |
| for m in models_by_tier[tier]: | |
| startup_outputs.append(startup_results.get(m.repo_id, "")) | |
| TOP_TOGGLES_CSS = """ | |
| .top-toggles-row { justify-content: flex-end !important; align-items: center; flex-wrap: nowrap !important; } | |
| .top-toggles-row > div:first-child { flex: 1 !important; min-width: 0 !important; } | |
| .top-toggles-group { display: flex !important; flex-direction: row !important; flex-wrap: nowrap !important; align-items: center !important; gap: 0.5rem !important; } | |
| .top-toggles-group > div { flex: 0 0 auto !important; margin-bottom: 0 !important; } | |
| .top-toggles-group .wrap, | |
| .top-toggles-group [class*="wrap"] { display: flex !important; flex-direction: row !important; flex-wrap: nowrap !important; margin-bottom: 0 !important; } | |
| .top-toggles-group .wrap-inner, | |
| .top-toggles-group [class*="wrap-inner"], | |
| .top-toggles-group [class*="form"], | |
| .top-toggles-group [role="radiogroup"], | |
| .top-toggles-group div:has(> input[type="radio"]) { display: flex !important; flex-direction: row !important; flex-wrap: nowrap !important; gap: 0.2rem !important; } | |
| .top-toggles-group label { min-height: unset !important; padding: 0.35rem 0.6rem !important; margin: 0 !important; border-radius: 6px !important; } | |
| """ | |
| with gr.Blocks(title=s["TITLE"]) as demo: | |
| # Right-aligned: language + theme on a single row | |
| with gr.Row(elem_classes="top-toggles-row"): | |
| gr.HTML('<div style="flex:1; min-width:0;"></div>') | |
| with gr.Row(scale=0, elem_classes=["top-toggles-group"]): | |
| lang_radio = gr.Radio( | |
| choices=[("🇫🇷", "fr"), ("🇺🇸", "en")], | |
| value="en", | |
| show_label=False, | |
| scale=0, | |
| ) | |
| theme_radio = gr.Radio( | |
| choices=[("☀️", "light"), ("🌙", "dark")], | |
| value="light", | |
| show_label=False, | |
| scale=0, | |
| ) | |
| theme_radio.change( | |
| None, | |
| inputs=[theme_radio], | |
| js="(v) => { document.body.classList.toggle('dark', v === 'dark'); }", | |
| ) | |
| title_md = gr.Markdown(f"# {s['TITLE']}") | |
| subtitle_md = gr.Markdown(s["SUBTITLE"]) | |
| # Row 1: Single consolidated comparison table | |
| heading_footprint_md = gr.Markdown(s["HEADING_FOOTPRINT"]) | |
| footprint_intro_md = gr.Markdown(s["FOOTPRINT_INTRO"]) | |
| footprint_df = gr.Dataframe( | |
| value=footprint_table_data(), | |
| headers=s["FOOTPRINT_HEADERS"], | |
| interactive=False, | |
| ) | |
| footprint_summary_md = gr.Markdown(footprint_md) | |
| # Row 2: Generation settings — two columns (Baguettotron | Luth) | |
| heading_generation_md = gr.Markdown(s["HEADING_GENERATION"]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| col_baguettotron_md = gr.Markdown(s["COL_BAGUETTOTRON_HEADING"]) | |
| temp_baguettotron = gr.Slider(0, 2, value=0.5, label=s["LABEL_TEMPERATURE"], info=s["INFO_TEMP_BAGUETTOTRON"]) | |
| max_tok_baguettotron = gr.Number(value=512, label=s["LABEL_MAX_TOKENS"], minimum=64, maximum=2048) | |
| top_p_baguettotron = gr.Slider(0, 1, value=0.9, label=s["LABEL_TOP_P"]) | |
| top_k_baguettotron = gr.Number(value=40, label=s["LABEL_TOP_K"]) | |
| rep_baguettotron = gr.Slider(1.0, 1.5, value=1.1, label=s["LABEL_REPEAT_PENALTY"]) | |
| with gr.Column(): | |
| col_luth_md = gr.Markdown(s["COL_LUTH_HEADING"]) | |
| temp_luth = gr.Slider(0, 2, value=0.7, label=s["LABEL_TEMPERATURE"]) | |
| max_tok_luth = gr.Number(value=256, label=s["LABEL_MAX_TOKENS"], minimum=64, maximum=2048) | |
| top_p_luth = gr.Slider(0, 1, value=0.9, label=s["LABEL_TOP_P"]) | |
| top_k_luth = gr.Number(value=40, label=s["LABEL_TOP_K"]) | |
| rep_luth = gr.Slider(1.0, 1.5, value=1.05, label=s["LABEL_REPEAT_PENALTY"], info=s["INFO_REP_LUTH"]) | |
| # Row 3: Live inference — outputs above inputs | |
| heading_live_md = gr.Markdown(s["HEADING_LIVE_INFERENCE"]) | |
| models_by_tier = get_models_by_tier() | |
| with gr.Tabs() as tabs_root: | |
| with gr.Tab(s["TIER_SMALL"], id="tab_small") as tab_small: | |
| with gr.Row(): | |
| out_baguettotron = gr.Textbox( | |
| label=s["LABEL_OUT_BAGUETTOTRON"], | |
| lines=12, | |
| max_lines=24, | |
| value=startup_outputs[0], | |
| ) | |
| out_luth_350 = gr.Textbox( | |
| label=s["LABEL_OUT_LUTH_350"], | |
| lines=12, | |
| max_lines=24, | |
| value=startup_outputs[1], | |
| ) | |
| with gr.Tab(s["TIER_MEDIUM"], id="tab_medium") as tab_medium: | |
| with gr.Row(): | |
| out_luth_06 = gr.Textbox( | |
| label=s["LABEL_OUT_LUTH_06"], | |
| lines=12, | |
| max_lines=24, | |
| value=startup_outputs[2], | |
| ) | |
| out_luth_07 = gr.Textbox( | |
| label=s["LABEL_OUT_LUTH_07"], | |
| lines=12, | |
| max_lines=24, | |
| value=startup_outputs[3], | |
| ) | |
| with gr.Tab(s["TIER_LARGE"], id="tab_large") as tab_large: | |
| with gr.Row(): | |
| out_luth_12 = gr.Textbox( | |
| label=s["LABEL_OUT_LUTH_12"], | |
| lines=12, | |
| max_lines=24, | |
| value=startup_outputs[4], | |
| ) | |
| out_luth_17 = gr.Textbox( | |
| label=s["LABEL_OUT_LUTH_17"], | |
| lines=12, | |
| max_lines=24, | |
| value=startup_outputs[5], | |
| ) | |
| system_prompt_in = gr.Textbox( | |
| label=s["LABEL_SYSTEM_PROMPT"], | |
| placeholder=s["PLACEHOLDER_SYSTEM_PROMPT"], | |
| lines=2, | |
| ) | |
| prompt_in = gr.Textbox( | |
| label=s["LABEL_PROMPT"], | |
| placeholder=s["PLACEHOLDER_PROMPT"], | |
| lines=3, | |
| value=STARTUP_EXAMPLE_PROMPT, | |
| ) | |
| gen_btn = gr.Button(s["BTN_GENERATE"], variant="primary") | |
| all_inputs = [ | |
| prompt_in, | |
| system_prompt_in, | |
| temp_baguettotron, | |
| max_tok_baguettotron, | |
| top_p_baguettotron, | |
| top_k_baguettotron, | |
| rep_baguettotron, | |
| temp_luth, | |
| max_tok_luth, | |
| top_p_luth, | |
| top_k_luth, | |
| rep_luth, | |
| ] | |
| all_outputs = [ | |
| out_baguettotron, | |
| out_luth_350, | |
| out_luth_06, | |
| out_luth_07, | |
| out_luth_12, | |
| out_luth_17, | |
| ] | |
| gen_btn.click( | |
| fn=generate_all, | |
| inputs=all_inputs, | |
| outputs=all_outputs, | |
| ) | |
| # How to use & join us | |
| heading_how_to_use_md = gr.Markdown(s["HEADING_HOW_TO_USE"]) | |
| how_to_use_md = gr.Markdown(s["HOW_TO_USE"]) | |
| join_us_md = gr.Markdown(s["JOIN_US"]) | |
| # Language toggle: update all visible strings | |
| _lang_outputs = [ | |
| title_md, | |
| subtitle_md, | |
| heading_footprint_md, | |
| footprint_intro_md, | |
| footprint_df, | |
| footprint_summary_md, | |
| heading_generation_md, | |
| col_baguettotron_md, | |
| col_luth_md, | |
| temp_baguettotron, | |
| max_tok_baguettotron, | |
| top_p_baguettotron, | |
| top_k_baguettotron, | |
| rep_baguettotron, | |
| temp_luth, | |
| max_tok_luth, | |
| top_p_luth, | |
| top_k_luth, | |
| rep_luth, | |
| heading_live_md, | |
| tab_small, | |
| tab_medium, | |
| tab_large, | |
| out_baguettotron, | |
| out_luth_350, | |
| out_luth_06, | |
| out_luth_07, | |
| out_luth_12, | |
| out_luth_17, | |
| system_prompt_in, | |
| prompt_in, | |
| gen_btn, | |
| heading_how_to_use_md, | |
| how_to_use_md, | |
| join_us_md, | |
| ] | |
| lang_radio.change( | |
| fn=_ui_updates_for_locale, | |
| inputs=[lang_radio], | |
| outputs=_lang_outputs, | |
| ) | |
| demo._custom_css = TOP_TOGGLES_CSS | |
| _original_launch = demo.launch | |
| def _launch_with_css(**kwargs): | |
| if "css" not in kwargs and demo._custom_css: | |
| kwargs["css"] = demo._custom_css | |
| return _original_launch(**kwargs) | |
| demo.launch = _launch_with_css | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_ui() | |
| demo.launch(ssr_mode=False) | |