Spaces:
Sleeping
Sleeping
| import json | |
| import math | |
| import re | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| MODEL_REPO = "aagzamov/tiny-chatbot-model" | |
| CKPT_FILENAME = "ckpt.pt" | |
| VOCAB_FILENAME = "vocab.json" | |
| TOKEN_RE = re.compile(r"<[^>]+>|[A-Za-z0-9]+|[^\sA-Za-z0-9]", re.UNICODE) | |
| PAD = "<pad>" | |
| EOS = "<eos>" | |
| UNK = "<unk>" | |
| def tokenize(text: str) -> list[str]: | |
| return TOKEN_RE.findall(text.strip()) | |
| def load_vocab(local_path: str) -> tuple[dict[str, int], list[str]]: | |
| obj = json.loads(Path(local_path).read_text(encoding="utf-8")) | |
| itos = obj["itos"] | |
| stoi = {t: i for i, t in enumerate(itos)} | |
| return stoi, itos | |
| def encode(stoi: dict[str, int], unk_id: int, text: str) -> list[int]: | |
| return [stoi.get(tok, unk_id) for tok in tokenize(text)] | |
| def decode(itos: list[str], eos_id: int, ids: list[int]) -> str: | |
| toks = [] | |
| for i in ids: | |
| if i == eos_id: | |
| break | |
| if 0 <= i < len(itos): | |
| tok = itos[i] | |
| if tok == PAD: | |
| continue | |
| toks.append(tok) | |
| return " ".join(toks).replace(" ,", ",").replace(" .", ".").replace(" !", "!").replace(" ?", "?") | |
| RULES: list[tuple[str, list[str]]] = [ | |
| ("refund", ["refund", "money back", "chargeback"]), | |
| ("return_process", ["return", "exchange"]), | |
| ("damaged", ["damaged", "broken", "cracked", "defect"]), | |
| ("shipping_time", ["shipping time", "delivery time", "how long", "arrive"]), | |
| ("express", ["express", "fast delivery", "1-2 day", "1–2 day"]), | |
| ("international", ["international", "other country", "abroad"]), | |
| ("tracking", ["tracking", "track", "track my order", "order tracking", "tracking link", "tracking number"]), | |
| ("payment_methods", ["payment methods", "how can i pay", "pay with", "payment option"]), | |
| ("payment_failed", ["payment failed", "cant pay", "can’t pay", "declined", "checkout payment error"]), | |
| ("discount", ["discount", "coupon", "promo code"]), | |
| ("account_create", ["create account", "sign up", "register"]), | |
| ("reset_password", ["forgot password", "reset password", "cant login", "can’t login", "cannot login", "login problem", "login issue"]), | |
| ("cancel_order", ["cancel", "cancellation"]), | |
| ("address_change", ["change address", "update address"]), | |
| ("delivered_not_received", ["delivered but", "says delivered", "not received"]), | |
| ("warranty", ["warranty", "guarantee"]), | |
| ("size", ["size chart", "size guide", "which size"]), | |
| ("support", ["contact", "support", "help"]), | |
| ] | |
| def route_intent(question: str) -> str: | |
| q = re.sub(r"\s+", " ", question.lower().strip()) | |
| for intent, keys in RULES: | |
| if any(k in q for k in keys): | |
| return intent | |
| return "unknown" | |
| class GPTConfig: | |
| def __init__(self, vocab_size: int, ctx_len: int, n_layers: int, n_heads: int, d_model: int, ff_mult: int = 4, dropout: float = 0.0): | |
| self.vocab_size = vocab_size | |
| self.ctx_len = ctx_len | |
| self.n_layers = n_layers | |
| self.n_heads = n_heads | |
| self.d_model = d_model | |
| self.ff_mult = ff_mult | |
| self.dropout = dropout | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, cfg: GPTConfig): | |
| super().__init__() | |
| assert cfg.d_model % cfg.n_heads == 0 | |
| self.cfg = cfg | |
| self.head_dim = cfg.d_model // cfg.n_heads | |
| self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=True) | |
| self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=True) | |
| mask = torch.tril(torch.ones(cfg.ctx_len, cfg.ctx_len)).view(1, 1, cfg.ctx_len, cfg.ctx_len) | |
| self.register_buffer("mask", mask, persistent=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| b, t, c = x.shape | |
| qkv = self.qkv(x) | |
| q, k, v = qkv.split(c, dim=2) | |
| q = q.view(b, t, self.cfg.n_heads, self.head_dim).transpose(1, 2) | |
| k = k.view(b, t, self.cfg.n_heads, self.head_dim).transpose(1, 2) | |
| v = v.view(b, t, self.cfg.n_heads, self.head_dim).transpose(1, 2) | |
| att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf")) | |
| att = F.softmax(att, dim=-1) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(b, t, c) | |
| return self.proj(y) | |
| class MLP(nn.Module): | |
| def __init__(self, cfg: GPTConfig): | |
| super().__init__() | |
| hidden = cfg.d_model * cfg.ff_mult | |
| self.fc1 = nn.Linear(cfg.d_model, hidden, bias=True) | |
| self.fc2 = nn.Linear(hidden, cfg.d_model, bias=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.fc2(F.gelu(self.fc1(x))) | |
| class Block(nn.Module): | |
| def __init__(self, cfg: GPTConfig): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(cfg.d_model) | |
| self.attn = CausalSelfAttention(cfg) | |
| self.ln2 = nn.LayerNorm(cfg.d_model) | |
| self.mlp = MLP(cfg) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class TinyGPT(nn.Module): | |
| def __init__(self, cfg: GPTConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.pos_emb = nn.Embedding(cfg.ctx_len, cfg.d_model) | |
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) | |
| self.ln_f = nn.LayerNorm(cfg.d_model) | |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| self.lm_head.weight = self.tok_emb.weight | |
| def forward(self, idx: torch.Tensor): | |
| b, t = idx.shape | |
| pos = torch.arange(0, t, device=idx.device).unsqueeze(0) | |
| x = self.tok_emb(idx) + self.pos_emb(pos) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = self.ln_f(x) | |
| return self.lm_head(x) | |
| def sample(model: TinyGPT, prompt_ids: list[int], eos_id: int, max_new: int, temperature: float, top_k: int, device: str) -> list[int]: | |
| ids = torch.tensor(np.array(prompt_ids, dtype=np.int64), device=device).unsqueeze(0) | |
| for _ in range(max_new): | |
| ids_cond = ids[:, -model.cfg.ctx_len :] | |
| logits = model(ids_cond)[:, -1, :] / max(1e-6, temperature) | |
| if top_k > 0: | |
| v, _ = torch.topk(logits, top_k) | |
| cutoff = v[:, -1].unsqueeze(1) | |
| logits = torch.where(logits < cutoff, torch.full_like(logits, float("-inf")), logits) | |
| probs = torch.softmax(logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1) | |
| ids = torch.cat([ids, next_id], dim=1) | |
| if int(next_id.item()) == eos_id: | |
| break | |
| return ids[0].detach().cpu().numpy().astype(int).tolist() | |
| def load_assets(): | |
| vocab_path = hf_hub_download(repo_id=MODEL_REPO, filename=VOCAB_FILENAME) | |
| ckpt_path = hf_hub_download(repo_id=MODEL_REPO, filename=CKPT_FILENAME) | |
| stoi, itos = load_vocab(vocab_path) | |
| pad_id = stoi[PAD] | |
| eos_id = stoi[EOS] | |
| unk_id = stoi[UNK] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ckpt = torch.load(ckpt_path, map_location=device) | |
| cfg = GPTConfig(**ckpt["cfg"]) | |
| model = TinyGPT(cfg).to(device) | |
| model.load_state_dict(ckpt["state_dict"]) | |
| model.eval() | |
| return model, device, stoi, itos, pad_id, eos_id, unk_id | |
| MODEL, DEVICE, STOI, ITOS, PAD_ID, EOS_ID, UNK_ID = load_assets() | |
| def chat_fn(message: str, temperature: float, top_k: int, max_new: int) -> str: | |
| intent = route_intent(message) | |
| prompt = f"<{intent}>\nUser: {message}\nBot:" | |
| prompt_ids = encode(STOI, UNK_ID, prompt) | |
| out_ids = sample(MODEL, prompt_ids, EOS_ID, max_new=max_new, temperature=temperature, top_k=top_k, device=DEVICE) | |
| gen = decode(ITOS, EOS_ID, out_ids[len(prompt_ids):]) | |
| if "User:" in gen: | |
| gen = gen.split("User:")[0] | |
| gen = gen.replace("Bot:", "").strip() | |
| return gen if gen else "I can help with store support topics like orders, shipping, refunds, payments, and account access." | |
| demo = gr.Interface( | |
| fn=chat_fn, | |
| inputs=[ | |
| gr.Textbox(label="Message"), | |
| gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"), | |
| gr.Slider(0, 200, value=50, step=1, label="Top-k"), | |
| gr.Slider(20, 300, value=120, step=5, label="Max new tokens"), | |
| ], | |
| outputs=gr.Textbox(label="Bot reply"), | |
| title="Tiny FAQ Chatbot", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |