File size: 4,713 Bytes
ea2eee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Interactive chat with your GPT-2 model β€” Notebook version
Run Cell 1 and Cell 2 in order, then use the chat widget.
You can copy the code into colab cells. ensure you rename to your model path
"""

# ── Cell 1: Setup & load model ────────────────────────────────────────────────
import torch
import tiktoken
import ipywidgets as widgets
from IPython.display import display, HTML
from Architecture import LanguageModel
from GeneratorHF import HFTextGenerator

CHECKPOINT_PATH = "/content/drive/MyDrive/GPT2-small/refinedGPT2small.pth" #on google colab
DEVICE          = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(checkpoint_path, device):
    print(f"Loading model from {checkpoint_path}...")
    ckpt       = torch.load(checkpoint_path, map_location=device, weights_only=False)
    config     = ckpt["config"]
    model      = LanguageModel(config)
    state_dict = ckpt["model_state_dict"]

    cleaned = {}
    for k, v in state_dict.items():
        for prefix in ["module.", "_orig_mod."]:
            while k.startswith(prefix):
                k = k[len(prefix):]
        cleaned[k] = v

    model.load_state_dict(cleaned)
    model.to(device)
    model.eval()

    params   = sum(p.numel() for p in model.parameters()) / 1e6
    val_loss = ckpt.get("loss", "?")
    print(f"  {params:.1f}M params | val_loss {val_loss} | device {device}")
    return model, config

model, config = load_model(CHECKPOINT_PATH, DEVICE)
enc           = tiktoken.get_encoding("gpt2")
gen           = HFTextGenerator(model, enc, DEVICE, context_size=config["context_length"])
print("Ready.")


# ── Cell 2: Chat widget ───────────────────────────────────────────────────────

# ── Sampling controls ─────────────────────────────────────────────────────────
temp_slider = widgets.FloatSlider(
    value=0.8, min=0.1, max=2.0, step=0.05,
    description="Temperature",
    style={"description_width": "90px"},
    layout=widgets.Layout(width="400px")
)
topk_slider = widgets.IntSlider(
    value=40, min=1, max=200, step=1,
    description="Top-k",
    style={"description_width": "90px"},
    layout=widgets.Layout(width="400px")
)
tokens_slider = widgets.IntSlider(
    value=150, min=10, max=500, step=10,
    description="Max tokens",
    style={"description_width": "90px"},
    layout=widgets.Layout(width="400px")
)

# ── Input & buttons ───────────────────────────────────────────────────────────
prompt_box = widgets.Textarea(
    placeholder="Type your prompt here...",
    layout=widgets.Layout(width="700px", height="70px")
)
generate_btn = widgets.Button(
    description="Generate",
    button_style="primary",
    layout=widgets.Layout(width="120px")
)
clear_btn = widgets.Button(
    description="Clear",
    button_style="warning",
    layout=widgets.Layout(width="80px")
)

# ── Output area ───────────────────────────────────────────────────────────────
output_area = widgets.Output()

def on_generate(b):
    prompt = prompt_box.value.strip()
    if not prompt:
        return

    with output_area:
        print(f"\nYOU   > {prompt}")
        print("MODEL > ", end="", flush=True)

        continuation = gen.generate(
            prompt,
            max_new_tokens=tokens_slider.value,
            temperature=temp_slider.value,
            top_k=topk_slider.value,
        )
        print(continuation)
        print("-" * 70)

def on_clear(b):
    output_area.clear_output()
    prompt_box.value = ""

generate_btn.on_click(on_generate)
clear_btn.on_click(on_clear)

# ── Layout ────────────────────────────────────────────────────────────────────
controls = widgets.VBox([
    widgets.HTML("<b>Sampling settings</b>"),
    temp_slider,
    topk_slider,
    tokens_slider,
])
buttons = widgets.HBox([generate_btn, clear_btn])
ui      = widgets.VBox([
    widgets.HTML("<h3>GPT-2 Small β€” Chat</h3>"),
    controls,
    widgets.HTML("<br><b>Prompt</b>"),
    prompt_box,
    buttons,
    widgets.HTML("<br><b>Output</b>"),
    output_area,
])

display(ui)