LikeGPT2small / chat.py
Zemulax's picture
files for inference
ea2eee0 verified
"""
Interactive chat with your GPT-2 model β€” Notebook version
Run Cell 1 and Cell 2 in order, then use the chat widget.
You can copy the code into colab cells. ensure you rename to your model path
"""
# ── Cell 1: Setup & load model ────────────────────────────────────────────────
import torch
import tiktoken
import ipywidgets as widgets
from IPython.display import display, HTML
from Architecture import LanguageModel
from GeneratorHF import HFTextGenerator
CHECKPOINT_PATH = "/content/drive/MyDrive/GPT2-small/refinedGPT2small.pth" #on google colab
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(checkpoint_path, device):
print(f"Loading model from {checkpoint_path}...")
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = ckpt["config"]
model = LanguageModel(config)
state_dict = ckpt["model_state_dict"]
cleaned = {}
for k, v in state_dict.items():
for prefix in ["module.", "_orig_mod."]:
while k.startswith(prefix):
k = k[len(prefix):]
cleaned[k] = v
model.load_state_dict(cleaned)
model.to(device)
model.eval()
params = sum(p.numel() for p in model.parameters()) / 1e6
val_loss = ckpt.get("loss", "?")
print(f" {params:.1f}M params | val_loss {val_loss} | device {device}")
return model, config
model, config = load_model(CHECKPOINT_PATH, DEVICE)
enc = tiktoken.get_encoding("gpt2")
gen = HFTextGenerator(model, enc, DEVICE, context_size=config["context_length"])
print("Ready.")
# ── Cell 2: Chat widget ───────────────────────────────────────────────────────
# ── Sampling controls ─────────────────────────────────────────────────────────
temp_slider = widgets.FloatSlider(
value=0.8, min=0.1, max=2.0, step=0.05,
description="Temperature",
style={"description_width": "90px"},
layout=widgets.Layout(width="400px")
)
topk_slider = widgets.IntSlider(
value=40, min=1, max=200, step=1,
description="Top-k",
style={"description_width": "90px"},
layout=widgets.Layout(width="400px")
)
tokens_slider = widgets.IntSlider(
value=150, min=10, max=500, step=10,
description="Max tokens",
style={"description_width": "90px"},
layout=widgets.Layout(width="400px")
)
# ── Input & buttons ───────────────────────────────────────────────────────────
prompt_box = widgets.Textarea(
placeholder="Type your prompt here...",
layout=widgets.Layout(width="700px", height="70px")
)
generate_btn = widgets.Button(
description="Generate",
button_style="primary",
layout=widgets.Layout(width="120px")
)
clear_btn = widgets.Button(
description="Clear",
button_style="warning",
layout=widgets.Layout(width="80px")
)
# ── Output area ───────────────────────────────────────────────────────────────
output_area = widgets.Output()
def on_generate(b):
prompt = prompt_box.value.strip()
if not prompt:
return
with output_area:
print(f"\nYOU > {prompt}")
print("MODEL > ", end="", flush=True)
continuation = gen.generate(
prompt,
max_new_tokens=tokens_slider.value,
temperature=temp_slider.value,
top_k=topk_slider.value,
)
print(continuation)
print("-" * 70)
def on_clear(b):
output_area.clear_output()
prompt_box.value = ""
generate_btn.on_click(on_generate)
clear_btn.on_click(on_clear)
# ── Layout ────────────────────────────────────────────────────────────────────
controls = widgets.VBox([
widgets.HTML("<b>Sampling settings</b>"),
temp_slider,
topk_slider,
tokens_slider,
])
buttons = widgets.HBox([generate_btn, clear_btn])
ui = widgets.VBox([
widgets.HTML("<h3>GPT-2 Small β€” Chat</h3>"),
controls,
widgets.HTML("<br><b>Prompt</b>"),
prompt_box,
buttons,
widgets.HTML("<br><b>Output</b>"),
output_area,
])
display(ui)