import time import gradio as gr import tiktoken import torch from train import MoENullModel, Config DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CHECKPOINT_PATH = "checkpoint_3000.pt" enc = tiktoken.get_encoding("gpt2") def encode(text): return enc.encode(text, allowed_special={"<|endoftext|>"}) def decode(ids): return enc.decode(ids) def load_model(): ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) model = MoENullModel(ckpt["config"]) model.load_state_dict(ckpt["model_state_dict"]) model.eval().to(DEVICE) return model, ckpt.get("step", "?") model, ckpt_step = load_model() @torch.no_grad() def generate(prompt_ids, max_new_tokens, temperature, top_k): idx = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE) seq_len = getattr(getattr(model, "config", None), "seq_len", 128) for _ in range(max_new_tokens): logits = model(idx[:, -seq_len:]) if isinstance(logits, tuple): logits = logits[0] logits = logits[:, -1, :] / max(temperature, 1e-6) if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") idx = torch.cat([idx, torch.multinomial(torch.softmax(logits, dim=-1), 1)], dim=1) return idx[0].tolist() def run_generation(prompt, max_new_tokens, temperature, top_k): if not prompt.strip(): return "⚠️ Empty prompt." start = time.time() try: out = generate(encode(prompt), int(max_new_tokens), temperature, int(top_k)) except Exception as e: return f"❌ Generation error:\n{e}" elapsed = time.time() - start new_ids = out[len(encode(prompt)):] stats = f"\n\n─────────────────────────────\n⚡ {len(new_ids)} tokens in {elapsed:.2f}s ({len(new_ids)/elapsed:.1f} tok/s) | step {ckpt_step} | {DEVICE}" return prompt + decode(new_ids) + stats EXAMPLE_PROMPTS = [ "ROMEO:", "To be, or not to be,", "HAMLET:\nWhat a piece of work is", "KING LEAR:\nBlow, winds, and crack your cheeks!", "First Citizen:\nBefore we proceed any further,", "All the world's a stage,", ] with gr.Blocks( title="Shakespeare MoE", theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), css="#output-box textarea { font-family: 'Georgia', serif; font-size: 15px; line-height: 1.7; }", ) as demo: gr.HTML("

🎭 Shakespeare MoE

") gr.HTML("

Mixture-of-Experts language model trained on Tiny Shakespeare

") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 🎛️ Settings") max_tokens = gr.Slider(10, 128, value=100, step=10, label="Max new tokens") temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature") top_k = gr.Slider(0, 100, value=40, step=5, label="Top-k (0 = disabled)") with gr.Column(scale=2): gr.Markdown("### ✍️ Prompt") prompt_box = gr.Textbox(placeholder="Type a prompt or click an example below…", lines=4, show_label=False) generate_btn = gr.Button("✨ Generate", variant="primary", size="lg") gr.Markdown("### 📜 Output") output_box = gr.Textbox(lines=14, show_label=False, interactive=False, elem_id="output-box") gr.Markdown("### 💡 Click a prompt, then hit Generate") with gr.Row(): prompt_btns = [gr.Button(p.replace("\n", " "), size="sm") for p in EXAMPLE_PROMPTS] for btn, prompt_text in zip(prompt_btns, EXAMPLE_PROMPTS): btn.click(fn=lambda p=prompt_text: p, outputs=prompt_box) generate_btn.click(run_generation, inputs=[prompt_box, max_tokens, temperature, top_k], outputs=output_box) prompt_box.submit(run_generation, inputs=[prompt_box, max_tokens, temperature, top_k], outputs=output_box) if __name__ == "__main__": demo.launch(share=False)