Aditi4275 commited on
Commit
11ec1a2
·
verified ·
1 Parent(s): 9eea007

upload main file

Browse files
Files changed (1) hide show
  1. app.py +304 -0
app.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ import os
5
+ import pickle
6
+
7
+ class LayerNorm(nn.Module):
8
+ def __init__(self, emb_dim):
9
+ super().__init__()
10
+ self.eps = 1e-5
11
+ self.scale = nn.Parameter(torch.ones(emb_dim))
12
+ self.shift = nn.Parameter(torch.zeros(emb_dim))
13
+
14
+ def forward(self, x):
15
+ mean = x.mean(dim=-1, keepdim=True)
16
+ var = x.var(dim=-1, keepdim=True)
17
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
18
+ return self.scale * norm_x + self.shift
19
+
20
+
21
+ class GELU(nn.Module):
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ def forward(self, x):
26
+ return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))))
27
+
28
+
29
+ class MultiHeadAttention(nn.Module):
30
+ def __init__(self, d_in, d_out, context_length, dropout, num_head, qkv_bias=False):
31
+ super().__init__()
32
+ assert (d_out % num_head == 0)
33
+
34
+ self.d_out = d_out
35
+ self.num_head = num_head
36
+ self.head_dim = d_out // num_head
37
+
38
+ self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
39
+ self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
40
+ self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
41
+
42
+ self.out_proj = torch.nn.Linear(d_out, d_out)
43
+ self.dropout = torch.nn.Dropout(dropout)
44
+ self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
45
+
46
+ def forward(self, x):
47
+ b, num_tokens, d_in = x.shape
48
+
49
+ keys = self.W_key(x)
50
+ queries = self.W_query(x)
51
+ values = self.W_value(x)
52
+
53
+ keys = keys.view(b, num_tokens, self.num_head, self.head_dim)
54
+ values = values.view(b, num_tokens, self.num_head, self.head_dim)
55
+ queries = queries.view(b, num_tokens, self.num_head, self.head_dim)
56
+
57
+ keys = keys.transpose(1, 2)
58
+ values = values.transpose(1, 2)
59
+ queries = queries.transpose(1, 2)
60
+
61
+ attn_score = queries @ keys.transpose(2, 3)
62
+
63
+ mask_bool = self.mask.to(torch.bool)[:num_tokens, :num_tokens]
64
+
65
+ attn_score.masked_fill_(mask_bool, -torch.inf)
66
+
67
+ attn_weight = torch.softmax(attn_score / keys.shape[-1] ** 0.5, dim=-1)
68
+ attn_weight = self.dropout(attn_weight)
69
+
70
+ context_vector = (attn_weight @ values).transpose(1, 2)
71
+
72
+ context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out)
73
+ context_vector = self.out_proj(context_vector)
74
+
75
+ return context_vector
76
+
77
+
78
+ class FeedForward(nn.Module):
79
+ def __init__(self, cfg):
80
+ super().__init__()
81
+ self.layers = nn.Sequential(
82
+ nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
83
+ GELU(),
84
+ nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"])
85
+ )
86
+
87
+ def forward(self, x):
88
+ return self.layers(x)
89
+
90
+
91
+ class TransformerBlock(nn.Module):
92
+ def __init__(self, cfg):
93
+ super().__init__()
94
+ self.att = MultiHeadAttention(
95
+ d_in=cfg["emb_dim"],
96
+ d_out=cfg["emb_dim"],
97
+ context_length=cfg["context_length"],
98
+ num_head=cfg["n_heads"],
99
+ dropout=cfg.get("drop_rate", 0.0),
100
+ qkv_bias=cfg.get("qkv_bias", False)
101
+ )
102
+
103
+ self.ff = FeedForward(cfg)
104
+ self.norm1 = LayerNorm(cfg["emb_dim"])
105
+ self.norm2 = LayerNorm(cfg["emb_dim"])
106
+ self.drop_shortcut = nn.Dropout(cfg.get("drop_rate", 0.0))
107
+
108
+ def forward(self, x):
109
+ shortcut = x
110
+ x = self.norm1(x)
111
+ x = self.att(x)
112
+ x = self.drop_shortcut(x)
113
+ x = x + shortcut
114
+
115
+ shortcut = x
116
+ x = self.norm2(x)
117
+ x = self.ff(x)
118
+ x = self.drop_shortcut(x)
119
+ x = x + shortcut
120
+ return x
121
+
122
+
123
+ class GPTModel(nn.Module):
124
+ def __init__(self, cfg):
125
+ super().__init__()
126
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
127
+ self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
128
+ self.drop_emb = nn.Dropout(cfg.get("drop_rate", 0.0))
129
+
130
+ self.trf_blocks = nn.Sequential(
131
+ *[TransformerBlock(cfg) for _ in range(cfg["n_layers"]) ]
132
+ )
133
+
134
+ self.final_norm = LayerNorm(cfg["emb_dim"])
135
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
136
+
137
+ def forward(self, in_idx):
138
+ batch_size, seq_len = in_idx.shape
139
+ tok_embeds = self.tok_emb(in_idx)
140
+ pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
141
+ x = tok_embeds + pos_embeds
142
+ x = self.drop_emb(x)
143
+ x = self.trf_blocks(x)
144
+ x = self.final_norm(x)
145
+ logits = self.out_head(x)
146
+ return logits
147
+
148
+
149
+ model_path = "review_classifier_model.pth"
150
+ if not os.path.exists(model_path):
151
+ raise FileNotFoundError(f"{model_path} not found. Please check the path.")
152
+
153
+ try:
154
+ loaded_full = None
155
+ safe_ctx = getattr(torch.serialization, "safe_globals", None)
156
+ add_safe = getattr(torch.serialization, "add_safe_globals", None)
157
+
158
+ if safe_ctx is not None:
159
+ try:
160
+ with torch.serialization.safe_globals([GPTModel]):
161
+ loaded_full = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
162
+ except Exception:
163
+ loaded_full = None
164
+ elif add_safe is not None:
165
+ try:
166
+ # older helper: register globally then load
167
+ torch.serialization.add_safe_globals([GPTModel])
168
+ loaded_full = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
169
+ except Exception:
170
+ loaded_full = None
171
+ else:
172
+ # If neither helper exists, try loading with weights_only=False (may execute code during unpickle).
173
+ try:
174
+ loaded_full = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
175
+ except Exception:
176
+ loaded_full = None
177
+
178
+ if loaded_full is not None and hasattr(loaded_full, "state_dict") and not isinstance(loaded_full, dict):
179
+ model = loaded_full
180
+ print(f"Loaded full model object from {model_path}")
181
+ else:
182
+ state = None
183
+ try:
184
+ state = torch.load(model_path, map_location=torch.device("cpu"), weights_only=True)
185
+ except Exception:
186
+ try:
187
+ if safe_ctx is not None:
188
+ with torch.serialization.safe_globals([GPTModel]):
189
+ tmp = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
190
+ elif add_safe is not None:
191
+ torch.serialization.add_safe_globals([GPTModel])
192
+ tmp = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
193
+ else:
194
+ tmp = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
195
+
196
+ if hasattr(tmp, "state_dict"):
197
+ state = tmp.state_dict()
198
+ else:
199
+ state = tmp
200
+ except Exception as e:
201
+ raise RuntimeError(f"Unable to load checkpoint as full model or weights-only. Last error: {e}")
202
+
203
+ if isinstance(state, dict):
204
+ print("Attempting to load checkpoint state into a GPTModel instance...")
205
+ BASE_CONFIG = {
206
+ "vocab_size": 50257,
207
+ "context_length": 1024,
208
+ "drop_rate": 0.0,
209
+ "qkv_bias": True,
210
+ "emb_dim": 768,
211
+ "n_layers": 12,
212
+ "n_heads": 12,
213
+ }
214
+ model = GPTModel(BASE_CONFIG)
215
+
216
+ if "model_state_dict" in state:
217
+ state_dict = state["model_state_dict"]
218
+ elif "state_dict" in state:
219
+ state_dict = state["state_dict"]
220
+ else:
221
+ state_dict = state
222
+
223
+ model.load_state_dict(state_dict, strict=False)
224
+ print("Loaded state_dict into GPTModel instance (non-strict).")
225
+ else:
226
+ raise RuntimeError("Unrecognized checkpoint format and unable to construct model from checkpoint.")
227
+ except Exception as e:
228
+ raise RuntimeError(f"Failed to load model checkpoint: {e}")
229
+
230
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
231
+ model.to(device)
232
+ model.eval()
233
+ print(f"Model loaded and moved to {device}")
234
+
235
+ tokenizer_path = "tokenizer.pkl"
236
+ if os.path.exists(tokenizer_path):
237
+ with open(tokenizer_path, "rb") as f:
238
+ tokenizer = pickle.load(f)
239
+ print(f"Tokenizer loaded from {tokenizer_path}")
240
+ else:
241
+ raise FileNotFoundError(f"{tokenizer_path} not found. Please check the path.")
242
+
243
+ MAX_SEQUENCE_LENGTH = 120
244
+
245
+ def classify_review(text, model, tokenizer_obj, device, max_length=MAX_SEQUENCE_LENGTH, pad_token_id=50256):
246
+ model.eval()
247
+ input_ids = tokenizer_obj.encode(text)
248
+ input_ids = input_ids[:max_length] + [pad_token_id] * (max_length - len(input_ids))
249
+ input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)
250
+ with torch.no_grad():
251
+ logits = model(input_tensor)[:, -1, :]
252
+ predicted_label = torch.argmax(logits, dim=-1).item()
253
+ return "spam" if predicted_label == 1 else "not spam"
254
+
255
+ def chatbot_classify(message, history):
256
+ result = classify_review(
257
+ message,
258
+ model,
259
+ tokenizer,
260
+ device,
261
+ max_length=MAX_SEQUENCE_LENGTH
262
+ )
263
+ return result
264
+
265
+ print("Launching Gradio interface...")
266
+
267
+ iface = gr.ChatInterface(
268
+ chatbot_classify,
269
+ title="📬 Spam Detection System",
270
+ description="Enter an SMS message below...",
271
+ theme="compact",
272
+ css="""
273
+ /* Customize chat bubble colors */
274
+ .chatbot-message {
275
+ background-color: #e0f7fa; /* Light cyan */
276
+ color: #006064; /* Dark teal text */
277
+ font-weight: 600;
278
+ border-radius: 12px;
279
+ padding: 12px;
280
+ }
281
+ .user-message {
282
+ background-color: #c8e6c9; /* Light green */
283
+ color: #1b5e20; /* Dark green text */
284
+ font-weight: 600;
285
+ border-radius: 12px;
286
+ padding: 12px;
287
+ }
288
+ .chat-ending-message {
289
+ font-style: italic;
290
+ color: #555;
291
+ }
292
+ """,
293
+ )
294
+
295
+ ICON_CDN = "https://img.icons8.com/color/48/mail-envelope.png"
296
+
297
+ custom_head_html = f"""
298
+ <link rel="icon" href="{ICON_CDN}" type="image/x-icon">
299
+ """
300
+
301
+ iface.launch(
302
+ share=True,
303
+ favicon_path=ICON_CDN
304
+ )