Spaces:
Sleeping
Sleeping
| import time | |
| import gradio as gr | |
| import tiktoken | |
| import torch | |
| from train import MoENullModel, Config | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| CHECKPOINT_PATH = "checkpoint_3000.pt" | |
| enc = tiktoken.get_encoding("gpt2") | |
| def encode(text): | |
| return enc.encode(text, allowed_special={"<|endoftext|>"}) | |
| def decode(ids): | |
| return enc.decode(ids) | |
| def load_model(): | |
| ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) | |
| model = MoENullModel(ckpt["config"]) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| model.eval().to(DEVICE) | |
| return model, ckpt.get("step", "?") | |
| model, ckpt_step = load_model() | |
| def generate(prompt_ids, max_new_tokens, temperature, top_k): | |
| idx = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE) | |
| seq_len = getattr(getattr(model, "config", None), "seq_len", 128) | |
| for _ in range(max_new_tokens): | |
| logits = model(idx[:, -seq_len:]) | |
| if isinstance(logits, tuple): | |
| logits = logits[0] | |
| logits = logits[:, -1, :] / max(temperature, 1e-6) | |
| if top_k > 0: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = float("-inf") | |
| idx = torch.cat([idx, torch.multinomial(torch.softmax(logits, dim=-1), 1)], dim=1) | |
| return idx[0].tolist() | |
| def run_generation(prompt, max_new_tokens, temperature, top_k): | |
| if not prompt.strip(): | |
| return "β οΈ Empty prompt." | |
| start = time.time() | |
| try: | |
| out = generate(encode(prompt), int(max_new_tokens), temperature, int(top_k)) | |
| except Exception as e: | |
| return f"β Generation error:\n{e}" | |
| elapsed = time.time() - start | |
| new_ids = out[len(encode(prompt)):] | |
| stats = f"\n\nβββββββββββββββββββββββββββββ\nβ‘ {len(new_ids)} tokens in {elapsed:.2f}s ({len(new_ids)/elapsed:.1f} tok/s) | step {ckpt_step} | {DEVICE}" | |
| return prompt + decode(new_ids) + stats | |
| EXAMPLE_PROMPTS = [ | |
| "ROMEO:", | |
| "To be, or not to be,", | |
| "HAMLET:\nWhat a piece of work is", | |
| "KING LEAR:\nBlow, winds, and crack your cheeks!", | |
| "First Citizen:\nBefore we proceed any further,", | |
| "All the world's a stage,", | |
| ] | |
| with gr.Blocks( | |
| title="Shakespeare MoE", | |
| theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), | |
| css="#output-box textarea { font-family: 'Georgia', serif; font-size: 15px; line-height: 1.7; }", | |
| ) as demo: | |
| gr.HTML("<h1 style='text-align:center;margin-bottom:4px'>π Shakespeare MoE</h1>") | |
| gr.HTML("<p style='text-align:center;color:#888;margin-bottom:20px'>Mixture-of-Experts language model trained on Tiny Shakespeare</p>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ποΈ Settings") | |
| max_tokens = gr.Slider(10, 128, value=100, step=10, label="Max new tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature") | |
| top_k = gr.Slider(0, 100, value=40, step=5, label="Top-k (0 = disabled)") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### βοΈ Prompt") | |
| prompt_box = gr.Textbox(placeholder="Type a prompt or click an example belowβ¦", lines=4, show_label=False) | |
| generate_btn = gr.Button("β¨ Generate", variant="primary", size="lg") | |
| gr.Markdown("### π Output") | |
| output_box = gr.Textbox(lines=14, show_label=False, interactive=False, elem_id="output-box") | |
| gr.Markdown("### π‘ Click a prompt, then hit Generate") | |
| with gr.Row(): | |
| prompt_btns = [gr.Button(p.replace("\n", " "), size="sm") for p in EXAMPLE_PROMPTS] | |
| for btn, prompt_text in zip(prompt_btns, EXAMPLE_PROMPTS): | |
| btn.click(fn=lambda p=prompt_text: p, outputs=prompt_box) | |
| generate_btn.click(run_generation, inputs=[prompt_box, max_tokens, temperature, top_k], outputs=output_box) | |
| prompt_box.submit(run_generation, inputs=[prompt_box, max_tokens, temperature, top_k], outputs=output_box) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |