echobot / app.py
simonguest's picture
Deploy
afee189
import math
import time
import torch
import gradio as gr
import spaces
import pandas as pd
from datasets import DATASETS
from model import load_fresh_model, train_model, infer, TOKENIZER
from logo_b64 import LOGO_B64
# ---------------------------------------------------------------------------
# Per-session state factory
# ---------------------------------------------------------------------------
def make_state():
"""Called by gr.State for each new browser session."""
return {"model": None, "trained_on": None}
def _detect_device():
if torch.cuda.is_available():
return "cuda"
elif torch.mps.is_available():
return "mps"
return "cpu"
# ---------------------------------------------------------------------------
# Event handlers
# ---------------------------------------------------------------------------
def on_dataset_change(dataset_name):
pairs = [[inp, out] for inp, out in DATASETS[dataset_name]]
return pairs
def _overfitting_warning(loss_records):
"""Return a warning string if the final loss is extremely low, or None."""
if not loss_records:
return None
final_loss = 10 ** loss_records[-1]["Log Loss"]
if final_loss < 0.01:
return (
"> **Possible overfitting:** the loss is extremely low, which on a small "
"dataset usually means the model has memorized the examples rather than "
"learned the pattern. Try fewer epochs or a lower learning rate."
)
return None
@spaces.GPU(duration=300)
def on_train(dataset_name, epochs, lr, state):
"""Generator — yields (progress, state, status, train_btn, reset_btn) after each step."""
device = _detect_device()
state["device"] = device
yield (
None,
state,
"**Status:** Loading model...",
gr.update(interactive=False),
gr.update(interactive=False),
)
model = load_fresh_model()
model.to(device) # type:ignore
tuples = DATASETS[dataset_name]
loss_records = []
for epoch_num, loss in train_model(model, TOKENIZER, tuples, device, epochs=epochs, lr=float(lr)):
loss_records.append({"Epoch": epoch_num, "Log Loss": math.log10(loss)})
df = pd.DataFrame(loss_records)
yield (
df,
state,
f"**Status:** Training... Epoch {epoch_num}/{epochs} | Loss: {loss:.4f}",
gr.update(interactive=False),
gr.update(interactive=False),
)
state["model"] = model.cpu()
state["trained_on"] = dataset_name
status = f"**Status:** Trained on '{dataset_name}'"
warning = _overfitting_warning(loss_records)
if warning:
status += f"\n\n{warning}"
yield (
pd.DataFrame(loss_records),
state,
status,
gr.update(interactive=True),
gr.update(interactive=True),
)
def on_reset(state):
state["model"] = None
state["trained_on"] = None
return (
state,
"**Status:** Untrained (echoing)",
gr.update(interactive=True),
gr.update(interactive=False),
None,
)
def on_user_message(message, history):
"""Immediately append the user message and clear the input box."""
if not message.strip():
return history, message
return history + [{"role": "user", "content": message}], ""
@spaces.GPU
def on_bot_response(history, num_beams, state):
"""Run inference and append the assistant reply."""
if not history or history[-1]["role"] != "user":
return history
message = history[-1]["content"]
if state["model"] is None:
time.sleep(1)
response = message
else:
device = _detect_device()
model = state["model"].to(device)
results = infer(model, TOKENIZER, message, device, num_beams=num_beams)
model.cpu() # move back to CPU before ZeroGPU releases the allocation
response = results[0]
return history + [{"role": "assistant", "content": response}]
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
first_dataset = list(DATASETS.keys())[0]
with gr.Blocks(title="EchoBot", css=".align-bottom { margin-top: auto; margin-bottom: auto }") as demo:
state = gr.State(make_state)
gr.HTML(
'<div style="text-align:center">'
f'<img src="data:image/png;base64,{LOGO_B64}" style="display:block;margin:0 auto;height:300px">'
'<p>Select a dataset, train the model, then chat to see how EchoBot responds!</p>'
'</div>'
)
with gr.Row():
# ---- Column 1: dataset explorer ----
with gr.Column(scale=1):
gr.Markdown("## Dataset")
dataset_dropdown = gr.Dropdown(
choices=list(DATASETS.keys()),
value=first_dataset,
label="Select Dataset",
)
dataset_table = gr.Dataframe(
value=[[inp, out] for inp, out in DATASETS[first_dataset]],
headers=["Input", "Output"],
interactive=False,
label="Input / Output Pairs",
wrap=True,
)
# ---- Column 2: training controls ----
with gr.Column(scale=1):
gr.Markdown("## Training")
status_display = gr.Markdown("**Status:** Untrained (echoing)")
epochs_slider = gr.Slider(
minimum=1, maximum=50, step=1, value=10,
label="Epochs",
)
lr_dropdown = gr.Dropdown(
choices=[
("1e-3 — high (aggressive)", "1e-3"),
("3e-4 — medium (default)", "3e-4"),
("1e-4 — low (cautious)", "1e-4"),
("1e-5 — very low (stable)", "1e-5"),
],
value="3e-4",
label="Learning Rate",
)
num_beams_slider = gr.Slider(
minimum=1, maximum=20, step=1, value=10,
label="Inference Beams",
)
train_btn = gr.Button("Train EchoBot", variant="primary")
loss_plot = gr.LinePlot(
value=None,
x="Epoch",
y="Log Loss",
label="Training Loss (log scale)",
min_width=200,
)
reset_btn = gr.Button("Reset EchoBot", variant="secondary", interactive=False)
# ---- Column 3: chat ----
with gr.Column(scale=1):
gr.Markdown("## Chat with EchoBot")
chatbot = gr.Chatbot(type="messages", height=520)
with gr.Row():
chat_input = gr.Textbox(
placeholder="Type a message and press Enter...",
show_label=False,
scale=4,
)
send_btn = gr.Button("Send", scale=1, elem_classes=["align-bottom"])
# ---- Event wiring ----
dataset_dropdown.change(
fn=on_dataset_change,
inputs=[dataset_dropdown],
outputs=[dataset_table],
)
train_btn.click(
fn=on_train,
inputs=[dataset_dropdown, epochs_slider, lr_dropdown, state],
outputs=[loss_plot, state, status_display, train_btn, reset_btn],
)
reset_btn.click(
fn=on_reset,
inputs=[state],
outputs=[state, status_display, train_btn, reset_btn, loss_plot],
)
send_btn.click(
fn=on_user_message,
inputs=[chat_input, chatbot],
outputs=[chatbot, chat_input],
).then(
fn=on_bot_response,
inputs=[chatbot, num_beams_slider, state],
outputs=[chatbot],
)
chat_input.submit(
fn=on_user_message,
inputs=[chat_input, chatbot],
outputs=[chatbot, chat_input],
).then(
fn=on_bot_response,
inputs=[chatbot, num_beams_slider, state],
outputs=[chatbot],
)
demo.queue()
demo.launch(server_name="0.0.0.0")