| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import os, json, math
|
| | from pathlib import Path
|
| | from typing import List, Tuple, Generator
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import sentencepiece as spm
|
| | import gradio as gr
|
| |
|
| |
|
| | ART = Path(os.environ.get("CHAT_SPRINT_ARTIFACTS", "chat_sprint_artifacts"))
|
| | SPM_PATH = ART / "spm_chat.model"
|
| | CKPT = ART / "tinygpt.pt"
|
| | CFG_JSON = ART / "model_config.json"
|
| | LOG_PATH = ART / "chat_transcript.txt"
|
| | ART.mkdir(parents=True, exist_ok=True)
|
| |
|
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | if DEVICE.type == "cuda":
|
| | torch.set_float32_matmul_precision("high")
|
| | torch.backends.cuda.matmul.allow_tf32 = True
|
| |
|
| |
|
| | class CausalSelfAttention(nn.Module):
|
| | def __init__(self, n_embd, n_head, dropout=0.0, block_size=256):
|
| | super().__init__()
|
| | assert n_embd % n_head == 0
|
| | self.n_head = n_head
|
| | self.head_dim = n_embd // n_head
|
| | self.qkv = nn.Linear(n_embd, 3*n_embd, bias=False)
|
| | self.proj = nn.Linear(n_embd, n_embd, bias=False)
|
| | self.attn_drop = nn.Dropout(dropout)
|
| | self.resid_drop = nn.Dropout(dropout)
|
| | self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size))
|
| | def forward(self, x):
|
| | B,T,C = x.shape
|
| | qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1)
|
| | q = q.view(B,T,self.n_head,self.head_dim).transpose(1,2)
|
| | k = k.view(B,T,self.n_head,self.head_dim).transpose(1,2)
|
| | v = v.view(B,T,self.n_head,self.head_dim).transpose(1,2)
|
| | att = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim)
|
| | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
| | att = torch.softmax(att, dim=-1)
|
| | att = self.attn_drop(att)
|
| | y = att @ v
|
| | y = y.transpose(1,2).contiguous().view(B,T,C)
|
| | y = self.resid_drop(self.proj(y))
|
| | return y
|
| |
|
| | class Block(nn.Module):
|
| | def __init__(self, n_embd, n_head, dropout=0.0, block_size=256):
|
| | super().__init__()
|
| | self.ln1 = nn.LayerNorm(n_embd)
|
| | self.attn = CausalSelfAttention(n_embd, n_head, dropout, block_size)
|
| | self.ln2 = nn.LayerNorm(n_embd)
|
| | self.mlp = nn.Sequential(
|
| | nn.Linear(n_embd, 4*n_embd),
|
| | nn.GELU(),
|
| | nn.Linear(4*n_embd, n_embd),
|
| | nn.Dropout(dropout),
|
| | )
|
| | def forward(self, x):
|
| | x = x + self.attn(self.ln1(x))
|
| | x = x + self.mlp(self.ln2(x))
|
| | return x
|
| |
|
| | class TinyGPT(nn.Module):
|
| | def __init__(self, vocab_size, n_layer, n_head, n_embd, block_size, dropout=0.0):
|
| | super().__init__()
|
| | self.block_size = block_size
|
| | self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| | self.pos_emb = nn.Embedding(block_size, n_embd)
|
| | self.blocks = nn.ModuleList([Block(n_embd, n_head, dropout, block_size) for _ in range(n_layer)])
|
| | self.ln_f = nn.LayerNorm(n_embd)
|
| | self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
| | self.apply(self._init)
|
| | def _init(self, m):
|
| | if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| | if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias)
|
| | if isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
|
| | def forward(self, idx):
|
| | B,T = idx.shape; assert T <= self.block_size
|
| | pos = torch.arange(0, T, device=idx.device)
|
| | x = self.tok_emb(idx) + self.pos_emb(pos)[None,:,:]
|
| | for blk in self.blocks: x = blk(x)
|
| | x = self.ln_f(x)
|
| | logits = self.head(x)
|
| | return logits
|
| |
|
| | @torch.no_grad()
|
| | def sample_stream(
|
| | self, idx, sp: spm.SentencePieceProcessor,
|
| | forbid_ids=None, stop_ids=None, stop_strings=None,
|
| | max_new_tokens=200, temperature=0.8, top_k=60, top_p=0.95, repetition_penalty=1.0
|
| | ) -> Generator[str, None, None]:
|
| | """Yield decoded text chunks with simple constraints."""
|
| | forbid_ids = set(forbid_ids or [])
|
| | stop_ids = set(stop_ids or [])
|
| | stop_strings = list(stop_strings or [])
|
| | prev_text = sp.decode(idx[0].tolist())
|
| |
|
| | for _ in range(int(max_new_tokens)):
|
| | idx_cond = idx[:, -self.block_size:]
|
| | logits = self.forward(idx_cond)[:, -1, :]
|
| |
|
| |
|
| | if forbid_ids:
|
| | mask_idx = torch.tensor(list(forbid_ids), device=logits.device)
|
| | logits[:, mask_idx] = -1e9
|
| |
|
| |
|
| | if repetition_penalty != 1.0:
|
| | uniq, _ = torch.unique(idx_cond[0], return_counts=True)
|
| | logits[:, uniq] /= repetition_penalty
|
| |
|
| |
|
| | logits = logits / max(1e-8, float(temperature))
|
| | if top_k and int(top_k) > 0:
|
| | v, _ = torch.topk(logits, min(int(top_k), logits.size(-1)))
|
| | cutoff = v[:, -1].unsqueeze(-1)
|
| | logits = torch.where(logits < cutoff, torch.full_like(logits, -1e9), logits)
|
| | if top_p and float(top_p) < 0.9999:
|
| | sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| | probs = torch.softmax(sorted_logits, dim=-1)
|
| | cdf = torch.cumsum(probs, dim=-1)
|
| | mask = cdf > float(top_p); mask[:, 0] = False
|
| | sorted_logits[mask] = -1e9
|
| | logits = torch.zeros_like(logits).scatter(1, sorted_idx, sorted_logits)
|
| |
|
| |
|
| | probs = torch.softmax(logits, dim=-1)
|
| | next_id = torch.multinomial(probs, num_samples=1)
|
| |
|
| |
|
| | if int(next_id) in stop_ids:
|
| | break
|
| |
|
| | idx = torch.cat([idx, next_id], dim=1)
|
| |
|
| |
|
| | full_text = sp.decode(idx[0].tolist())
|
| | new_chunk = full_text[len(prev_text):]
|
| | acc_text = full_text
|
| | cut_at = None
|
| | for s in stop_strings:
|
| | pos = acc_text.find(s, len(prev_text))
|
| | if pos != -1:
|
| | cut_at = pos
|
| | break
|
| | if cut_at is not None:
|
| |
|
| | yield acc_text[len(prev_text):cut_at-len(prev_text)]
|
| | break
|
| |
|
| | yield new_chunk
|
| | prev_text = full_text
|
| |
|
| |
|
| | def load_everything():
|
| | if not SPM_PATH.exists(): raise FileNotFoundError(f"Missing tokenizer: {SPM_PATH}")
|
| | if not CKPT.exists(): raise FileNotFoundError(f"Missing weights: {CKPT}")
|
| | if not CFG_JSON.exists(): raise FileNotFoundError(f"Missing config: {CFG_JSON}")
|
| |
|
| | sp = spm.SentencePieceProcessor(); sp.load(str(SPM_PATH))
|
| | cfg = json.loads(CFG_JSON.read_text())
|
| | model = TinyGPT(
|
| | vocab_size=cfg["vocab_size"],
|
| | n_layer=cfg["n_layer"], n_head=cfg["n_head"], n_embd=cfg["n_embd"],
|
| | block_size=cfg["block_size"], dropout=cfg.get("dropout", 0.0)
|
| | ).to(DEVICE)
|
| | sd = torch.load(CKPT, map_location=DEVICE)
|
| | model.load_state_dict(sd, strict=True)
|
| | model.eval()
|
| | return sp, model, cfg
|
| |
|
| | SP, MODEL, CFG = load_everything()
|
| |
|
| |
|
| | EOS_ID = SP.eos_id()
|
| | YOU_ID = SP.piece_to_id("You:")
|
| | BOT_ID = SP.piece_to_id("Bot:")
|
| | TAGS = ["[STYLE=Snark]", "[FORM=TWEET]", "[FORM=HEADLINE]", "[MOOD=Unhinged]", "[MOOD=Cheeky]"]
|
| | TAG_IDS = [SP.piece_to_id(t) for t in TAGS if SP.piece_to_id(t) != -1]
|
| | FORBID_IDS = {x for x in [YOU_ID, BOT_ID] + TAG_IDS if x != -1}
|
| | STOP_IDS = {i for i in [EOS_ID, YOU_ID] if i != -1}
|
| | STOP_STRS = ["\nYou:", "\n\n"]
|
| |
|
| |
|
| | STYLE_TAGS = ["", *TAGS]
|
| |
|
| | def history_to_pairs(history_messages) -> List[Tuple[str,str]]:
|
| | pairs: List[Tuple[str,str]] = []
|
| | last_user = None
|
| | for m in history_messages:
|
| | role = m.get("role"); content = m.get("content", "")
|
| | if role == "user": last_user = content
|
| | elif role == "assistant" and last_user is not None:
|
| | pairs.append((last_user, content)); last_user = None
|
| | return pairs
|
| |
|
| | def build_prompt(history_pairs: List[Tuple[str,str]], user: str, style: str) -> str:
|
| | lines = []
|
| | for u, b in history_pairs:
|
| | lines.append(f"You: {u}")
|
| | lines.append(f"Bot: {b}")
|
| | lines.append("")
|
| | lines.append(f"You: {user}")
|
| | if style: lines.append(style)
|
| | lines.append("Bot:")
|
| | return "\n".join(lines)
|
| |
|
| | def encode_ctx(text: str, block_size: int) -> torch.Tensor:
|
| | ids = SP.encode(text, out_type=int)
|
| | return torch.tensor([ids[-block_size:]], dtype=torch.long, device=DEVICE)
|
| |
|
| |
|
| | def respond(message, history, temperature, top_k, top_p, repetition_penalty, max_new_tokens, style):
|
| | if isinstance(message, dict):
|
| | message = message.get("content", "")
|
| | pairs = history_to_pairs(history)
|
| | prompt = build_prompt(pairs, message, style)
|
| | x = encode_ctx(prompt, CFG["block_size"])
|
| |
|
| | stream = MODEL.sample_stream(
|
| | x, SP,
|
| | forbid_ids=FORBID_IDS,
|
| | stop_ids=STOP_IDS,
|
| | stop_strings=STOP_STRS,
|
| | max_new_tokens=int(max_new_tokens),
|
| | temperature=float(temperature),
|
| | top_k=int(top_k),
|
| | top_p=float(top_p),
|
| | repetition_penalty=float(repetition_penalty),
|
| | )
|
| |
|
| | acc = ""
|
| | for chunk in stream:
|
| | acc += chunk
|
| | yield acc
|
| |
|
| | try:
|
| | with LOG_PATH.open("a", encoding="utf-8") as f:
|
| | f.write(f"You: {message}\n")
|
| | if style: f.write(style + "\n")
|
| | f.write(f"Bot: {acc}\n\n")
|
| | except Exception:
|
| | pass
|
| |
|
| |
|
| | def main():
|
| | title = "TinyGPT — Fun Chat (turn-taking fixed)"
|
| | desc = f"Device: {DEVICE.type.upper()} • vocab={CFG['vocab_size']} • layers={CFG['n_layer']} heads={CFG['n_head']} dim={CFG['n_embd']} • block={CFG['block_size']}"
|
| |
|
| | iface = gr.ChatInterface(
|
| | fn=respond,
|
| | title=title,
|
| | description=desc,
|
| | additional_inputs=[
|
| | gr.Slider(0.2, 1.5, value=0.8, step=0.05, label="Temperature"),
|
| | gr.Slider(0, 200, value=60, step=1, label="Top-K (0=off ⇒ set 0)"),
|
| | gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-P"),
|
| | gr.Slider(1.0, 1.5, value=1.08, step=0.01, label="Repetition penalty"),
|
| | gr.Slider(16, 512, value=200, step=8, label="Max new tokens"),
|
| | gr.Dropdown(STYLE_TAGS, value="", label="Style tag"),
|
| | ],
|
| | type="messages",
|
| | )
|
| | iface.queue()
|
| | iface.launch(server_name="0.0.0.0", server_port=7860, show_api=False, inbrowser=False)
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|