nishantup commited on
Commit
2e92a84
·
verified ·
1 Parent(s): 94df9e8

Upload nanogpt_slm_tinystories_classifier_inference.py with huggingface_hub

Browse files
nanogpt_slm_tinystories_classifier_inference.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nanoGPT SLM Classifier -- Standalone Inference
3
+ ================================================
4
+ 124M parameter spam classifier.
5
+ Pretrained on TinyStories (2.1M stories) -> Classification fine-tuned on 60K spam dataset.
6
+ Binary classification: "spam" vs "not spam" using last-token logits.
7
+
8
+ Install: pip install torch tiktoken huggingface_hub
9
+ Run: python nanogpt_slm_tinystories_classifier_inference.py
10
+ Import: from nanogpt_slm_tinystories_classifier_inference import classify, classify_batch
11
+ """
12
+
13
+ import torch, torch.nn as nn, torch.nn.functional as F, math, tiktoken
14
+ from dataclasses import dataclass
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ # ==============================================================
18
+ # ARCHITECTURE (nanoGPT -- modified for 2-class classification)
19
+ # ==============================================================
20
+
21
+ class LayerNorm(nn.Module):
22
+ def __init__(self, ndim, bias):
23
+ super().__init__()
24
+ self.weight = nn.Parameter(torch.ones(ndim))
25
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
26
+ def forward(self, x):
27
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+ def __init__(self, config):
31
+ super().__init__()
32
+ assert config.n_embd % config.n_head == 0
33
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
34
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
35
+ self.attn_dropout = nn.Dropout(config.dropout)
36
+ self.resid_dropout = nn.Dropout(config.dropout)
37
+ self.n_head, self.n_embd = config.n_head, config.n_embd
38
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
39
+ if not self.flash:
40
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
41
+ .view(1, 1, config.block_size, config.block_size))
42
+ def forward(self, x):
43
+ B, T, C = x.size()
44
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
45
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
46
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
47
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
48
+ if self.flash:
49
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
50
+ dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
51
+ else:
52
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
53
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
54
+ att = F.softmax(att, dim=-1); att = self.attn_dropout(att); y = att @ v
55
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
56
+ return self.resid_dropout(self.c_proj(y))
57
+
58
+ class MLP(nn.Module):
59
+ def __init__(self, config):
60
+ super().__init__()
61
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
62
+ self.gelu = nn.GELU()
63
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
64
+ self.dropout = nn.Dropout(config.dropout)
65
+ def forward(self, x):
66
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
67
+
68
+ class Block(nn.Module):
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.ln1, self.attn = LayerNorm(config.n_embd, config.bias), CausalSelfAttention(config)
72
+ self.ln2, self.mlp = LayerNorm(config.n_embd, config.bias), MLP(config)
73
+ def forward(self, x):
74
+ x = x + self.attn(self.ln1(x))
75
+ return x + self.mlp(self.ln2(x))
76
+
77
+ @dataclass
78
+ class GPTConfig:
79
+ block_size: int = 512; vocab_size: int = 50257
80
+ n_layer: int = 12; n_head: int = 12; n_embd: int = 768
81
+ dropout: float = 0.0; bias: bool = True
82
+
83
+ class GPT(nn.Module):
84
+ def __init__(self, config):
85
+ super().__init__()
86
+ self.config = config
87
+ self.transformer = nn.ModuleDict(dict(
88
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
89
+ wpe=nn.Embedding(config.block_size, config.n_embd),
90
+ drop=nn.Dropout(config.dropout),
91
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
92
+ ln_f=LayerNorm(config.n_embd, config.bias),
93
+ ))
94
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
95
+ self.transformer.wte.weight = self.lm_head.weight # weight tying
96
+
97
+ def forward(self, idx, targets=None):
98
+ b, t = idx.size()
99
+ pos = torch.arange(0, t, dtype=torch.long, device=idx.device)
100
+ x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos))
101
+ for block in self.transformer.h:
102
+ x = block(x)
103
+ x = self.transformer.ln_f(x)
104
+ if targets is not None:
105
+ logits = self.lm_head(x)
106
+ return logits, F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
107
+ else:
108
+ logits = self.lm_head(x[:, [-1], :])
109
+ return logits, None
110
+
111
+
112
+ # ==============================================================
113
+ # CLASSIFICATION CONFIG
114
+ # ==============================================================
115
+
116
+ NUM_CLASSES = 2
117
+ MAX_LENGTH = 512 # Max token length used during training (longest sequence)
118
+ PAD_TOKEN = 50256 # <|endoftext|>
119
+ LABELS = {0: "not spam", 1: "spam"}
120
+
121
+ # ==============================================================
122
+ # CLASSIFICATION FUNCTIONS
123
+ # ==============================================================
124
+
125
+ def classify(text, max_length=MAX_LENGTH):
126
+ """
127
+ Classify a single text as 'spam' or 'not spam'.
128
+
129
+ Args:
130
+ text: Input text string
131
+ max_length: Pad/truncate to this length (default: 120)
132
+
133
+ Returns:
134
+ dict with 'label', 'confidence', and 'probabilities'
135
+ """
136
+ model.eval()
137
+ input_ids = tokenizer.encode(text)
138
+ supported_context_length = model.transformer.wpe.weight.shape[0]
139
+
140
+ # Truncate
141
+ input_ids = input_ids[:min(max_length, supported_context_length)]
142
+
143
+ # Pad
144
+ input_ids += [PAD_TOKEN] * (max_length - len(input_ids))
145
+ input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)
146
+
147
+ with torch.no_grad():
148
+ logits, _ = model(input_tensor)
149
+ logits = logits[:, -1, :] # Last token logits: (1, num_classes)
150
+ probs = torch.softmax(logits, dim=-1).squeeze(0)
151
+ predicted = torch.argmax(probs).item()
152
+
153
+ return {
154
+ "label": LABELS[predicted],
155
+ "confidence": probs[predicted].item(),
156
+ "probabilities": {LABELS[i]: probs[i].item() for i in range(NUM_CLASSES)},
157
+ }
158
+
159
+
160
+ def classify_batch(texts, max_length=MAX_LENGTH):
161
+ """Classify multiple texts. Returns list of result dicts."""
162
+ return [classify(text, max_length) for text in texts]
163
+
164
+
165
+ def is_spam(text, max_length=MAX_LENGTH):
166
+ """Simple boolean check: returns True if spam, False if not."""
167
+ return classify(text, max_length)["label"] == "spam"
168
+
169
+
170
+ # ==============================================================
171
+ # LOAD MODEL (auto-downloads from HuggingFace Hub)
172
+ # ==============================================================
173
+
174
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
175
+ config = GPTConfig()
176
+ tokenizer = tiktoken.get_encoding("gpt2")
177
+
178
+ weights_path = hf_hub_download(repo_id="nishantup/nanogpt-slm-tinystories-classifier",
179
+ filename="nanogpt_slm_tinystories_classifier.pth")
180
+
181
+ # 1. Build base GPT model
182
+ model = GPT(config)
183
+
184
+ # 2. Replace lm_head with 2-class classification head
185
+ # (must happen BEFORE loading state_dict since saved weights have shape (2, 768))
186
+ model.lm_head = nn.Linear(in_features=config.n_embd, out_features=NUM_CLASSES)
187
+
188
+ # 3. Load fine-tuned classifier weights
189
+ model.load_state_dict(torch.load(weights_path, map_location=device))
190
+ model.to(device)
191
+ model.eval()
192
+
193
+ total_params = sum(p.numel() for p in model.parameters())
194
+ print(f"nanoGPT Spam Classifier loaded: {total_params:,} params on {device}")
195
+ print(f"Config: {config.n_layer}L / {config.n_head}H / {config.n_embd}D / ctx={config.block_size}")
196
+ print(f"Classification: {NUM_CLASSES} classes ({', '.join(LABELS.values())})")
197
+ print(f"Max sequence length: {MAX_LENGTH} tokens\n")
198
+
199
+ # ==============================================================
200
+ # EXAMPLES (only run when executed directly)
201
+ # ==============================================================
202
+
203
+ if __name__ == "__main__":
204
+
205
+ # Spam examples
206
+ spam_texts = [
207
+ "You are a winner you have been specially selected to receive $1000 cash or a $2000 award.",
208
+ "URGENT! You have won a free ticket to the Bahamas. Call now!",
209
+ "Congratulations! You've been selected for a $500 Walmart gift card. Click here to claim.",
210
+ "FREE entry to our prize draw! Text WIN to 80085 now!",
211
+ ]
212
+
213
+ # Ham (not spam) examples
214
+ ham_texts = [
215
+ "Hey, just wanted to check if we're still on for dinner tonight? Let me know!",
216
+ "Can you pick up some milk on your way home? Thanks!",
217
+ "The meeting has been moved to 3pm tomorrow. See you there.",
218
+ "Happy birthday! Hope you have a wonderful day!",
219
+ ]
220
+
221
+ print("=" * 60)
222
+ print("SPAM DETECTION RESULTS")
223
+ print("=" * 60)
224
+
225
+ print("\n-- Known SPAM messages --")
226
+ for text in spam_texts:
227
+ result = classify(text)
228
+ conf = result['confidence'] * 100
229
+ print(f"\n Text: {text[:80]}...")
230
+ print(f" Prediction: {result['label'].upper()} ({conf:.1f}% confidence)")
231
+
232
+ print(f"\n-- Known HAM (not spam) messages --")
233
+ for text in ham_texts:
234
+ result = classify(text)
235
+ conf = result['confidence'] * 100
236
+ print(f"\n Text: {text[:80]}...")
237
+ print(f" Prediction: {result['label'].upper()} ({conf:.1f}% confidence)")
238
+
239
+ # Accuracy summary
240
+ print(f"\n{'=' * 60}")
241
+ print("ACCURACY SUMMARY")
242
+ print("=" * 60)
243
+ spam_correct = sum(1 for t in spam_texts if is_spam(t))
244
+ ham_correct = sum(1 for t in ham_texts if not is_spam(t))
245
+ total = len(spam_texts) + len(ham_texts)
246
+ correct = spam_correct + ham_correct
247
+ print(f" Spam detected: {spam_correct}/{len(spam_texts)}")
248
+ print(f" Ham detected: {ham_correct}/{len(ham_texts)}")
249
+ print(f" Overall accuracy: {correct}/{total} ({correct/total*100:.0f}%)")
250
+
251
+ # Boolean API demo
252
+ print(f"\n{'=' * 60}")
253
+ print("BOOLEAN API: is_spam()")
254
+ print("=" * 60)
255
+ test = "Click here to claim your free iPhone!"
256
+ print(f" is_spam(\"{test}\")")
257
+ print(f" -> {is_spam(test)}")