import math import time import torch import gradio as gr import spaces import pandas as pd from datasets import DATASETS from model import load_fresh_model, train_model, infer, TOKENIZER from logo_b64 import LOGO_B64 # --------------------------------------------------------------------------- # Per-session state factory # --------------------------------------------------------------------------- def make_state(): """Called by gr.State for each new browser session.""" return {"model": None, "trained_on": None} def _detect_device(): if torch.cuda.is_available(): return "cuda" elif torch.mps.is_available(): return "mps" return "cpu" # --------------------------------------------------------------------------- # Event handlers # --------------------------------------------------------------------------- def on_dataset_change(dataset_name): pairs = [[inp, out] for inp, out in DATASETS[dataset_name]] return pairs def _overfitting_warning(loss_records): """Return a warning string if the final loss is extremely low, or None.""" if not loss_records: return None final_loss = 10 ** loss_records[-1]["Log Loss"] if final_loss < 0.01: return ( "> **Possible overfitting:** the loss is extremely low, which on a small " "dataset usually means the model has memorized the examples rather than " "learned the pattern. Try fewer epochs or a lower learning rate." ) return None @spaces.GPU(duration=300) def on_train(dataset_name, epochs, lr, state): """Generator — yields (progress, state, status, train_btn, reset_btn) after each step.""" device = _detect_device() state["device"] = device yield ( None, state, "**Status:** Loading model...", gr.update(interactive=False), gr.update(interactive=False), ) model = load_fresh_model() model.to(device) # type:ignore tuples = DATASETS[dataset_name] loss_records = [] for epoch_num, loss in train_model(model, TOKENIZER, tuples, device, epochs=epochs, lr=float(lr)): loss_records.append({"Epoch": epoch_num, "Log Loss": math.log10(loss)}) df = pd.DataFrame(loss_records) yield ( df, state, f"**Status:** Training... Epoch {epoch_num}/{epochs} | Loss: {loss:.4f}", gr.update(interactive=False), gr.update(interactive=False), ) state["model"] = model.cpu() state["trained_on"] = dataset_name status = f"**Status:** Trained on '{dataset_name}'" warning = _overfitting_warning(loss_records) if warning: status += f"\n\n{warning}" yield ( pd.DataFrame(loss_records), state, status, gr.update(interactive=True), gr.update(interactive=True), ) def on_reset(state): state["model"] = None state["trained_on"] = None return ( state, "**Status:** Untrained (echoing)", gr.update(interactive=True), gr.update(interactive=False), None, ) def on_user_message(message, history): """Immediately append the user message and clear the input box.""" if not message.strip(): return history, message return history + [{"role": "user", "content": message}], "" @spaces.GPU def on_bot_response(history, num_beams, state): """Run inference and append the assistant reply.""" if not history or history[-1]["role"] != "user": return history message = history[-1]["content"] if state["model"] is None: time.sleep(1) response = message else: device = _detect_device() model = state["model"].to(device) results = infer(model, TOKENIZER, message, device, num_beams=num_beams) model.cpu() # move back to CPU before ZeroGPU releases the allocation response = results[0] return history + [{"role": "assistant", "content": response}] # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- first_dataset = list(DATASETS.keys())[0] with gr.Blocks(title="EchoBot", css=".align-bottom { margin-top: auto; margin-bottom: auto }") as demo: state = gr.State(make_state) gr.HTML( '
Select a dataset, train the model, then chat to see how EchoBot responds!
' '