nanogpt-slm-classifier / nanogpt_classifier_inference.py
nishantup's picture
Upload nanogpt_classifier_inference.py with huggingface_hub
374c69d verified
"""
nanoGPT SLM Classifier -- Standalone Inference
================================================
124M parameter spam classifier, fine-tuned from nanoGPT pretrained SLM.
Binary classification: "spam" vs "not spam" using last-token logits.
Install: pip install torch tiktoken huggingface_hub
Run: python nanogpt_classifier_inference.py
Import: from nanogpt_classifier_inference import classify, classify_batch
"""
import torch, torch.nn as nn, torch.nn.functional as F, math, tiktoken
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
# ==============================================================
# ARCHITECTURE (nanoGPT -- modified for 2-class classification)
# ==============================================================
class LayerNorm(nn.Module):
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, x):
return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head, self.n_embd = config.n_head, config.n_embd
self.flash = hasattr(F, 'scaled_dot_product_attention')
if not self.flash:
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
if self.flash:
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
else:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1); att = self.attn_dropout(att); y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.resid_dropout(self.c_proj(y))
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1, self.attn = LayerNorm(config.n_embd, config.bias), CausalSelfAttention(config)
self.ln2, self.mlp = LayerNorm(config.n_embd, config.bias), MLP(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))
return x + self.mlp(self.ln2(x))
@dataclass
class GPTConfig:
block_size: int = 256; vocab_size: int = 50257
n_layer: int = 12; n_head: int = 12; n_embd: int = 768
dropout: float = 0.0; bias: bool = True
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
wpe=nn.Embedding(config.block_size, config.n_embd),
drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=LayerNorm(config.n_embd, config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight # weight tying
def forward(self, idx, targets=None):
b, t = idx.size()
pos = torch.arange(0, t, dtype=torch.long, device=idx.device)
x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos))
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
logits = self.lm_head(x)
return logits, F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
logits = self.lm_head(x[:, [-1], :])
return logits, None
# ==============================================================
# CLASSIFICATION CONFIG
# ==============================================================
NUM_CLASSES = 2
MAX_LENGTH = 120 # Max token length used during training (longest sequence)
PAD_TOKEN = 50256 # <|endoftext|>
LABELS = {0: "not spam", 1: "spam"}
# ==============================================================
# CLASSIFICATION FUNCTIONS
# ==============================================================
def classify(text, max_length=MAX_LENGTH):
"""
Classify a single text as 'spam' or 'not spam'.
Args:
text: Input text string
max_length: Pad/truncate to this length (default: 120)
Returns:
dict with 'label', 'confidence', and 'probabilities'
"""
model.eval()
input_ids = tokenizer.encode(text)
supported_context_length = model.transformer.wpe.weight.shape[0]
# Truncate
input_ids = input_ids[:min(max_length, supported_context_length)]
# Pad
input_ids += [PAD_TOKEN] * (max_length - len(input_ids))
input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)
with torch.no_grad():
logits, _ = model(input_tensor)
logits = logits[:, -1, :] # Last token logits: (1, num_classes)
probs = torch.softmax(logits, dim=-1).squeeze(0)
predicted = torch.argmax(probs).item()
return {
"label": LABELS[predicted],
"confidence": probs[predicted].item(),
"probabilities": {LABELS[i]: probs[i].item() for i in range(NUM_CLASSES)},
}
def classify_batch(texts, max_length=MAX_LENGTH):
"""Classify multiple texts. Returns list of result dicts."""
return [classify(text, max_length) for text in texts]
def is_spam(text, max_length=MAX_LENGTH):
"""Simple boolean check: returns True if spam, False if not."""
return classify(text, max_length)["label"] == "spam"
# ==============================================================
# LOAD MODEL (auto-downloads from HuggingFace Hub)
# ==============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = GPTConfig()
tokenizer = tiktoken.get_encoding("gpt2")
weights_path = hf_hub_download(repo_id="nishantup/nanogpt-slm-classifier",
filename="nanogpt_classifier.pth")
# 1. Build base GPT model
model = GPT(config)
# 2. Replace lm_head with 2-class classification head
# (must happen BEFORE loading state_dict since saved weights have shape (2, 768))
model.lm_head = nn.Linear(in_features=config.n_embd, out_features=NUM_CLASSES)
# 3. Load fine-tuned classifier weights
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f"nanoGPT Spam Classifier loaded: {total_params:,} params on {device}")
print(f"Config: {config.n_layer}L / {config.n_head}H / {config.n_embd}D / ctx={config.block_size}")
print(f"Classification: {NUM_CLASSES} classes ({', '.join(LABELS.values())})")
print(f"Max sequence length: {MAX_LENGTH} tokens\n")
# ==============================================================
# EXAMPLES (only run when executed directly)
# ==============================================================
if __name__ == "__main__":
# Spam examples
spam_texts = [
"You are a winner you have been specially selected to receive $1000 cash or a $2000 award.",
"URGENT! You have won a free ticket to the Bahamas. Call now!",
"Congratulations! You've been selected for a $500 Walmart gift card. Click here to claim.",
"FREE entry to our prize draw! Text WIN to 80085 now!",
]
# Ham (not spam) examples
ham_texts = [
"Hey, just wanted to check if we're still on for dinner tonight? Let me know!",
"Can you pick up some milk on your way home? Thanks!",
"The meeting has been moved to 3pm tomorrow. See you there.",
"Happy birthday! Hope you have a wonderful day!",
]
print("=" * 60)
print("SPAM DETECTION RESULTS")
print("=" * 60)
print("\n-- Known SPAM messages --")
for text in spam_texts:
result = classify(text)
conf = result['confidence'] * 100
print(f"\n Text: {text[:80]}...")
print(f" Prediction: {result['label'].upper()} ({conf:.1f}% confidence)")
print(f"\n-- Known HAM (not spam) messages --")
for text in ham_texts:
result = classify(text)
conf = result['confidence'] * 100
print(f"\n Text: {text[:80]}...")
print(f" Prediction: {result['label'].upper()} ({conf:.1f}% confidence)")
# Accuracy summary
print(f"\n{'=' * 60}")
print("ACCURACY SUMMARY")
print("=" * 60)
spam_correct = sum(1 for t in spam_texts if is_spam(t))
ham_correct = sum(1 for t in ham_texts if not is_spam(t))
total = len(spam_texts) + len(ham_texts)
correct = spam_correct + ham_correct
print(f" Spam detected: {spam_correct}/{len(spam_texts)}")
print(f" Ham detected: {ham_correct}/{len(ham_texts)}")
print(f" Overall accuracy: {correct}/{total} ({correct/total*100:.0f}%)")
# Boolean API demo
print(f"\n{'=' * 60}")
print("BOOLEAN API: is_spam()")
print("=" * 60)
test = "Click here to claim your free iPhone!"
print(f" is_spam(\"{test}\")")
print(f" -> {is_spam(test)}")