| """ |
| nanoGPT SLM Classifier -- Standalone Inference |
| ================================================ |
| 124M parameter spam classifier, fine-tuned from nanoGPT pretrained SLM. |
| Binary classification: "spam" vs "not spam" using last-token logits. |
| |
| Install: pip install torch tiktoken huggingface_hub |
| Run: python nanogpt_classifier_inference.py |
| Import: from nanogpt_classifier_inference import classify, classify_batch |
| """ |
|
|
| import torch, torch.nn as nn, torch.nn.functional as F, math, tiktoken |
| from dataclasses import dataclass |
| from huggingface_hub import hf_hub_download |
|
|
| |
| |
| |
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, ndim, bias): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(ndim)) |
| self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
| def forward(self, x): |
| return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| self.attn_dropout = nn.Dropout(config.dropout) |
| self.resid_dropout = nn.Dropout(config.dropout) |
| self.n_head, self.n_embd = config.n_head, config.n_embd |
| self.flash = hasattr(F, 'scaled_dot_product_attention') |
| if not self.flash: |
| self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) |
| .view(1, 1, config.block_size, config.block_size)) |
| def forward(self, x): |
| B, T, C = x.size() |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| if self.flash: |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, |
| dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True) |
| else: |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
| att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) |
| att = F.softmax(att, dim=-1); att = self.attn_dropout(att); y = att @ v |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| return self.resid_dropout(self.c_proj(y)) |
|
|
| class MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
| self.gelu = nn.GELU() |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
| self.dropout = nn.Dropout(config.dropout) |
| def forward(self, x): |
| return self.dropout(self.c_proj(self.gelu(self.c_fc(x)))) |
|
|
| class Block(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln1, self.attn = LayerNorm(config.n_embd, config.bias), CausalSelfAttention(config) |
| self.ln2, self.mlp = LayerNorm(config.n_embd, config.bias), MLP(config) |
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| return x + self.mlp(self.ln2(x)) |
|
|
| @dataclass |
| class GPTConfig: |
| block_size: int = 256; vocab_size: int = 50257 |
| n_layer: int = 12; n_head: int = 12; n_embd: int = 768 |
| dropout: float = 0.0; bias: bool = True |
|
|
| class GPT(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.transformer = nn.ModuleDict(dict( |
| wte=nn.Embedding(config.vocab_size, config.n_embd), |
| wpe=nn.Embedding(config.block_size, config.n_embd), |
| drop=nn.Dropout(config.dropout), |
| h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| ln_f=LayerNorm(config.n_embd, config.bias), |
| )) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.transformer.wte.weight = self.lm_head.weight |
|
|
| def forward(self, idx, targets=None): |
| b, t = idx.size() |
| pos = torch.arange(0, t, dtype=torch.long, device=idx.device) |
| x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos)) |
| for block in self.transformer.h: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
| if targets is not None: |
| logits = self.lm_head(x) |
| return logits, F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
| else: |
| logits = self.lm_head(x[:, [-1], :]) |
| return logits, None |
|
|
|
|
| |
| |
| |
|
|
| NUM_CLASSES = 2 |
| MAX_LENGTH = 120 |
| PAD_TOKEN = 50256 |
| LABELS = {0: "not spam", 1: "spam"} |
|
|
| |
| |
| |
|
|
| def classify(text, max_length=MAX_LENGTH): |
| """ |
| Classify a single text as 'spam' or 'not spam'. |
| |
| Args: |
| text: Input text string |
| max_length: Pad/truncate to this length (default: 120) |
| |
| Returns: |
| dict with 'label', 'confidence', and 'probabilities' |
| """ |
| model.eval() |
| input_ids = tokenizer.encode(text) |
| supported_context_length = model.transformer.wpe.weight.shape[0] |
|
|
| |
| input_ids = input_ids[:min(max_length, supported_context_length)] |
|
|
| |
| input_ids += [PAD_TOKEN] * (max_length - len(input_ids)) |
| input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| logits, _ = model(input_tensor) |
| logits = logits[:, -1, :] |
| probs = torch.softmax(logits, dim=-1).squeeze(0) |
| predicted = torch.argmax(probs).item() |
|
|
| return { |
| "label": LABELS[predicted], |
| "confidence": probs[predicted].item(), |
| "probabilities": {LABELS[i]: probs[i].item() for i in range(NUM_CLASSES)}, |
| } |
|
|
|
|
| def classify_batch(texts, max_length=MAX_LENGTH): |
| """Classify multiple texts. Returns list of result dicts.""" |
| return [classify(text, max_length) for text in texts] |
|
|
|
|
| def is_spam(text, max_length=MAX_LENGTH): |
| """Simple boolean check: returns True if spam, False if not.""" |
| return classify(text, max_length)["label"] == "spam" |
|
|
|
|
| |
| |
| |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| config = GPTConfig() |
| tokenizer = tiktoken.get_encoding("gpt2") |
|
|
| weights_path = hf_hub_download(repo_id="nishantup/nanogpt-slm-classifier", |
| filename="nanogpt_classifier.pth") |
|
|
| |
| model = GPT(config) |
|
|
| |
| |
| model.lm_head = nn.Linear(in_features=config.n_embd, out_features=NUM_CLASSES) |
|
|
| |
| model.load_state_dict(torch.load(weights_path, map_location=device)) |
| model.to(device) |
| model.eval() |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"nanoGPT Spam Classifier loaded: {total_params:,} params on {device}") |
| print(f"Config: {config.n_layer}L / {config.n_head}H / {config.n_embd}D / ctx={config.block_size}") |
| print(f"Classification: {NUM_CLASSES} classes ({', '.join(LABELS.values())})") |
| print(f"Max sequence length: {MAX_LENGTH} tokens\n") |
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
|
|
| |
| spam_texts = [ |
| "You are a winner you have been specially selected to receive $1000 cash or a $2000 award.", |
| "URGENT! You have won a free ticket to the Bahamas. Call now!", |
| "Congratulations! You've been selected for a $500 Walmart gift card. Click here to claim.", |
| "FREE entry to our prize draw! Text WIN to 80085 now!", |
| ] |
|
|
| |
| ham_texts = [ |
| "Hey, just wanted to check if we're still on for dinner tonight? Let me know!", |
| "Can you pick up some milk on your way home? Thanks!", |
| "The meeting has been moved to 3pm tomorrow. See you there.", |
| "Happy birthday! Hope you have a wonderful day!", |
| ] |
|
|
| print("=" * 60) |
| print("SPAM DETECTION RESULTS") |
| print("=" * 60) |
|
|
| print("\n-- Known SPAM messages --") |
| for text in spam_texts: |
| result = classify(text) |
| conf = result['confidence'] * 100 |
| print(f"\n Text: {text[:80]}...") |
| print(f" Prediction: {result['label'].upper()} ({conf:.1f}% confidence)") |
|
|
| print(f"\n-- Known HAM (not spam) messages --") |
| for text in ham_texts: |
| result = classify(text) |
| conf = result['confidence'] * 100 |
| print(f"\n Text: {text[:80]}...") |
| print(f" Prediction: {result['label'].upper()} ({conf:.1f}% confidence)") |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("ACCURACY SUMMARY") |
| print("=" * 60) |
| spam_correct = sum(1 for t in spam_texts if is_spam(t)) |
| ham_correct = sum(1 for t in ham_texts if not is_spam(t)) |
| total = len(spam_texts) + len(ham_texts) |
| correct = spam_correct + ham_correct |
| print(f" Spam detected: {spam_correct}/{len(spam_texts)}") |
| print(f" Ham detected: {ham_correct}/{len(ham_texts)}") |
| print(f" Overall accuracy: {correct}/{total} ({correct/total*100:.0f}%)") |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("BOOLEAN API: is_spam()") |
| print("=" * 60) |
| test = "Click here to claim your free iPhone!" |
| print(f" is_spam(\"{test}\")") |
| print(f" -> {is_spam(test)}") |
|
|