#!/usr/bin/env python # chat.py — Gradio chat app with **turn-taking guards** # ---------------------------------------------------- # pip install torch sentencepiece gradio # python chat.py # # Expects in ./chat_sprint_artifacts (or CHAT_SPRINT_ARTIFACTS env): # - spm_chat.model # - tinygpt.pt # - model_config.json # # Fixes for your screenshot: # • Masks “You:” and tag tokens so the bot can’t emit them mid-reply. # • Stops generation on EOS / "\n\n" / "\nYou:" so it doesn’t start the next turn. # • Works with Gradio 5.x (type="messages", queue() w/o kwargs). import os, json, math from pathlib import Path from typing import List, Tuple, Generator import torch import torch.nn as nn import sentencepiece as spm import gradio as gr # ---------- paths & device ---------- ART = Path(os.environ.get("CHAT_SPRINT_ARTIFACTS", "chat_sprint_artifacts")) SPM_PATH = ART / "spm_chat.model" CKPT = ART / "tinygpt.pt" CFG_JSON = ART / "model_config.json" LOG_PATH = ART / "chat_transcript.txt" ART.mkdir(parents=True, exist_ok=True) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") if DEVICE.type == "cuda": torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True # ---------- tiny GPT (same as trainer) ---------- class CausalSelfAttention(nn.Module): def __init__(self, n_embd, n_head, dropout=0.0, block_size=256): super().__init__() assert n_embd % n_head == 0 self.n_head = n_head self.head_dim = n_embd // n_head self.qkv = nn.Linear(n_embd, 3*n_embd, bias=False) self.proj = nn.Linear(n_embd, n_embd, bias=False) self.attn_drop = nn.Dropout(dropout) self.resid_drop = nn.Dropout(dropout) self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size)) def forward(self, x): B,T,C = x.shape qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1) q = q.view(B,T,self.n_head,self.head_dim).transpose(1,2) k = k.view(B,T,self.n_head,self.head_dim).transpose(1,2) v = v.view(B,T,self.n_head,self.head_dim).transpose(1,2) att = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) att = torch.softmax(att, dim=-1) att = self.attn_drop(att) y = att @ v y = y.transpose(1,2).contiguous().view(B,T,C) y = self.resid_drop(self.proj(y)) return y class Block(nn.Module): def __init__(self, n_embd, n_head, dropout=0.0, block_size=256): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, dropout, block_size) self.ln2 = nn.LayerNorm(n_embd) self.mlp = nn.Sequential( nn.Linear(n_embd, 4*n_embd), nn.GELU(), nn.Linear(4*n_embd, n_embd), nn.Dropout(dropout), ) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class TinyGPT(nn.Module): def __init__(self, vocab_size, n_layer, n_head, n_embd, block_size, dropout=0.0): super().__init__() self.block_size = block_size self.tok_emb = nn.Embedding(vocab_size, n_embd) self.pos_emb = nn.Embedding(block_size, n_embd) self.blocks = nn.ModuleList([Block(n_embd, n_head, dropout, block_size) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) self.apply(self._init) def _init(self, m): if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, mean=0.0, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) if isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight); nn.init.zeros_(m.bias) def forward(self, idx): B,T = idx.shape; assert T <= self.block_size pos = torch.arange(0, T, device=idx.device) x = self.tok_emb(idx) + self.pos_emb(pos)[None,:,:] for blk in self.blocks: x = blk(x) x = self.ln_f(x) logits = self.head(x) return logits @torch.no_grad() def sample_stream( self, idx, sp: spm.SentencePieceProcessor, forbid_ids=None, stop_ids=None, stop_strings=None, max_new_tokens=200, temperature=0.8, top_k=60, top_p=0.95, repetition_penalty=1.0 ) -> Generator[str, None, None]: """Yield decoded text chunks with simple constraints.""" forbid_ids = set(forbid_ids or []) stop_ids = set(stop_ids or []) stop_strings = list(stop_strings or []) prev_text = sp.decode(idx[0].tolist()) for _ in range(int(max_new_tokens)): idx_cond = idx[:, -self.block_size:] logits = self.forward(idx_cond)[:, -1, :] # role/tag masking if forbid_ids: mask_idx = torch.tensor(list(forbid_ids), device=logits.device) logits[:, mask_idx] = -1e9 # repetition penalty if repetition_penalty != 1.0: uniq, _ = torch.unique(idx_cond[0], return_counts=True) logits[:, uniq] /= repetition_penalty # temp / top-k / top-p logits = logits / max(1e-8, float(temperature)) if top_k and int(top_k) > 0: v, _ = torch.topk(logits, min(int(top_k), logits.size(-1))) cutoff = v[:, -1].unsqueeze(-1) logits = torch.where(logits < cutoff, torch.full_like(logits, -1e9), logits) if top_p and float(top_p) < 0.9999: sorted_logits, sorted_idx = torch.sort(logits, descending=True) probs = torch.softmax(sorted_logits, dim=-1) cdf = torch.cumsum(probs, dim=-1) mask = cdf > float(top_p); mask[:, 0] = False sorted_logits[mask] = -1e9 logits = torch.zeros_like(logits).scatter(1, sorted_idx, sorted_logits) # sample one probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) # token-level stops if int(next_id) in stop_ids: break idx = torch.cat([idx, next_id], dim=1) # text-level stops full_text = sp.decode(idx[0].tolist()) new_chunk = full_text[len(prev_text):] acc_text = full_text # for string-stop check cut_at = None for s in stop_strings: pos = acc_text.find(s, len(prev_text)) # only search in newly produced tail if pos != -1: cut_at = pos break if cut_at is not None: # yield only up to stop sequence, then stop yield acc_text[len(prev_text):cut_at-len(prev_text)] break yield new_chunk prev_text = full_text # ---------- artifacts ---------- def load_everything(): if not SPM_PATH.exists(): raise FileNotFoundError(f"Missing tokenizer: {SPM_PATH}") if not CKPT.exists(): raise FileNotFoundError(f"Missing weights: {CKPT}") if not CFG_JSON.exists(): raise FileNotFoundError(f"Missing config: {CFG_JSON}") sp = spm.SentencePieceProcessor(); sp.load(str(SPM_PATH)) cfg = json.loads(CFG_JSON.read_text()) model = TinyGPT( vocab_size=cfg["vocab_size"], n_layer=cfg["n_layer"], n_head=cfg["n_head"], n_embd=cfg["n_embd"], block_size=cfg["block_size"], dropout=cfg.get("dropout", 0.0) ).to(DEVICE) sd = torch.load(CKPT, map_location=DEVICE) model.load_state_dict(sd, strict=True) model.eval() return sp, model, cfg SP, MODEL, CFG = load_everything() # special pieces we want to control at decode-time EOS_ID = SP.eos_id() YOU_ID = SP.piece_to_id("You:") BOT_ID = SP.piece_to_id("Bot:") TAGS = ["[STYLE=Snark]", "[FORM=TWEET]", "[FORM=HEADLINE]", "[MOOD=Unhinged]", "[MOOD=Cheeky]"] TAG_IDS = [SP.piece_to_id(t) for t in TAGS if SP.piece_to_id(t) != -1] FORBID_IDS = {x for x in [YOU_ID, BOT_ID] + TAG_IDS if x != -1} # forbid these inside assistant text STOP_IDS = {i for i in [EOS_ID, YOU_ID] if i != -1} STOP_STRS = ["\nYou:", "\n\n"] # treat next turn / blank-line as stop # ---------- prompt building ---------- STYLE_TAGS = ["", *TAGS] def history_to_pairs(history_messages) -> List[Tuple[str,str]]: pairs: List[Tuple[str,str]] = [] last_user = None for m in history_messages: role = m.get("role"); content = m.get("content", "") if role == "user": last_user = content elif role == "assistant" and last_user is not None: pairs.append((last_user, content)); last_user = None return pairs def build_prompt(history_pairs: List[Tuple[str,str]], user: str, style: str) -> str: lines = [] for u, b in history_pairs: lines.append(f"You: {u}") lines.append(f"Bot: {b}") lines.append("") lines.append(f"You: {user}") if style: lines.append(style) # we *insert* tags here, but we forbid them inside the reply lines.append("Bot:") return "\n".join(lines) def encode_ctx(text: str, block_size: int) -> torch.Tensor: ids = SP.encode(text, out_type=int) return torch.tensor([ids[-block_size:]], dtype=torch.long, device=DEVICE) # ---------- gradio handler ---------- def respond(message, history, temperature, top_k, top_p, repetition_penalty, max_new_tokens, style): if isinstance(message, dict): message = message.get("content", "") pairs = history_to_pairs(history) prompt = build_prompt(pairs, message, style) x = encode_ctx(prompt, CFG["block_size"]) stream = MODEL.sample_stream( x, SP, forbid_ids=FORBID_IDS, stop_ids=STOP_IDS, stop_strings=STOP_STRS, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_k=int(top_k), top_p=float(top_p), repetition_penalty=float(repetition_penalty), ) acc = "" for chunk in stream: acc += chunk yield acc try: with LOG_PATH.open("a", encoding="utf-8") as f: f.write(f"You: {message}\n") if style: f.write(style + "\n") f.write(f"Bot: {acc}\n\n") except Exception: pass # ---------- app ---------- def main(): title = "TinyGPT — Fun Chat (turn-taking fixed)" desc = f"Device: {DEVICE.type.upper()} • vocab={CFG['vocab_size']} • layers={CFG['n_layer']} heads={CFG['n_head']} dim={CFG['n_embd']} • block={CFG['block_size']}" iface = gr.ChatInterface( fn=respond, title=title, description=desc, additional_inputs=[ gr.Slider(0.2, 1.5, value=0.8, step=0.05, label="Temperature"), gr.Slider(0, 200, value=60, step=1, label="Top-K (0=off ⇒ set 0)"), gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-P"), gr.Slider(1.0, 1.5, value=1.08, step=0.01, label="Repetition penalty"), gr.Slider(16, 512, value=200, step=8, label="Max new tokens"), gr.Dropdown(STYLE_TAGS, value="", label="Style tag"), ], type="messages", ) iface.queue() iface.launch(server_name="0.0.0.0", server_port=7860, show_api=False, inbrowser=False) if __name__ == "__main__": main()