File size: 4,713 Bytes
ea2eee0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
Interactive chat with your GPT-2 model β Notebook version
Run Cell 1 and Cell 2 in order, then use the chat widget.
You can copy the code into colab cells. ensure you rename to your model path
"""
# ββ Cell 1: Setup & load model ββββββββββββββββββββββββββββββββββββββββββββββββ
import torch
import tiktoken
import ipywidgets as widgets
from IPython.display import display, HTML
from Architecture import LanguageModel
from GeneratorHF import HFTextGenerator
CHECKPOINT_PATH = "/content/drive/MyDrive/GPT2-small/refinedGPT2small.pth" #on google colab
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(checkpoint_path, device):
print(f"Loading model from {checkpoint_path}...")
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = ckpt["config"]
model = LanguageModel(config)
state_dict = ckpt["model_state_dict"]
cleaned = {}
for k, v in state_dict.items():
for prefix in ["module.", "_orig_mod."]:
while k.startswith(prefix):
k = k[len(prefix):]
cleaned[k] = v
model.load_state_dict(cleaned)
model.to(device)
model.eval()
params = sum(p.numel() for p in model.parameters()) / 1e6
val_loss = ckpt.get("loss", "?")
print(f" {params:.1f}M params | val_loss {val_loss} | device {device}")
return model, config
model, config = load_model(CHECKPOINT_PATH, DEVICE)
enc = tiktoken.get_encoding("gpt2")
gen = HFTextGenerator(model, enc, DEVICE, context_size=config["context_length"])
print("Ready.")
# ββ Cell 2: Chat widget βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ββ Sampling controls βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
temp_slider = widgets.FloatSlider(
value=0.8, min=0.1, max=2.0, step=0.05,
description="Temperature",
style={"description_width": "90px"},
layout=widgets.Layout(width="400px")
)
topk_slider = widgets.IntSlider(
value=40, min=1, max=200, step=1,
description="Top-k",
style={"description_width": "90px"},
layout=widgets.Layout(width="400px")
)
tokens_slider = widgets.IntSlider(
value=150, min=10, max=500, step=10,
description="Max tokens",
style={"description_width": "90px"},
layout=widgets.Layout(width="400px")
)
# ββ Input & buttons βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
prompt_box = widgets.Textarea(
placeholder="Type your prompt here...",
layout=widgets.Layout(width="700px", height="70px")
)
generate_btn = widgets.Button(
description="Generate",
button_style="primary",
layout=widgets.Layout(width="120px")
)
clear_btn = widgets.Button(
description="Clear",
button_style="warning",
layout=widgets.Layout(width="80px")
)
# ββ Output area βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
output_area = widgets.Output()
def on_generate(b):
prompt = prompt_box.value.strip()
if not prompt:
return
with output_area:
print(f"\nYOU > {prompt}")
print("MODEL > ", end="", flush=True)
continuation = gen.generate(
prompt,
max_new_tokens=tokens_slider.value,
temperature=temp_slider.value,
top_k=topk_slider.value,
)
print(continuation)
print("-" * 70)
def on_clear(b):
output_area.clear_output()
prompt_box.value = ""
generate_btn.on_click(on_generate)
clear_btn.on_click(on_clear)
# ββ Layout ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
controls = widgets.VBox([
widgets.HTML("<b>Sampling settings</b>"),
temp_slider,
topk_slider,
tokens_slider,
])
buttons = widgets.HBox([generate_btn, clear_btn])
ui = widgets.VBox([
widgets.HTML("<h3>GPT-2 Small β Chat</h3>"),
controls,
widgets.HTML("<br><b>Prompt</b>"),
prompt_box,
buttons,
widgets.HTML("<br><b>Output</b>"),
output_area,
])
display(ui) |