fungpt / chat.py
nagolinc's picture
Upload folder using huggingface_hub
a1bf71f verified
#!/usr/bin/env python
# chat.py — Gradio chat app with **turn-taking guards**
# ----------------------------------------------------
# pip install torch sentencepiece gradio
# python chat.py
#
# Expects in ./chat_sprint_artifacts (or CHAT_SPRINT_ARTIFACTS env):
# - spm_chat.model
# - tinygpt.pt
# - model_config.json
#
# Fixes for your screenshot:
# • Masks “You:” and tag tokens so the bot can’t emit them mid-reply.
# • Stops generation on EOS / "\n\n" / "\nYou:" so it doesn’t start the next turn.
# • Works with Gradio 5.x (type="messages", queue() w/o kwargs).
import os, json, math
from pathlib import Path
from typing import List, Tuple, Generator
import torch
import torch.nn as nn
import sentencepiece as spm
import gradio as gr
# ---------- paths & device ----------
ART = Path(os.environ.get("CHAT_SPRINT_ARTIFACTS", "chat_sprint_artifacts"))
SPM_PATH = ART / "spm_chat.model"
CKPT = ART / "tinygpt.pt"
CFG_JSON = ART / "model_config.json"
LOG_PATH = ART / "chat_transcript.txt"
ART.mkdir(parents=True, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
# ---------- tiny GPT (same as trainer) ----------
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd, n_head, dropout=0.0, block_size=256):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.head_dim = n_embd // n_head
self.qkv = nn.Linear(n_embd, 3*n_embd, bias=False)
self.proj = nn.Linear(n_embd, n_embd, bias=False)
self.attn_drop = nn.Dropout(dropout)
self.resid_drop = nn.Dropout(dropout)
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size))
def forward(self, x):
B,T,C = x.shape
qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1)
q = q.view(B,T,self.n_head,self.head_dim).transpose(1,2)
k = k.view(B,T,self.n_head,self.head_dim).transpose(1,2)
v = v.view(B,T,self.n_head,self.head_dim).transpose(1,2)
att = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = torch.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1,2).contiguous().view(B,T,C)
y = self.resid_drop(self.proj(y))
return y
class Block(nn.Module):
def __init__(self, n_embd, n_head, dropout=0.0, block_size=256):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, dropout, block_size)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4*n_embd),
nn.GELU(),
nn.Linear(4*n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class TinyGPT(nn.Module):
def __init__(self, vocab_size, n_layer, n_head, n_embd, block_size, dropout=0.0):
super().__init__()
self.block_size = block_size
self.tok_emb = nn.Embedding(vocab_size, n_embd)
self.pos_emb = nn.Embedding(block_size, n_embd)
self.blocks = nn.ModuleList([Block(n_embd, n_head, dropout, block_size) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.apply(self._init)
def _init(self, m):
if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, mean=0.0, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias)
if isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
def forward(self, idx):
B,T = idx.shape; assert T <= self.block_size
pos = torch.arange(0, T, device=idx.device)
x = self.tok_emb(idx) + self.pos_emb(pos)[None,:,:]
for blk in self.blocks: x = blk(x)
x = self.ln_f(x)
logits = self.head(x)
return logits
@torch.no_grad()
def sample_stream(
self, idx, sp: spm.SentencePieceProcessor,
forbid_ids=None, stop_ids=None, stop_strings=None,
max_new_tokens=200, temperature=0.8, top_k=60, top_p=0.95, repetition_penalty=1.0
) -> Generator[str, None, None]:
"""Yield decoded text chunks with simple constraints."""
forbid_ids = set(forbid_ids or [])
stop_ids = set(stop_ids or [])
stop_strings = list(stop_strings or [])
prev_text = sp.decode(idx[0].tolist())
for _ in range(int(max_new_tokens)):
idx_cond = idx[:, -self.block_size:]
logits = self.forward(idx_cond)[:, -1, :]
# role/tag masking
if forbid_ids:
mask_idx = torch.tensor(list(forbid_ids), device=logits.device)
logits[:, mask_idx] = -1e9
# repetition penalty
if repetition_penalty != 1.0:
uniq, _ = torch.unique(idx_cond[0], return_counts=True)
logits[:, uniq] /= repetition_penalty
# temp / top-k / top-p
logits = logits / max(1e-8, float(temperature))
if top_k and int(top_k) > 0:
v, _ = torch.topk(logits, min(int(top_k), logits.size(-1)))
cutoff = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < cutoff, torch.full_like(logits, -1e9), logits)
if top_p and float(top_p) < 0.9999:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
probs = torch.softmax(sorted_logits, dim=-1)
cdf = torch.cumsum(probs, dim=-1)
mask = cdf > float(top_p); mask[:, 0] = False
sorted_logits[mask] = -1e9
logits = torch.zeros_like(logits).scatter(1, sorted_idx, sorted_logits)
# sample one
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
# token-level stops
if int(next_id) in stop_ids:
break
idx = torch.cat([idx, next_id], dim=1)
# text-level stops
full_text = sp.decode(idx[0].tolist())
new_chunk = full_text[len(prev_text):]
acc_text = full_text # for string-stop check
cut_at = None
for s in stop_strings:
pos = acc_text.find(s, len(prev_text)) # only search in newly produced tail
if pos != -1:
cut_at = pos
break
if cut_at is not None:
# yield only up to stop sequence, then stop
yield acc_text[len(prev_text):cut_at-len(prev_text)]
break
yield new_chunk
prev_text = full_text
# ---------- artifacts ----------
def load_everything():
if not SPM_PATH.exists(): raise FileNotFoundError(f"Missing tokenizer: {SPM_PATH}")
if not CKPT.exists(): raise FileNotFoundError(f"Missing weights: {CKPT}")
if not CFG_JSON.exists(): raise FileNotFoundError(f"Missing config: {CFG_JSON}")
sp = spm.SentencePieceProcessor(); sp.load(str(SPM_PATH))
cfg = json.loads(CFG_JSON.read_text())
model = TinyGPT(
vocab_size=cfg["vocab_size"],
n_layer=cfg["n_layer"], n_head=cfg["n_head"], n_embd=cfg["n_embd"],
block_size=cfg["block_size"], dropout=cfg.get("dropout", 0.0)
).to(DEVICE)
sd = torch.load(CKPT, map_location=DEVICE)
model.load_state_dict(sd, strict=True)
model.eval()
return sp, model, cfg
SP, MODEL, CFG = load_everything()
# special pieces we want to control at decode-time
EOS_ID = SP.eos_id()
YOU_ID = SP.piece_to_id("You:")
BOT_ID = SP.piece_to_id("Bot:")
TAGS = ["[STYLE=Snark]", "[FORM=TWEET]", "[FORM=HEADLINE]", "[MOOD=Unhinged]", "[MOOD=Cheeky]"]
TAG_IDS = [SP.piece_to_id(t) for t in TAGS if SP.piece_to_id(t) != -1]
FORBID_IDS = {x for x in [YOU_ID, BOT_ID] + TAG_IDS if x != -1} # forbid these inside assistant text
STOP_IDS = {i for i in [EOS_ID, YOU_ID] if i != -1}
STOP_STRS = ["\nYou:", "\n\n"] # treat next turn / blank-line as stop
# ---------- prompt building ----------
STYLE_TAGS = ["", *TAGS]
def history_to_pairs(history_messages) -> List[Tuple[str,str]]:
pairs: List[Tuple[str,str]] = []
last_user = None
for m in history_messages:
role = m.get("role"); content = m.get("content", "")
if role == "user": last_user = content
elif role == "assistant" and last_user is not None:
pairs.append((last_user, content)); last_user = None
return pairs
def build_prompt(history_pairs: List[Tuple[str,str]], user: str, style: str) -> str:
lines = []
for u, b in history_pairs:
lines.append(f"You: {u}")
lines.append(f"Bot: {b}")
lines.append("")
lines.append(f"You: {user}")
if style: lines.append(style) # we *insert* tags here, but we forbid them inside the reply
lines.append("Bot:")
return "\n".join(lines)
def encode_ctx(text: str, block_size: int) -> torch.Tensor:
ids = SP.encode(text, out_type=int)
return torch.tensor([ids[-block_size:]], dtype=torch.long, device=DEVICE)
# ---------- gradio handler ----------
def respond(message, history, temperature, top_k, top_p, repetition_penalty, max_new_tokens, style):
if isinstance(message, dict):
message = message.get("content", "")
pairs = history_to_pairs(history)
prompt = build_prompt(pairs, message, style)
x = encode_ctx(prompt, CFG["block_size"])
stream = MODEL.sample_stream(
x, SP,
forbid_ids=FORBID_IDS,
stop_ids=STOP_IDS,
stop_strings=STOP_STRS,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
)
acc = ""
for chunk in stream:
acc += chunk
yield acc
try:
with LOG_PATH.open("a", encoding="utf-8") as f:
f.write(f"You: {message}\n")
if style: f.write(style + "\n")
f.write(f"Bot: {acc}\n\n")
except Exception:
pass
# ---------- app ----------
def main():
title = "TinyGPT — Fun Chat (turn-taking fixed)"
desc = f"Device: {DEVICE.type.upper()} • vocab={CFG['vocab_size']} • layers={CFG['n_layer']} heads={CFG['n_head']} dim={CFG['n_embd']} • block={CFG['block_size']}"
iface = gr.ChatInterface(
fn=respond,
title=title,
description=desc,
additional_inputs=[
gr.Slider(0.2, 1.5, value=0.8, step=0.05, label="Temperature"),
gr.Slider(0, 200, value=60, step=1, label="Top-K (0=off ⇒ set 0)"),
gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-P"),
gr.Slider(1.0, 1.5, value=1.08, step=0.01, label="Repetition penalty"),
gr.Slider(16, 512, value=200, step=8, label="Max new tokens"),
gr.Dropdown(STYLE_TAGS, value="", label="Style tag"),
],
type="messages",
)
iface.queue()
iface.launch(server_name="0.0.0.0", server_port=7860, show_api=False, inbrowser=False)
if __name__ == "__main__":
main()