""" Interactive chat with your GPT-2 model — Notebook version Run Cell 1 and Cell 2 in order, then use the chat widget. You can copy the code into colab cells. ensure you rename to your model path """ # ── Cell 1: Setup & load model ──────────────────────────────────────────────── import torch import tiktoken import ipywidgets as widgets from IPython.display import display, HTML from Architecture import LanguageModel from GeneratorHF import HFTextGenerator CHECKPOINT_PATH = "/content/drive/MyDrive/GPT2-small/refinedGPT2small.pth" #on google colab DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_model(checkpoint_path, device): print(f"Loading model from {checkpoint_path}...") ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) config = ckpt["config"] model = LanguageModel(config) state_dict = ckpt["model_state_dict"] cleaned = {} for k, v in state_dict.items(): for prefix in ["module.", "_orig_mod."]: while k.startswith(prefix): k = k[len(prefix):] cleaned[k] = v model.load_state_dict(cleaned) model.to(device) model.eval() params = sum(p.numel() for p in model.parameters()) / 1e6 val_loss = ckpt.get("loss", "?") print(f" {params:.1f}M params | val_loss {val_loss} | device {device}") return model, config model, config = load_model(CHECKPOINT_PATH, DEVICE) enc = tiktoken.get_encoding("gpt2") gen = HFTextGenerator(model, enc, DEVICE, context_size=config["context_length"]) print("Ready.") # ── Cell 2: Chat widget ─────────────────────────────────────────────────────── # ── Sampling controls ───────────────────────────────────────────────────────── temp_slider = widgets.FloatSlider( value=0.8, min=0.1, max=2.0, step=0.05, description="Temperature", style={"description_width": "90px"}, layout=widgets.Layout(width="400px") ) topk_slider = widgets.IntSlider( value=40, min=1, max=200, step=1, description="Top-k", style={"description_width": "90px"}, layout=widgets.Layout(width="400px") ) tokens_slider = widgets.IntSlider( value=150, min=10, max=500, step=10, description="Max tokens", style={"description_width": "90px"}, layout=widgets.Layout(width="400px") ) # ── Input & buttons ─────────────────────────────────────────────────────────── prompt_box = widgets.Textarea( placeholder="Type your prompt here...", layout=widgets.Layout(width="700px", height="70px") ) generate_btn = widgets.Button( description="Generate", button_style="primary", layout=widgets.Layout(width="120px") ) clear_btn = widgets.Button( description="Clear", button_style="warning", layout=widgets.Layout(width="80px") ) # ── Output area ─────────────────────────────────────────────────────────────── output_area = widgets.Output() def on_generate(b): prompt = prompt_box.value.strip() if not prompt: return with output_area: print(f"\nYOU > {prompt}") print("MODEL > ", end="", flush=True) continuation = gen.generate( prompt, max_new_tokens=tokens_slider.value, temperature=temp_slider.value, top_k=topk_slider.value, ) print(continuation) print("-" * 70) def on_clear(b): output_area.clear_output() prompt_box.value = "" generate_btn.on_click(on_generate) clear_btn.on_click(on_clear) # ── Layout ──────────────────────────────────────────────────────────────────── controls = widgets.VBox([ widgets.HTML("Sampling settings"), temp_slider, topk_slider, tokens_slider, ]) buttons = widgets.HBox([generate_btn, clear_btn]) ui = widgets.VBox([ widgets.HTML("