| """ |
| Interactive chat with your GPT-2 model β Notebook version |
| Run Cell 1 and Cell 2 in order, then use the chat widget. |
| You can copy the code into colab cells. ensure you rename to your model path |
| """ |
|
|
| |
| import torch |
| import tiktoken |
| import ipywidgets as widgets |
| from IPython.display import display, HTML |
| from Architecture import LanguageModel |
| from GeneratorHF import HFTextGenerator |
|
|
| CHECKPOINT_PATH = "/content/drive/MyDrive/GPT2-small/refinedGPT2small.pth" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def load_model(checkpoint_path, device): |
| print(f"Loading model from {checkpoint_path}...") |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| config = ckpt["config"] |
| model = LanguageModel(config) |
| state_dict = ckpt["model_state_dict"] |
|
|
| cleaned = {} |
| for k, v in state_dict.items(): |
| for prefix in ["module.", "_orig_mod."]: |
| while k.startswith(prefix): |
| k = k[len(prefix):] |
| cleaned[k] = v |
|
|
| model.load_state_dict(cleaned) |
| model.to(device) |
| model.eval() |
|
|
| params = sum(p.numel() for p in model.parameters()) / 1e6 |
| val_loss = ckpt.get("loss", "?") |
| print(f" {params:.1f}M params | val_loss {val_loss} | device {device}") |
| return model, config |
|
|
| model, config = load_model(CHECKPOINT_PATH, DEVICE) |
| enc = tiktoken.get_encoding("gpt2") |
| gen = HFTextGenerator(model, enc, DEVICE, context_size=config["context_length"]) |
| print("Ready.") |
|
|
|
|
| |
|
|
| |
| temp_slider = widgets.FloatSlider( |
| value=0.8, min=0.1, max=2.0, step=0.05, |
| description="Temperature", |
| style={"description_width": "90px"}, |
| layout=widgets.Layout(width="400px") |
| ) |
| topk_slider = widgets.IntSlider( |
| value=40, min=1, max=200, step=1, |
| description="Top-k", |
| style={"description_width": "90px"}, |
| layout=widgets.Layout(width="400px") |
| ) |
| tokens_slider = widgets.IntSlider( |
| value=150, min=10, max=500, step=10, |
| description="Max tokens", |
| style={"description_width": "90px"}, |
| layout=widgets.Layout(width="400px") |
| ) |
|
|
| |
| prompt_box = widgets.Textarea( |
| placeholder="Type your prompt here...", |
| layout=widgets.Layout(width="700px", height="70px") |
| ) |
| generate_btn = widgets.Button( |
| description="Generate", |
| button_style="primary", |
| layout=widgets.Layout(width="120px") |
| ) |
| clear_btn = widgets.Button( |
| description="Clear", |
| button_style="warning", |
| layout=widgets.Layout(width="80px") |
| ) |
|
|
| |
| output_area = widgets.Output() |
|
|
| def on_generate(b): |
| prompt = prompt_box.value.strip() |
| if not prompt: |
| return |
|
|
| with output_area: |
| print(f"\nYOU > {prompt}") |
| print("MODEL > ", end="", flush=True) |
|
|
| continuation = gen.generate( |
| prompt, |
| max_new_tokens=tokens_slider.value, |
| temperature=temp_slider.value, |
| top_k=topk_slider.value, |
| ) |
| print(continuation) |
| print("-" * 70) |
|
|
| def on_clear(b): |
| output_area.clear_output() |
| prompt_box.value = "" |
|
|
| generate_btn.on_click(on_generate) |
| clear_btn.on_click(on_clear) |
|
|
| |
| controls = widgets.VBox([ |
| widgets.HTML("<b>Sampling settings</b>"), |
| temp_slider, |
| topk_slider, |
| tokens_slider, |
| ]) |
| buttons = widgets.HBox([generate_btn, clear_btn]) |
| ui = widgets.VBox([ |
| widgets.HTML("<h3>GPT-2 Small β Chat</h3>"), |
| controls, |
| widgets.HTML("<br><b>Prompt</b>"), |
| prompt_box, |
| buttons, |
| widgets.HTML("<br><b>Output</b>"), |
| output_area, |
| ]) |
|
|
| display(ui) |