File size: 4,109 Bytes
aad4104
 
 
 
 
 
 
 
 
c29780f
aad4104
 
 
 
 
 
 
 
 
c29780f
 
aad4104
 
 
c29780f
 
 
aad4104
 
c29780f
aad4104
 
 
 
 
 
 
 
 
 
 
 
 
c29780f
 
 
aad4104
 
c29780f
aad4104
 
 
c29780f
 
aad4104
 
 
 
 
 
 
 
 
 
 
 
 
 
c29780f
aad4104
 
c29780f
 
aad4104
 
 
 
 
 
 
 
 
 
c29780f
aad4104
 
c29780f
aad4104
c29780f
aad4104
 
 
 
 
 
c29780f
 
aad4104
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import time

import gradio as gr
import tiktoken
import torch

from train import MoENullModel, Config

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_PATH = "checkpoint_3000.pt"

enc = tiktoken.get_encoding("gpt2")

def encode(text):
    return enc.encode(text, allowed_special={"<|endoftext|>"})

def decode(ids):
    return enc.decode(ids)

def load_model():
    ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
    model = MoENullModel(ckpt["config"])
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval().to(DEVICE)
    return model, ckpt.get("step", "?")

model, ckpt_step = load_model()

@torch.no_grad()
def generate(prompt_ids, max_new_tokens, temperature, top_k):
    idx = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE)
    seq_len = getattr(getattr(model, "config", None), "seq_len", 128)
    for _ in range(max_new_tokens):
        logits = model(idx[:, -seq_len:])
        if isinstance(logits, tuple):
            logits = logits[0]
        logits = logits[:, -1, :] / max(temperature, 1e-6)
        if top_k > 0:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float("-inf")
        idx = torch.cat([idx, torch.multinomial(torch.softmax(logits, dim=-1), 1)], dim=1)
    return idx[0].tolist()

def run_generation(prompt, max_new_tokens, temperature, top_k):
    if not prompt.strip():
        return "⚠️  Empty prompt."
    start = time.time()
    try:
        out = generate(encode(prompt), int(max_new_tokens), temperature, int(top_k))
    except Exception as e:
        return f"❌  Generation error:\n{e}"
    elapsed = time.time() - start
    new_ids = out[len(encode(prompt)):]
    stats = f"\n\n─────────────────────────────\n⚑ {len(new_ids)} tokens in {elapsed:.2f}s  ({len(new_ids)/elapsed:.1f} tok/s)  |  step {ckpt_step}  |  {DEVICE}"
    return prompt + decode(new_ids) + stats

EXAMPLE_PROMPTS = [
    "ROMEO:",
    "To be, or not to be,",
    "HAMLET:\nWhat a piece of work is",
    "KING LEAR:\nBlow, winds, and crack your cheeks!",
    "First Citizen:\nBefore we proceed any further,",
    "All the world's a stage,",
]

with gr.Blocks(
    title="Shakespeare MoE",
    theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
    css="#output-box textarea { font-family: 'Georgia', serif; font-size: 15px; line-height: 1.7; }",
) as demo:

    gr.HTML("<h1 style='text-align:center;margin-bottom:4px'>🎭 Shakespeare MoE</h1>")
    gr.HTML("<p style='text-align:center;color:#888;margin-bottom:20px'>Mixture-of-Experts language model trained on Tiny Shakespeare</p>")

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### πŸŽ›οΈ Settings")
            max_tokens  = gr.Slider(10, 128, value=100, step=10,   label="Max new tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature")
            top_k       = gr.Slider(0, 100,  value=40,  step=5,    label="Top-k  (0 = disabled)")

        with gr.Column(scale=2):
            gr.Markdown("### ✍️ Prompt")
            prompt_box   = gr.Textbox(placeholder="Type a prompt or click an example below…", lines=4, show_label=False)
            generate_btn = gr.Button("✨ Generate", variant="primary", size="lg")
            gr.Markdown("### πŸ“œ Output")
            output_box   = gr.Textbox(lines=14, show_label=False, interactive=False, elem_id="output-box")

    gr.Markdown("### πŸ’‘ Click a prompt, then hit Generate")
    with gr.Row():
        prompt_btns = [gr.Button(p.replace("\n", " "), size="sm") for p in EXAMPLE_PROMPTS]

    for btn, prompt_text in zip(prompt_btns, EXAMPLE_PROMPTS):
        btn.click(fn=lambda p=prompt_text: p, outputs=prompt_box)

    generate_btn.click(run_generation, inputs=[prompt_box, max_tokens, temperature, top_k], outputs=output_box)
    prompt_box.submit(run_generation,  inputs=[prompt_box, max_tokens, temperature, top_k], outputs=output_box)

if __name__ == "__main__":
    demo.launch(share=False)