import json import math import re from pathlib import Path import gradio as gr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import hf_hub_download MODEL_REPO = "aagzamov/tiny-chatbot-model" CKPT_FILENAME = "ckpt.pt" VOCAB_FILENAME = "vocab.json" TOKEN_RE = re.compile(r"<[^>]+>|[A-Za-z0-9]+|[^\sA-Za-z0-9]", re.UNICODE) PAD = "" EOS = "" UNK = "" def tokenize(text: str) -> list[str]: return TOKEN_RE.findall(text.strip()) def load_vocab(local_path: str) -> tuple[dict[str, int], list[str]]: obj = json.loads(Path(local_path).read_text(encoding="utf-8")) itos = obj["itos"] stoi = {t: i for i, t in enumerate(itos)} return stoi, itos def encode(stoi: dict[str, int], unk_id: int, text: str) -> list[int]: return [stoi.get(tok, unk_id) for tok in tokenize(text)] def decode(itos: list[str], eos_id: int, ids: list[int]) -> str: toks = [] for i in ids: if i == eos_id: break if 0 <= i < len(itos): tok = itos[i] if tok == PAD: continue toks.append(tok) return " ".join(toks).replace(" ,", ",").replace(" .", ".").replace(" !", "!").replace(" ?", "?") RULES: list[tuple[str, list[str]]] = [ ("refund", ["refund", "money back", "chargeback"]), ("return_process", ["return", "exchange"]), ("damaged", ["damaged", "broken", "cracked", "defect"]), ("shipping_time", ["shipping time", "delivery time", "how long", "arrive"]), ("express", ["express", "fast delivery", "1-2 day", "1–2 day"]), ("international", ["international", "other country", "abroad"]), ("tracking", ["tracking", "track", "track my order", "order tracking", "tracking link", "tracking number"]), ("payment_methods", ["payment methods", "how can i pay", "pay with", "payment option"]), ("payment_failed", ["payment failed", "cant pay", "can’t pay", "declined", "checkout payment error"]), ("discount", ["discount", "coupon", "promo code"]), ("account_create", ["create account", "sign up", "register"]), ("reset_password", ["forgot password", "reset password", "cant login", "can’t login", "cannot login", "login problem", "login issue"]), ("cancel_order", ["cancel", "cancellation"]), ("address_change", ["change address", "update address"]), ("delivered_not_received", ["delivered but", "says delivered", "not received"]), ("warranty", ["warranty", "guarantee"]), ("size", ["size chart", "size guide", "which size"]), ("support", ["contact", "support", "help"]), ] def route_intent(question: str) -> str: q = re.sub(r"\s+", " ", question.lower().strip()) for intent, keys in RULES: if any(k in q for k in keys): return intent return "unknown" class GPTConfig: def __init__(self, vocab_size: int, ctx_len: int, n_layers: int, n_heads: int, d_model: int, ff_mult: int = 4, dropout: float = 0.0): self.vocab_size = vocab_size self.ctx_len = ctx_len self.n_layers = n_layers self.n_heads = n_heads self.d_model = d_model self.ff_mult = ff_mult self.dropout = dropout class CausalSelfAttention(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() assert cfg.d_model % cfg.n_heads == 0 self.cfg = cfg self.head_dim = cfg.d_model // cfg.n_heads self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=True) self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=True) mask = torch.tril(torch.ones(cfg.ctx_len, cfg.ctx_len)).view(1, 1, cfg.ctx_len, cfg.ctx_len) self.register_buffer("mask", mask, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: b, t, c = x.shape qkv = self.qkv(x) q, k, v = qkv.split(c, dim=2) q = q.view(b, t, self.cfg.n_heads, self.head_dim).transpose(1, 2) k = k.view(b, t, self.cfg.n_heads, self.head_dim).transpose(1, 2) v = v.view(b, t, self.cfg.n_heads, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf")) att = F.softmax(att, dim=-1) y = att @ v y = y.transpose(1, 2).contiguous().view(b, t, c) return self.proj(y) class MLP(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() hidden = cfg.d_model * cfg.ff_mult self.fc1 = nn.Linear(cfg.d_model, hidden, bias=True) self.fc2 = nn.Linear(hidden, cfg.d_model, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(F.gelu(self.fc1(x))) class Block(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.ln1 = nn.LayerNorm(cfg.d_model) self.attn = CausalSelfAttention(cfg) self.ln2 = nn.LayerNorm(cfg.d_model) self.mlp = MLP(cfg) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class TinyGPT(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.pos_emb = nn.Embedding(cfg.ctx_len, cfg.d_model) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) self.ln_f = nn.LayerNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight def forward(self, idx: torch.Tensor): b, t = idx.shape pos = torch.arange(0, t, device=idx.device).unsqueeze(0) x = self.tok_emb(idx) + self.pos_emb(pos) for blk in self.blocks: x = blk(x) x = self.ln_f(x) return self.lm_head(x) @torch.no_grad() def sample(model: TinyGPT, prompt_ids: list[int], eos_id: int, max_new: int, temperature: float, top_k: int, device: str) -> list[int]: ids = torch.tensor(np.array(prompt_ids, dtype=np.int64), device=device).unsqueeze(0) for _ in range(max_new): ids_cond = ids[:, -model.cfg.ctx_len :] logits = model(ids_cond)[:, -1, :] / max(1e-6, temperature) if top_k > 0: v, _ = torch.topk(logits, top_k) cutoff = v[:, -1].unsqueeze(1) logits = torch.where(logits < cutoff, torch.full_like(logits, float("-inf")), logits) probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) ids = torch.cat([ids, next_id], dim=1) if int(next_id.item()) == eos_id: break return ids[0].detach().cpu().numpy().astype(int).tolist() def load_assets(): vocab_path = hf_hub_download(repo_id=MODEL_REPO, filename=VOCAB_FILENAME) ckpt_path = hf_hub_download(repo_id=MODEL_REPO, filename=CKPT_FILENAME) stoi, itos = load_vocab(vocab_path) pad_id = stoi[PAD] eos_id = stoi[EOS] unk_id = stoi[UNK] device = "cuda" if torch.cuda.is_available() else "cpu" ckpt = torch.load(ckpt_path, map_location=device) cfg = GPTConfig(**ckpt["cfg"]) model = TinyGPT(cfg).to(device) model.load_state_dict(ckpt["state_dict"]) model.eval() return model, device, stoi, itos, pad_id, eos_id, unk_id MODEL, DEVICE, STOI, ITOS, PAD_ID, EOS_ID, UNK_ID = load_assets() def chat_fn(message: str, temperature: float, top_k: int, max_new: int) -> str: intent = route_intent(message) prompt = f"<{intent}>\nUser: {message}\nBot:" prompt_ids = encode(STOI, UNK_ID, prompt) out_ids = sample(MODEL, prompt_ids, EOS_ID, max_new=max_new, temperature=temperature, top_k=top_k, device=DEVICE) gen = decode(ITOS, EOS_ID, out_ids[len(prompt_ids):]) if "User:" in gen: gen = gen.split("User:")[0] gen = gen.replace("Bot:", "").strip() return gen if gen else "I can help with store support topics like orders, shipping, refunds, payments, and account access." demo = gr.Interface( fn=chat_fn, inputs=[ gr.Textbox(label="Message"), gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"), gr.Slider(0, 200, value=50, step=1, label="Top-k"), gr.Slider(20, 300, value=120, step=5, label="Max new tokens"), ], outputs=gr.Textbox(label="Bot reply"), title="Tiny FAQ Chatbot", ) if __name__ == "__main__": demo.launch()