AmitHirpara commited on
Commit
de46a17
Β·
1 Parent(s): 94f7fb3

add binary files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,6 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.pkl filter=lfs diff=lfs merge=lfs -text
4
+ saved_lstm/** filter=lfs diff=lfs merge=lfs -text
5
+ saved_transformer/** filter=lfs diff=lfs merge=lfs -text
6
+ *.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import pickle
6
+ import os
7
+ import math
8
+ from typing import List, Tuple
9
+ from collections import Counter
10
+ import warnings
11
+ warnings.filterwarnings('ignore')
12
+
13
+ # Define the Vocabulary class (needed for unpickling)
14
+ class Vocabulary:
15
+ """Vocabulary class for encoding/decoding text and labels"""
16
+ def __init__(self, max_size=100000):
17
+ self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}
18
+ self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}
19
+ self.word_count = Counter()
20
+ self.max_size = max_size
21
+
22
+ def add_sentence(self, sentence):
23
+ for word in sentence:
24
+ self.word_count[word.lower()] += 1
25
+
26
+ def build(self):
27
+ most_common = self.word_count.most_common(self.max_size - len(self.word2idx))
28
+ for word, _ in most_common:
29
+ if word not in self.word2idx:
30
+ idx = len(self.word2idx)
31
+ self.word2idx[word] = idx
32
+ self.idx2word[idx] = word
33
+
34
+ def __len__(self):
35
+ return len(self.word2idx)
36
+
37
+ def encode(self, sentence):
38
+ return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]
39
+
40
+ def decode(self, indices):
41
+ return [self.idx2word.get(idx, '<unk>') for idx in indices]
42
+
43
+ # Custom Transformer components to match the saved model
44
+ class MultiHeadAttention(nn.Module):
45
+ def __init__(self, d_model, num_heads, dropout=0.1):
46
+ super().__init__()
47
+ assert d_model % num_heads == 0
48
+ self.d_model = d_model
49
+ self.num_heads = num_heads
50
+ self.d_k = d_model // num_heads
51
+
52
+ self.w_q = nn.Linear(d_model, d_model)
53
+ self.w_k = nn.Linear(d_model, d_model)
54
+ self.w_v = nn.Linear(d_model, d_model)
55
+ self.w_o = nn.Linear(d_model, d_model)
56
+
57
+ self.dropout = nn.Dropout(dropout)
58
+
59
+ def forward(self, query, key, value, mask=None):
60
+ batch_size = query.size(0)
61
+
62
+ # Linear transformations and split into heads
63
+ Q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
64
+ K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
65
+ V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
66
+
67
+ # Attention
68
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
69
+
70
+ if mask is not None:
71
+ mask = mask.unsqueeze(1).unsqueeze(1)
72
+ scores = scores.masked_fill(mask, -1e9)
73
+
74
+ attention = F.softmax(scores, dim=-1)
75
+ attention = self.dropout(attention)
76
+
77
+ context = torch.matmul(attention, V)
78
+
79
+ # Concatenate heads
80
+ context = context.transpose(1, 2).contiguous().view(
81
+ batch_size, -1, self.d_model
82
+ )
83
+
84
+ output = self.w_o(context)
85
+ return output
86
+
87
+ class FeedForward(nn.Module):
88
+ def __init__(self, d_model, d_ff, dropout=0.1):
89
+ super().__init__()
90
+ self.w_1 = nn.Linear(d_model, d_ff)
91
+ self.w_2 = nn.Linear(d_ff, d_model)
92
+ self.dropout = nn.Dropout(dropout)
93
+
94
+ def forward(self, x):
95
+ return self.w_2(self.dropout(F.gelu(self.w_1(x))))
96
+
97
+ class EncoderLayer(nn.Module):
98
+ def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
99
+ super().__init__()
100
+ self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
101
+ self.feed_forward = FeedForward(d_model, d_ff, dropout)
102
+ self.norm1 = nn.LayerNorm(d_model)
103
+ self.norm2 = nn.LayerNorm(d_model)
104
+ self.dropout = nn.Dropout(dropout)
105
+
106
+ def forward(self, x, mask=None):
107
+ # Self-attention with residual connection and layer norm
108
+ attn_output = self.self_attention(x, x, x, mask)
109
+ x = self.norm1(x + self.dropout(attn_output))
110
+
111
+ # Feed forward with residual connection and layer norm
112
+ ff_output = self.feed_forward(x)
113
+ x = self.norm2(x + self.dropout(ff_output))
114
+
115
+ return x
116
+
117
+ class TransformerEncoder(nn.Module):
118
+ def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
119
+ super().__init__()
120
+ self.layers = nn.ModuleList([
121
+ EncoderLayer(d_model, num_heads, d_ff, dropout)
122
+ for _ in range(num_layers)
123
+ ])
124
+ self.norm = nn.LayerNorm(d_model)
125
+
126
+ def forward(self, x, mask=None):
127
+ for layer in self.layers:
128
+ x = layer(x, mask)
129
+ return self.norm(x)
130
+
131
+ class PositionalEncoding(nn.Module):
132
+ def __init__(self, d_model, max_len=5000):
133
+ super().__init__()
134
+ self.d_model = d_model
135
+ pe = torch.zeros(max_len, d_model)
136
+ position = torch.arange(0, max_len).unsqueeze(1).float()
137
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
138
+ -(torch.log(torch.tensor(10000.0)) / d_model))
139
+ pe[:, 0::2] = torch.sin(position * div_term)
140
+ pe[:, 1::2] = torch.cos(position * div_term)
141
+ self.register_buffer('pe', pe.unsqueeze(0))
142
+
143
+ def forward(self, x):
144
+ return x * torch.sqrt(torch.tensor(self.d_model, dtype=x.dtype)) + self.pe[:, :x.size(1)]
145
+
146
+ class TransformerPIIDetector(nn.Module):
147
+ def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
148
+ d_ff=512, num_layers=4, dropout=0.1, max_len=512):
149
+ super().__init__()
150
+
151
+ self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
152
+ self.positional_encoding = PositionalEncoding(d_model, max_len) # Changed name to match saved model
153
+ self.dropout = nn.Dropout(dropout)
154
+
155
+ # Custom encoder to match saved model structure
156
+ self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
157
+ self.classifier = nn.Linear(d_model, num_classes)
158
+
159
+ def forward(self, x):
160
+ padding_mask = (x == 0)
161
+ x = self.embedding(x)
162
+ x = self.positional_encoding(x)
163
+ x = self.dropout(x)
164
+ x = self.encoder(x, padding_mask)
165
+ return self.classifier(x)
166
+
167
+ def create_transformer_pii_model(**kwargs):
168
+ return TransformerPIIDetector(**kwargs)
169
+
170
+ class PIIDetector:
171
+ def __init__(self, model_dir='saved_transformer'):
172
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
173
+ self.model_dir = model_dir
174
+ self.model = None
175
+ self.text_vocab = None
176
+ self.label_vocab = None
177
+ self.load_model()
178
+
179
+ # Single color for all PII highlighting
180
+ self.highlight_color = '#FF6B6B'
181
+
182
+ def load_model(self):
183
+ """Load the trained model and vocabularies"""
184
+ try:
185
+ # Load vocabularies
186
+ vocab_path = os.path.join(self.model_dir, 'vocabularies.pkl')
187
+ with open(vocab_path, 'rb') as f:
188
+ vocabs = pickle.load(f)
189
+ self.text_vocab = vocabs['text_vocab']
190
+ self.label_vocab = vocabs['label_vocab']
191
+
192
+ # Load model configuration
193
+ config_path = os.path.join(self.model_dir, 'model_config.pkl')
194
+ with open(config_path, 'rb') as f:
195
+ model_config = pickle.load(f)
196
+
197
+ # Create and load model
198
+ self.model = create_transformer_pii_model(**model_config)
199
+ model_path = os.path.join(self.model_dir, 'pii_transformer_model.pt')
200
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
201
+ self.model.to(self.device)
202
+ self.model.eval()
203
+
204
+ print(f"Model loaded successfully from {self.model_dir}")
205
+ print(f"Using device: {self.device}")
206
+
207
+ except Exception as e:
208
+ print(f"Error loading model: {str(e)}")
209
+ raise
210
+
211
+ def tokenize(self, text: str) -> List[str]:
212
+ """Simple tokenization by splitting on spaces and punctuation"""
213
+ import re
214
+ # Split on whitespace and keep punctuation as separate tokens
215
+ tokens = re.findall(r'\w+|[^\w\s]', text)
216
+ return tokens
217
+
218
+ def predict(self, text: str) -> List[Tuple[str, str]]:
219
+ """Predict PII labels for input text"""
220
+ if not text.strip():
221
+ return []
222
+
223
+ # Tokenize
224
+ tokens = self.tokenize(text)
225
+
226
+ # Add start and end tokens
227
+ tokens_with_special = ['<start>'] + tokens + ['<end>']
228
+
229
+ # Encode tokens
230
+ token_ids = self.text_vocab.encode(tokens_with_special)
231
+
232
+ # Convert to tensor and add batch dimension
233
+ input_tensor = torch.tensor([token_ids]).to(self.device)
234
+
235
+ # Predict
236
+ with torch.no_grad():
237
+ outputs = self.model(input_tensor)
238
+ predictions = torch.argmax(outputs, dim=-1)
239
+
240
+ # Decode predictions (skip start and end tokens)
241
+ predicted_labels = []
242
+ for idx in predictions[0][1:-1]: # Skip <start> and <end>
243
+ label = self.label_vocab.idx2word.get(idx.item(), 'O')
244
+ predicted_labels.append(label.upper())
245
+
246
+ # Pair tokens with their labels
247
+ return list(zip(tokens, predicted_labels))
248
+
249
+ def create_highlighted_html(self, token_label_pairs: List[Tuple[str, str]]) -> str:
250
+ """Create HTML with highlighted PII entities"""
251
+ html_parts = ['<div style="font-family: Arial, sans-serif; line-height: 1.8; padding: 20px; background-color: white; border-radius: 8px; color: black;">']
252
+
253
+ i = 0
254
+ while i < len(token_label_pairs):
255
+ token, label = token_label_pairs[i]
256
+
257
+ # Check if this is the start of a PII entity
258
+ if label != 'O':
259
+ # Collect all tokens for this entity
260
+ entity_tokens = [token]
261
+ entity_label = label
262
+ j = i + 1
263
+
264
+ # Look for continuation tokens (I- tags)
265
+ while j < len(token_label_pairs):
266
+ next_token, next_label = token_label_pairs[j]
267
+ if next_label.startswith('I-') and next_label.replace('I-', 'B-') == entity_label:
268
+ entity_tokens.append(next_token)
269
+ j += 1
270
+ else:
271
+ break
272
+
273
+ # Join tokens with appropriate spacing
274
+ entity_text = ''
275
+ for k, tok in enumerate(entity_tokens):
276
+ if k > 0 and tok not in '.,!?;:':
277
+ entity_text += ' '
278
+ entity_text += tok
279
+
280
+ # Add highlighted entity
281
+ label_display = entity_label.replace('B-', '').replace('I-', '').replace('_', ' ')
282
+ html_parts.append(
283
+ f'<mark style="background-color: {self.highlight_color}; padding: 2px 4px; '
284
+ f'border-radius: 3px; margin: 0 2px; font-weight: 500;" '
285
+ f'title="{label_display}">{entity_text}</mark>'
286
+ )
287
+
288
+ i = j
289
+ else:
290
+ # Add space before token if needed
291
+ if i > 0 and token not in '.,!?;:' and len(token_label_pairs) > i-1:
292
+ prev_token, _ = token_label_pairs[i-1]
293
+ if prev_token not in '(':
294
+ html_parts.append(' ')
295
+
296
+ html_parts.append(f'<span style="color: black;">{token}</span>')
297
+ i += 1
298
+
299
+ html_parts.append('</div>')
300
+
301
+ return ''.join(html_parts)
302
+
303
+ def get_statistics(self, token_label_pairs: List[Tuple[str, str]]) -> str:
304
+ """Generate statistics about detected PII"""
305
+ stats = {}
306
+ total_tokens = len(token_label_pairs)
307
+ pii_tokens = 0
308
+
309
+ for _, label in token_label_pairs:
310
+ if label != 'O':
311
+ pii_tokens += 1
312
+ # Clean up label for display
313
+ label_clean = label.replace('B-', '').replace('I-', '').replace('_', ' ')
314
+ stats[label_clean] = stats.get(label_clean, 0) + 1
315
+
316
+ # Create statistics text
317
+ stats_text = f"### Detection Summary\n\n"
318
+ stats_text += f"**Total tokens:** {total_tokens}\n\n"
319
+ stats_text += f"**PII tokens:** {pii_tokens} ({pii_tokens/total_tokens*100:.1f}%)\n\n"
320
+
321
+ if not stats:
322
+ stats_text += "*No PII detected in the text.*"
323
+
324
+ return stats_text
325
+
326
+ # Initialize the detector
327
+ print("Initializing PII Detector...")
328
+ detector = PIIDetector()
329
+
330
+ def detect_pii(text):
331
+ """Main function for Gradio interface"""
332
+ if not text:
333
+ return "<p style='color: #6c757d; padding: 20px;'>Please enter some text to analyze.</p>", "No text provided."
334
+
335
+ try:
336
+ # Get predictions
337
+ token_label_pairs = detector.predict(text)
338
+
339
+ # Create highlighted HTML
340
+ highlighted_html = detector.create_highlighted_html(token_label_pairs)
341
+
342
+ # Get statistics
343
+ stats = detector.get_statistics(token_label_pairs)
344
+
345
+ return highlighted_html, stats
346
+
347
+ except Exception as e:
348
+ error_html = f'<div style="color: #dc3545; padding: 20px; background-color: #f8d7da; border-radius: 8px;">Error: {str(e)}</div>'
349
+ error_stats = f"Error occurred: {str(e)}"
350
+ return error_html, error_stats
351
+
352
+ # Example texts
353
+ examples = [
354
+ "My name is John Smith and my email is john.smith@email.com. You can reach me at 555-123-4567.",
355
+ "Student ID: 12345678. Please send the documents to 123 Main Street, Anytown, USA 12345.",
356
+ "Contact Sarah Johnson at sarah_j_2023@gmail.com or visit her profile at linkedin.com/in/sarahjohnson",
357
+ "The project was completed by student A1B2C3D4 who lives at 456 Oak Avenue.",
358
+ "For verification, my phone number is (555) 987-6543 and my username is cool_user_99.",
359
+ "Hi, I'm Emily Chen. My student number is STU-2023-98765 and I live at 789 Pine Street, Apt 4B.",
360
+ "You can reach me at my personal website: www.johndoe.com or call me at +1-555-0123.",
361
+ ]
362
+
363
+ # Create Gradio interface
364
+ with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
365
+ gr.Markdown(
366
+ """
367
+ # πŸ”’ PII Detection System
368
+
369
+ Enter or paste text below to analyze it for PII content.
370
+ """
371
+ )
372
+
373
+ with gr.Column():
374
+ input_text = gr.Textbox(
375
+ label="Input Text",
376
+ placeholder="Enter text to analyze for PII...",
377
+ lines=8,
378
+ max_lines=20
379
+ )
380
+
381
+ with gr.Row():
382
+ analyze_btn = gr.Button("πŸ” Detect PII", variant="primary", scale=2)
383
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", scale=1)
384
+
385
+ highlighted_output = gr.HTML(
386
+ label="Highlighted Text",
387
+ value="<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>"
388
+ )
389
+
390
+ stats_output = gr.Markdown(
391
+ label="Detection Statistics",
392
+ value="*Statistics will appear here...*"
393
+ )
394
+
395
+ # Set up event handlers
396
+ analyze_btn.click(
397
+ fn=detect_pii,
398
+ inputs=[input_text],
399
+ outputs=[highlighted_output, stats_output]
400
+ )
401
+
402
+ clear_btn.click(
403
+ fn=lambda: ("", "<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>", "*Statistics will appear here...*"),
404
+ outputs=[input_text, highlighted_output, stats_output]
405
+ )
406
+
407
+ # Launch the app
408
+ if __name__ == "__main__":
409
+ print("\nLaunching Gradio interface...")
410
+ demo.launch(
411
+ share=False,
412
+ server_name="127.0.0.1",
413
+ server_port=7860,
414
+ show_error=True
415
+ )
data_augmentation.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from faker import Faker
3
+ import pandas as pd
4
+ import numpy as np
5
+ from collections import Counter
6
+ import torch
7
+
8
+ class PIIDataAugmenter:
9
+ """
10
+ Generates synthetic PII examples to augment training data.
11
+
12
+ This class creates realistic examples of various PII types including names,
13
+ emails, phone numbers, addresses, IDs, URLs, and usernames.
14
+ """
15
+
16
+ def __init__(self, seed=42):
17
+ """Initialize the augmenter with random seeds for reproducibility."""
18
+ random.seed(seed)
19
+ np.random.seed(seed)
20
+ self.fake = Faker()
21
+ Faker.seed(seed)
22
+
23
+ self._init_templates()
24
+ self._init_context_phrases()
25
+ self._init_generators()
26
+
27
+ def _init_templates(self):
28
+ """Initialize templates for different PII types."""
29
+ self.templates = {
30
+ 'NAME_STUDENT': [
31
+ "My name is {name}",
32
+ "I am {name}",
33
+ "This is {name} speaking",
34
+ "Student: {name}",
35
+ "{name} here",
36
+ "Submitted by {name}",
37
+ "Author: {name}",
38
+ "Contact {name} for more information",
39
+ "Please call {name}",
40
+ "{name} is my name"
41
+ ],
42
+ 'EMAIL': [
43
+ "Email me at {email}",
44
+ "My email is {email}",
45
+ "Contact: {email}",
46
+ "Send to {email}",
47
+ "Reach me at {email}",
48
+ "Email address: {email}",
49
+ "You can email {email}",
50
+ "Write to {email}",
51
+ "My contact email is {email}",
52
+ "Send your response to {email}"
53
+ ],
54
+ 'PHONE_NUM': [
55
+ "Call me at {phone}",
56
+ "My phone number is {phone}",
57
+ "Phone: {phone}",
58
+ "Contact number: {phone}",
59
+ "Reach me at {phone}",
60
+ "My number is {phone}",
61
+ "You can call {phone}",
62
+ "Mobile: {phone}",
63
+ "Tel: {phone}",
64
+ "Phone contact: {phone}"
65
+ ],
66
+ 'STREET_ADDRESS': [
67
+ "I live at {address}",
68
+ "My address is {address}",
69
+ "Located at {address}",
70
+ "Address: {address}",
71
+ "Find me at {address}",
72
+ "Residence: {address}",
73
+ "Mail to {address}",
74
+ "Home address: {address}",
75
+ "Visit us at {address}",
76
+ "Ship to {address}"
77
+ ],
78
+ 'ID_NUM': [
79
+ "ID: {id_num}",
80
+ "Student ID: {id_num}",
81
+ "ID number {id_num}",
82
+ "Reference number: {id_num}",
83
+ "Account: {id_num}",
84
+ "Member ID: {id_num}",
85
+ "Registration: {id_num}",
86
+ "Code: {id_num}",
87
+ "Identification: {id_num}",
88
+ "Number: {id_num}"
89
+ ],
90
+ 'URL_PERSONAL': [
91
+ "Visit my website at {url}",
92
+ "Check out {url}",
93
+ "My portfolio: {url}",
94
+ "Website: {url}",
95
+ "Link: {url}",
96
+ "Find me online at {url}",
97
+ "Personal site: {url}",
98
+ "URL: {url}",
99
+ "Web: {url}",
100
+ "Online at {url}"
101
+ ],
102
+ 'USERNAME': [
103
+ "Username: {username}",
104
+ "User: {username}",
105
+ "Handle: {username}",
106
+ "My username is {username}",
107
+ "Find me as {username}",
108
+ "Account: {username}",
109
+ "Login: {username}",
110
+ "Profile: {username}",
111
+ "Known as {username}",
112
+ "Tag me @{username}"
113
+ ]
114
+ }
115
+
116
+ def _init_context_phrases(self):
117
+ """Initialize context phrases for more natural text generation."""
118
+ self.context_prefix = [
119
+ "Hello everyone,",
120
+ "Dear Sir/Madam,",
121
+ "To whom it may concern,",
122
+ "Please note that",
123
+ "For your reference,",
124
+ "As requested,",
125
+ "I would like to inform you that",
126
+ "This is to confirm that",
127
+ "Please be advised that",
128
+ "I am writing to tell you that"
129
+ ]
130
+
131
+ self.context_suffix = [
132
+ "Thank you.",
133
+ "Best regards.",
134
+ "Please let me know if you need anything else.",
135
+ "Looking forward to your response.",
136
+ "Have a great day!",
137
+ "Thanks for your attention.",
138
+ "Feel free to contact me.",
139
+ "I appreciate your help.",
140
+ "Hope this helps.",
141
+ "Let me know if you have questions."
142
+ ]
143
+
144
+ self.connectors = [
145
+ " and ", " or ", ", ", ". Also, ", ". Additionally, "
146
+ ]
147
+
148
+ def _init_generators(self):
149
+ """Initialize PII generators mapping."""
150
+ self.generators = {
151
+ 'NAME_STUDENT': self.generate_name,
152
+ 'EMAIL': self.generate_email,
153
+ 'PHONE_NUM': self.generate_phone,
154
+ 'STREET_ADDRESS': self.generate_address,
155
+ 'ID_NUM': self.generate_id_num,
156
+ 'URL_PERSONAL': self.generate_url,
157
+ 'USERNAME': self.generate_username
158
+ }
159
+
160
+ self.format_keys = {
161
+ 'NAME_STUDENT': 'name',
162
+ 'EMAIL': 'email',
163
+ 'PHONE_NUM': 'phone',
164
+ 'STREET_ADDRESS': 'address',
165
+ 'ID_NUM': 'id_num',
166
+ 'URL_PERSONAL': 'url',
167
+ 'USERNAME': 'username'
168
+ }
169
+
170
+ # ========== PII Generators ==========
171
+
172
+ def generate_name(self):
173
+ """Generate realistic person names."""
174
+ return self.fake.name()
175
+
176
+ def generate_email(self):
177
+ """Generate realistic email addresses."""
178
+ return self.fake.email()
179
+
180
+ def generate_phone(self):
181
+ """Generate realistic phone numbers in various formats."""
182
+ formats = [
183
+ "555-{:03d}-{:04d}",
184
+ "(555) {:03d}-{:04d}",
185
+ "555.{:03d}.{:04d}",
186
+ "+1-555-{:03d}-{:04d}",
187
+ "555{:03d}{:04d}"
188
+ ]
189
+ format_choice = random.choice(formats)
190
+ area = random.randint(100, 999)
191
+ number = random.randint(1000, 9999)
192
+ return format_choice.format(area, number)
193
+
194
+ def generate_address(self):
195
+ """Generate realistic street addresses."""
196
+ return self.fake.address().replace('\n', ', ')
197
+
198
+ def generate_id_num(self):
199
+ """Generate various ID number formats."""
200
+ formats = [
201
+ "{:06d}", # 6-digit ID
202
+ "{:08d}", # 8-digit ID
203
+ "ID{:05d}", # ID prefix
204
+ "STU{:06d}", # Student ID
205
+ "{:04d}-{:04d}", # Hyphenated
206
+ "A{:07d}", # Letter prefix
207
+ ]
208
+ format_choice = random.choice(formats)
209
+
210
+ if '-' in format_choice:
211
+ return format_choice.format(
212
+ random.randint(1000, 9999),
213
+ random.randint(1000, 9999)
214
+ )
215
+ else:
216
+ return format_choice.format(random.randint(10000, 9999999))
217
+
218
+ def generate_url(self):
219
+ """Generate personal website URLs."""
220
+ domains = ['github.com', 'linkedin.com', 'portfolio.com',
221
+ 'personal.com', 'website.com']
222
+ username = self.fake.user_name()
223
+ domain = random.choice(domains)
224
+ return f"https://{domain}/{username}"
225
+
226
+ def generate_username(self):
227
+ """Generate usernames."""
228
+ return self.fake.user_name()
229
+
230
+ # ========== Synthetic Example Creation ==========
231
+
232
+ def create_synthetic_example(self, pii_type, add_context=True):
233
+ """
234
+ Create a synthetic example with proper BIO labeling.
235
+
236
+ Args:
237
+ pii_type: Type of PII to generate
238
+ add_context: Whether to add context phrases
239
+
240
+ Returns:
241
+ Tuple of (tokens, labels)
242
+ """
243
+ # Generate PII value
244
+ pii_value = self.generators[pii_type]()
245
+
246
+ # Select and fill template
247
+ template = random.choice(self.templates[pii_type])
248
+ format_key = self.format_keys[pii_type]
249
+ sentence = template.format(**{format_key: pii_value})
250
+
251
+ # Add context if requested
252
+ if add_context and random.random() > 0.3:
253
+ sentence = self._add_context(sentence)
254
+
255
+ # Tokenize and label
256
+ tokens, labels = self._tokenize_and_label(sentence, pii_value, pii_type)
257
+
258
+ return tokens, labels
259
+
260
+ def create_mixed_example(self, pii_types, num_pii=2):
261
+ """
262
+ Create examples with multiple PII types.
263
+
264
+ Args:
265
+ pii_types: List of PII types to include
266
+ num_pii: Number of PII entities to include
267
+
268
+ Returns:
269
+ Tuple of (tokens, labels)
270
+ """
271
+ selected_types = random.sample(pii_types, min(num_pii, len(pii_types)))
272
+
273
+ all_tokens = []
274
+ all_labels = []
275
+
276
+ # Add context prefix
277
+ if random.random() > 0.3:
278
+ prefix = random.choice(self.context_prefix)
279
+ all_tokens.extend(prefix.split())
280
+ all_labels.extend(['O'] * len(prefix.split()))
281
+
282
+ # Add each PII with connectors
283
+ for i, pii_type in enumerate(selected_types):
284
+ # Add connector between PIIs
285
+ if i > 0 and random.random() > 0.5:
286
+ connector = random.choice(self.connectors)
287
+ all_tokens.extend(connector.strip().split())
288
+ all_labels.extend(['O'] * len(connector.strip().split()))
289
+
290
+ # Create PII example without additional context
291
+ tokens, labels = self.create_synthetic_example(pii_type, add_context=False)
292
+ all_tokens.extend(tokens)
293
+ all_labels.extend(labels)
294
+
295
+ # Add context suffix
296
+ if random.random() > 0.3:
297
+ suffix = random.choice(self.context_suffix)
298
+ all_tokens.extend(suffix.split())
299
+ all_labels.extend(['O'] * len(suffix.split()))
300
+
301
+ return all_tokens, all_labels
302
+
303
+ def _add_context(self, sentence):
304
+ """Add context phrases to make text more natural."""
305
+ if random.random() > 0.5:
306
+ sentence = random.choice(self.context_prefix) + " " + sentence
307
+ if random.random() > 0.5:
308
+ sentence = sentence + " " + random.choice(self.context_suffix)
309
+ return sentence
310
+
311
+ def _tokenize_and_label(self, sentence, pii_value, pii_type):
312
+ """
313
+ Tokenize sentence and apply BIO labels for PII.
314
+
315
+ Args:
316
+ sentence: The sentence containing PII
317
+ pii_value: The PII value to find and label
318
+ pii_type: The type of PII for labeling
319
+
320
+ Returns:
321
+ Tuple of (tokens, labels)
322
+ """
323
+ tokens = sentence.split()
324
+ labels = ['O'] * len(tokens)
325
+
326
+ # Tokenize PII value
327
+ pii_tokens = pii_value.split()
328
+
329
+ # Find and label PII in the sentence
330
+ for i in range(len(tokens) - len(pii_tokens) + 1):
331
+ # Check if tokens match PII value
332
+ if (tokens[i:i+len(pii_tokens)] == pii_tokens or
333
+ ' '.join(tokens[i:i+len(pii_tokens)]).lower() == pii_value.lower()):
334
+
335
+ # Apply BIO labels
336
+ labels[i] = f'B-{pii_type}'
337
+ for j in range(1, len(pii_tokens)):
338
+ labels[i+j] = f'I-{pii_type}'
339
+ break
340
+
341
+ return tokens, labels
342
+
343
+ # ========== Dataset Augmentation ==========
344
+
345
+ def augment_dataset(self, original_data, target_samples_per_class=1000, mix_ratio=0.3):
346
+ """
347
+ Augment dataset with synthetic examples to balance PII classes.
348
+
349
+ Args:
350
+ original_data: Original dataset DataFrame
351
+ target_samples_per_class: Target number of samples per PII class
352
+ mix_ratio: Ratio of mixed (multi-PII) examples
353
+
354
+ Returns:
355
+ Augmented dataset DataFrame
356
+ """
357
+ # Analyze original distribution
358
+ label_counts = self._analyze_label_distribution(original_data)
359
+ print("\nOriginal label distribution:")
360
+ self._print_distribution(label_counts)
361
+
362
+ # Generate synthetic examples
363
+ synthetic_tokens, synthetic_labels = self._generate_synthetic_data(
364
+ label_counts, target_samples_per_class, mix_ratio
365
+ )
366
+
367
+ # Add non-PII examples
368
+ synthetic_tokens, synthetic_labels = self._add_non_pii_examples(
369
+ synthetic_tokens, synthetic_labels
370
+ )
371
+
372
+ # Combine and shuffle data
373
+ augmented_df = self._combine_and_shuffle(
374
+ original_data, synthetic_tokens, synthetic_labels
375
+ )
376
+
377
+ # Analyze new distribution
378
+ new_label_counts = self._analyze_label_distribution(augmented_df)
379
+ print("\nAugmented label distribution:")
380
+ self._print_distribution(new_label_counts)
381
+
382
+ return augmented_df
383
+
384
+ def _analyze_label_distribution(self, data):
385
+ """Analyze the distribution of PII labels in the dataset."""
386
+ label_counts = Counter()
387
+
388
+ for labels in data['labels']:
389
+ for label in labels:
390
+ if label != 'O':
391
+ # Extract base label (remove B- or I- prefix)
392
+ base_label = label.split('-')[1] if '-' in label else label
393
+ label_counts[base_label] += 1
394
+
395
+ return label_counts
396
+
397
+ def _print_distribution(self, label_counts):
398
+ """Print label distribution statistics."""
399
+ total = sum(label_counts.values())
400
+ for label, count in label_counts.most_common():
401
+ percentage = (count / total * 100) if total > 0 else 0
402
+ print(f" {label:15} : {count:6,} ({percentage:5.2f}%)")
403
+
404
+ def _generate_synthetic_data(self, label_counts, target_samples, mix_ratio):
405
+ """Generate synthetic PII examples based on current distribution."""
406
+ synthetic_tokens = []
407
+ synthetic_labels = []
408
+
409
+ for pii_type in self.templates.keys():
410
+ current_count = label_counts.get(pii_type, 0)
411
+ needed = max(0, target_samples - current_count)
412
+
413
+ if needed == 0:
414
+ continue
415
+
416
+ print(f"\nGenerating {needed} synthetic examples for {pii_type}")
417
+
418
+ # Single PII examples
419
+ single_count = int(needed * (1 - mix_ratio))
420
+ for _ in range(single_count):
421
+ tokens, labels = self.create_synthetic_example(pii_type)
422
+ synthetic_tokens.append(tokens)
423
+ synthetic_labels.append(labels)
424
+
425
+ # Mixed PII examples
426
+ mixed_count = int(needed * mix_ratio)
427
+ for _ in range(mixed_count):
428
+ # Ensure current PII type is included
429
+ other_types = [t for t in self.templates.keys() if t != pii_type]
430
+ selected_types = [pii_type] + random.sample(
431
+ other_types, min(1, len(other_types))
432
+ )
433
+
434
+ tokens, labels = self.create_mixed_example(selected_types, num_pii=2)
435
+ synthetic_tokens.append(tokens)
436
+ synthetic_labels.append(labels)
437
+
438
+ return synthetic_tokens, synthetic_labels
439
+
440
+ def _add_non_pii_examples(self, synthetic_tokens, synthetic_labels):
441
+ """Add examples without PII (all 'O' labels) for balance."""
442
+ num_non_pii = int(len(synthetic_tokens) * 0.1)
443
+
444
+ for _ in range(num_non_pii):
445
+ # Generate random text without PII
446
+ sentence = self.fake.text(max_nb_chars=100)
447
+ tokens = sentence.split()
448
+ labels = ['O'] * len(tokens)
449
+
450
+ synthetic_tokens.append(tokens)
451
+ synthetic_labels.append(labels)
452
+
453
+ return synthetic_tokens, synthetic_labels
454
+
455
+ def _combine_and_shuffle(self, original_data, synthetic_tokens, synthetic_labels):
456
+ """Combine original and synthetic data, then shuffle."""
457
+ # Combine data
458
+ all_tokens = original_data['tokens'].tolist() + synthetic_tokens
459
+ all_labels = original_data['labels'].tolist() + synthetic_labels
460
+
461
+ # Create DataFrame
462
+ augmented_data = pd.DataFrame({
463
+ 'tokens': all_tokens,
464
+ 'labels': all_labels
465
+ })
466
+
467
+ # Shuffle
468
+ augmented_data = augmented_data.sample(frac=1, random_state=42).reset_index(drop=True)
469
+
470
+ print(f"\nTotal augmented samples: {len(augmented_data):,}")
471
+
472
+ return augmented_data
473
+
474
+ def calculate_class_weights(data, label_vocab):
475
+ """
476
+ Calculate class weights for balanced loss function.
477
+
478
+ Args:
479
+ data: Dataset DataFrame with 'labels' column
480
+ label_vocab: Vocabulary object with word2idx mapping
481
+
482
+ Returns:
483
+ Tensor of class weights
484
+ """
485
+ # Count label occurrences
486
+ label_counts = Counter()
487
+
488
+ for labels in data['labels']:
489
+ for label in labels:
490
+ label_id = label_vocab.word2idx.get(label.lower(), 0)
491
+ label_counts[label_id] += 1
492
+
493
+ # Calculate inverse frequency weights
494
+ total_samples = sum(label_counts.values())
495
+ num_classes = len(label_vocab)
496
+
497
+ weights = torch.zeros(num_classes)
498
+
499
+ for class_id, count in label_counts.items():
500
+ if count > 0:
501
+ # Inverse frequency with smoothing
502
+ weights[class_id] = total_samples / (num_classes * count)
503
+
504
+ # Normalize weights
505
+ weights = weights / weights.sum() * num_classes
506
+
507
+ # Cap extreme weights to prevent instability
508
+ weights = torch.clamp(weights, min=0.1, max=10.0)
509
+
510
+ # Set padding weight to 0
511
+ weights[0] = 0.0
512
+
513
+ return weights
514
+
515
+ if __name__ == '__main__':
516
+ """Example usage of the augmentation module."""
517
+ # Load original data
518
+ print("Loading original training data...")
519
+ original_data = pd.read_json('train.json')
520
+ print(f"Original dataset size: {len(original_data):,}")
521
+
522
+ # Initialize augmenter
523
+ augmenter = PIIDataAugmenter(seed=42)
524
+
525
+ # Augment dataset
526
+ print("\n" + "="*60)
527
+ print("Starting data augmentation...")
528
+ print("="*60)
529
+
530
+ augmented_data = augmenter.augment_dataset(
531
+ original_data,
532
+ target_samples_per_class=2000,
533
+ mix_ratio=0.3
534
+ )
535
+
536
+ # Save augmented data
537
+ output_path = './train_augmented.json'
538
+ augmented_data.to_json(output_path, orient='records', lines=True)
539
+ print(f"\nSaved augmented data to {output_path}")
lstm.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
5
+
6
+ class LSTMCell(nn.Module):
7
+ """
8
+ LSTM cell implementation from scratch
9
+ """
10
+ def __init__(self, input_size: int, hidden_size: int):
11
+ super().__init__()
12
+ self.input_size = input_size
13
+ self.hidden_size = hidden_size
14
+
15
+ # Initialize weight matrices and bias vectors for LSTM gates
16
+ # Input gate
17
+ self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
18
+ self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
19
+ self.b_i = nn.Parameter(torch.Tensor(hidden_size))
20
+
21
+ # Forget gate
22
+ self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
23
+ self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
24
+ self.b_f = nn.Parameter(torch.Tensor(hidden_size))
25
+
26
+ # Input node (candidate)
27
+ self.W_in = nn.Parameter(torch.Tensor(input_size, hidden_size))
28
+ self.W_hn = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
29
+ self.b_n = nn.Parameter(torch.Tensor(hidden_size))
30
+
31
+ # Output gate
32
+ self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
33
+ self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
34
+ self.b_o = nn.Parameter(torch.Tensor(hidden_size))
35
+
36
+ # Initialize all weights with xavier_uniform and biases with zeros
37
+ for name, param in self.named_parameters():
38
+ if 'W_' in name:
39
+ nn.init.xavier_uniform_(param)
40
+ elif 'b_' in name:
41
+ nn.init.zeros_(param)
42
+
43
+ def forward(self, input: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
44
+ """
45
+ Forward pass for one time step
46
+ Args:
47
+ input: input at current time step [batch_size, input_size]
48
+ states: tuple of (hidden_state, cell_state) from previous time step
49
+ both with shape [batch_size, hidden_size]
50
+ Returns:
51
+ new_hidden: updated hidden state [batch_size, hidden_size]
52
+ new_cell: updated cell state [batch_size, hidden_size]
53
+ """
54
+ hidden, cell = states
55
+
56
+ # Implement LSTM cell forward pass
57
+ # Forget gate: f_t = sigmoid(W_if @ x_t + W_hf @ h_{t-1} + b_f)
58
+ forget_gate = torch.sigmoid(torch.mm(input, self.W_if) + torch.mm(hidden, self.W_hf) + self.b_f)
59
+
60
+ # Input gate: i_t = sigmoid(W_ii @ x_t + W_hi @ h_{t-1} + b_i)
61
+ input_gate = torch.sigmoid(torch.mm(input, self.W_ii) + torch.mm(hidden, self.W_hi) + self.b_i)
62
+
63
+ # Input node values: n_t = tanh(W_in @ x_t + W_hn @ h_{t-1} + b_n)
64
+ candidate = torch.tanh(torch.mm(input, self.W_in) + torch.mm(hidden, self.W_hn) + self.b_n)
65
+
66
+ # Output gate: o_t = sigmoid(W_io @ x_t + W_ho @ h_{t-1} + b_o)
67
+ output_gate = torch.sigmoid(torch.mm(input, self.W_io) + torch.mm(hidden, self.W_ho) + self.b_o)
68
+
69
+ # Update cell state: c_t = f_t * c_{t-1} + i_t * n_t
70
+ new_cell = forget_gate * cell + input_gate * candidate
71
+
72
+ # Update hidden state: h_t = o_t * tanh(c_t)
73
+ new_hidden = output_gate * torch.tanh(new_cell)
74
+
75
+ return new_hidden, new_cell
76
+
77
+ class BidirectionalLSTM(nn.Module):
78
+ """
79
+ Multi-layer bidirectional LSTM implementation using custom LSTM cells
80
+ """
81
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
82
+ batch_first: bool = True, dropout: float = 0.0):
83
+ super().__init__()
84
+ self.input_size = input_size
85
+ self.hidden_size = hidden_size
86
+ self.num_layers = num_layers
87
+ self.batch_first = batch_first
88
+ self.dropout = dropout if num_layers > 1 else 0.0
89
+
90
+ # Create forward and backward cells for each layer
91
+ self.forward_cells = nn.ModuleList()
92
+ self.backward_cells = nn.ModuleList()
93
+ self.dropout_layers = nn.ModuleList() if self.dropout > 0 else None
94
+
95
+ for layer in range(num_layers):
96
+ # Input size is input_size for first layer, hidden_size * 2 for others (bidirectional)
97
+ layer_input_size = input_size if layer == 0 else hidden_size * 2
98
+
99
+ self.forward_cells.append(LSTMCell(layer_input_size, hidden_size))
100
+ self.backward_cells.append(LSTMCell(layer_input_size, hidden_size))
101
+
102
+ if self.dropout > 0 and layer < num_layers - 1:
103
+ self.dropout_layers.append(nn.Dropout(dropout))
104
+
105
+ def forward(self, input, states=None, lengths=None):
106
+ # Handle PackedSequence input
107
+ is_packed = isinstance(input, PackedSequence)
108
+ if is_packed:
109
+ padded, lengths = pad_packed_sequence(input, batch_first=self.batch_first)
110
+ outputs, (h_n, c_n) = self._forward_unpacked(padded, states, lengths)
111
+ packed_out = pack_padded_sequence(
112
+ outputs, lengths,
113
+ batch_first=self.batch_first,
114
+ enforce_sorted=False
115
+ )
116
+ return packed_out, (h_n, c_n)
117
+ else:
118
+ return self._forward_unpacked(input, states, lengths)
119
+
120
+ def _forward_unpacked(self, input: torch.Tensor, states, lengths=None):
121
+ if not self.batch_first:
122
+ input = input.transpose(0, 1)
123
+
124
+ batch_size, seq_len, _ = input.size()
125
+
126
+ # Initialize states if not provided
127
+ if states is None:
128
+ h_t_forward = [input.new_zeros(batch_size, self.hidden_size)
129
+ for _ in range(self.num_layers)]
130
+ c_t_forward = [input.new_zeros(batch_size, self.hidden_size)
131
+ for _ in range(self.num_layers)]
132
+ h_t_backward = [input.new_zeros(batch_size, self.hidden_size)
133
+ for _ in range(self.num_layers)]
134
+ c_t_backward = [input.new_zeros(batch_size, self.hidden_size)
135
+ for _ in range(self.num_layers)]
136
+ else:
137
+ h0, c0 = states
138
+ # h0 and c0 are [num_layers * 2, batch_size, hidden_size]
139
+ h_t_forward = []
140
+ c_t_forward = []
141
+ h_t_backward = []
142
+ c_t_backward = []
143
+
144
+ for layer in range(self.num_layers):
145
+ h_t_forward.append(h0[layer * 2])
146
+ c_t_forward.append(c0[layer * 2])
147
+ h_t_backward.append(h0[layer * 2 + 1])
148
+ c_t_backward.append(c0[layer * 2 + 1])
149
+
150
+ # Process through layers
151
+ layer_input = input
152
+ for layer_idx in range(self.num_layers):
153
+ # Forward direction
154
+ forward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
155
+ for t in range(seq_len):
156
+ x = layer_input[:, t, :]
157
+ h, c = self.forward_cells[layer_idx](x, (h_t_forward[layer_idx], c_t_forward[layer_idx]))
158
+ h_t_forward[layer_idx] = h
159
+ c_t_forward[layer_idx] = c
160
+ forward_output[:, t, :] = h
161
+
162
+ # Backward direction
163
+ backward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
164
+ for t in reversed(range(seq_len)):
165
+ x = layer_input[:, t, :]
166
+ h, c = self.backward_cells[layer_idx](x, (h_t_backward[layer_idx], c_t_backward[layer_idx]))
167
+ h_t_backward[layer_idx] = h
168
+ c_t_backward[layer_idx] = c
169
+ backward_output[:, t, :] = h
170
+
171
+ # Concatenate forward and backward
172
+ layer_output = torch.cat([forward_output, backward_output], dim=2)
173
+
174
+ # Apply dropout between layers (except last layer)
175
+ if self.dropout > 0 and layer_idx < self.num_layers - 1:
176
+ layer_output = self.dropout_layers[layer_idx](layer_output)
177
+
178
+ layer_input = layer_output
179
+
180
+ # Final output
181
+ outputs = layer_output
182
+
183
+ # Stack hidden and cell states
184
+ h_n = []
185
+ c_n = []
186
+ for layer in range(self.num_layers):
187
+ h_n.extend([h_t_forward[layer], h_t_backward[layer]])
188
+ c_n.extend([c_t_forward[layer], c_t_backward[layer]])
189
+ h_n = torch.stack(h_n, dim=0)
190
+ c_n = torch.stack(c_n, dim=0)
191
+
192
+ if not self.batch_first:
193
+ outputs = outputs.transpose(0, 1)
194
+
195
+ return outputs, (h_n, c_n)
196
+
197
+ class LSTM(nn.Module):
198
+ """
199
+ Bidirectional LSTM model for PII detection (sequence labeling)
200
+ """
201
+ def __init__(self, vocab_size: int, num_classes: int, embed_size: int = 128,
202
+ hidden_size: int = 256, num_layers: int = 2, dropout: float = 0.1,
203
+ max_len: int = 512):
204
+ super().__init__()
205
+
206
+ self.vocab_size = vocab_size
207
+ self.num_classes = num_classes
208
+ self.embed_size = embed_size
209
+ self.hidden_size = hidden_size
210
+ self.num_layers = num_layers
211
+
212
+ # Embedding layer
213
+ self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
214
+ self.embed_dropout = nn.Dropout(dropout)
215
+
216
+ # Bidirectional LSTM layers
217
+ self.lstm = BidirectionalLSTM(
218
+ input_size=embed_size,
219
+ hidden_size=hidden_size,
220
+ num_layers=num_layers,
221
+ batch_first=True,
222
+ dropout=dropout if num_layers > 1 else 0.0
223
+ )
224
+
225
+ # Output projection layer
226
+ lstm_output_size = hidden_size * 2 # bidirectional
227
+ self.fc = nn.Linear(lstm_output_size, num_classes)
228
+ self.output_dropout = nn.Dropout(dropout)
229
+
230
+ def forward(self, input_ids, lengths=None):
231
+ """
232
+ Forward pass
233
+ Args:
234
+ input_ids: token ids [batch_size, seq_len]
235
+ lengths: actual lengths of sequences (optional)
236
+ Returns:
237
+ logits: class predictions [batch_size, seq_len, num_classes]
238
+ """
239
+ # Embedding
240
+ embedded = self.embedding(input_ids) # [batch_size, seq_len, embed_size]
241
+ embedded = self.embed_dropout(embedded)
242
+
243
+ # Pack if lengths provided for efficiency
244
+ if lengths is not None:
245
+ packed_embedded = pack_padded_sequence(
246
+ embedded, lengths.cpu(),
247
+ batch_first=True,
248
+ enforce_sorted=False
249
+ )
250
+ lstm_out, _ = self.lstm(packed_embedded)
251
+ lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
252
+ else:
253
+ lstm_out, _ = self.lstm(embedded)
254
+
255
+ # Apply dropout and project to output
256
+ lstm_out = self.output_dropout(lstm_out)
257
+ logits = self.fc(lstm_out) # [batch_size, seq_len, num_classes]
258
+
259
+ return logits
260
+
261
+ def create_lstm_pii_model(vocab_size: int, num_classes: int, d_model: int = 256,
262
+ num_heads: int = 8, d_ff: int = 512, num_layers: int = 4,
263
+ dropout: float = 0.1, max_len: int = 512):
264
+ """
265
+ Create Bidirectional LSTM model for PII detection
266
+ Note: num_heads and d_ff are ignored (kept for compatibility with transformer interface)
267
+
268
+ Args:
269
+ vocab_size: size of vocabulary
270
+ num_classes: number of output classes (PII tags)
271
+ d_model: hidden dimension size
272
+ num_heads: ignored (for compatibility)
273
+ d_ff: ignored (for compatibility)
274
+ num_layers: number of LSTM layers
275
+ dropout: dropout rate
276
+ max_len: maximum sequence length
277
+
278
+ Returns:
279
+ LSTM
280
+ """
281
+ return LSTM(
282
+ vocab_size=vocab_size,
283
+ num_classes=num_classes,
284
+ embed_size=d_model // 2, # Use half of d_model as embedding size
285
+ hidden_size=d_model,
286
+ num_layers=num_layers,
287
+ dropout=dropout,
288
+ max_len=max_len
289
+ )
lstm_training.ipynb ADDED
@@ -0,0 +1,1350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "6bce68a8",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2025-08-03T18:03:08.438040Z",
10
+ "iopub.status.busy": "2025-08-03T18:03:08.437435Z",
11
+ "iopub.status.idle": "2025-08-03T18:03:15.190888Z",
12
+ "shell.execute_reply": "2025-08-03T18:03:15.190285Z"
13
+ },
14
+ "papermill": {
15
+ "duration": 6.758353,
16
+ "end_time": "2025-08-03T18:03:15.192202",
17
+ "exception": false,
18
+ "start_time": "2025-08-03T18:03:08.433849",
19
+ "status": "completed"
20
+ },
21
+ "tags": []
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "import torch\n",
26
+ "import torch.nn as nn\n",
27
+ "import torch.optim as optim\n",
28
+ "from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler\n",
29
+ "from torch.nn.utils.rnn import pad_sequence\n",
30
+ "import pandas as pd\n",
31
+ "import numpy as np\n",
32
+ "from sklearn.model_selection import train_test_split\n",
33
+ "from collections import Counter\n",
34
+ "import pickle\n",
35
+ "from tqdm import tqdm\n",
36
+ "import matplotlib.pyplot as plt\n",
37
+ "import os\n",
38
+ "from datetime import datetime\n",
39
+ "from lstm import create_lstm_pii_model\n",
40
+ "from data_augmentation import calculate_class_weights"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 4,
46
+ "id": "1207cd93",
47
+ "metadata": {
48
+ "execution": {
49
+ "iopub.execute_input": "2025-08-03T18:03:15.199050Z",
50
+ "iopub.status.busy": "2025-08-03T18:03:15.198726Z",
51
+ "iopub.status.idle": "2025-08-03T18:03:15.205267Z",
52
+ "shell.execute_reply": "2025-08-03T18:03:15.204584Z"
53
+ },
54
+ "papermill": {
55
+ "duration": 0.010986,
56
+ "end_time": "2025-08-03T18:03:15.206321",
57
+ "exception": false,
58
+ "start_time": "2025-08-03T18:03:15.195335",
59
+ "status": "completed"
60
+ },
61
+ "tags": []
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "class Vocabulary:\n",
66
+ " \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
67
+ " def __init__(self, max_size=100000):\n",
68
+ " self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
69
+ " self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
70
+ " self.word_count = Counter()\n",
71
+ " self.max_size = max_size\n",
72
+ " \n",
73
+ " def add_sentence(self, sentence):\n",
74
+ " for word in sentence:\n",
75
+ " self.word_count[word.lower()] += 1\n",
76
+ " \n",
77
+ " def build(self):\n",
78
+ " most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
79
+ " for word, _ in most_common:\n",
80
+ " if word not in self.word2idx:\n",
81
+ " idx = len(self.word2idx)\n",
82
+ " self.word2idx[word] = idx\n",
83
+ " self.idx2word[idx] = word\n",
84
+ " \n",
85
+ " def __len__(self):\n",
86
+ " return len(self.word2idx)\n",
87
+ " \n",
88
+ " def encode(self, sentence):\n",
89
+ " return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
90
+ " \n",
91
+ " def decode(self, indices):\n",
92
+ " return [self.idx2word.get(idx, '<unk>') for idx in indices]"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 5,
98
+ "id": "f4056292",
99
+ "metadata": {
100
+ "execution": {
101
+ "iopub.execute_input": "2025-08-03T18:03:15.212478Z",
102
+ "iopub.status.busy": "2025-08-03T18:03:15.212272Z",
103
+ "iopub.status.idle": "2025-08-03T18:03:15.217352Z",
104
+ "shell.execute_reply": "2025-08-03T18:03:15.216675Z"
105
+ },
106
+ "papermill": {
107
+ "duration": 0.009321,
108
+ "end_time": "2025-08-03T18:03:15.218370",
109
+ "exception": false,
110
+ "start_time": "2025-08-03T18:03:15.209049",
111
+ "status": "completed"
112
+ },
113
+ "tags": []
114
+ },
115
+ "outputs": [],
116
+ "source": [
117
+ "class PIIDataset(Dataset):\n",
118
+ " \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
119
+ " def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
120
+ " self.tokens = tokens\n",
121
+ " self.labels = labels\n",
122
+ " self.text_vocab = text_vocab\n",
123
+ " self.label_vocab = label_vocab\n",
124
+ " self.max_len = max_len\n",
125
+ " \n",
126
+ " def __len__(self):\n",
127
+ " return len(self.tokens)\n",
128
+ " \n",
129
+ " def __getitem__(self, idx):\n",
130
+ " # Add start and end tokens\n",
131
+ " tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
132
+ " labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
133
+ " \n",
134
+ " # Truncate if too long\n",
135
+ " if len(tokens) > self.max_len:\n",
136
+ " tokens = tokens[:self.max_len-1] + ['<end>']\n",
137
+ " labels = labels[:self.max_len-1] + ['<end>']\n",
138
+ " \n",
139
+ " # Encode\n",
140
+ " token_ids = self.text_vocab.encode(tokens)\n",
141
+ " label_ids = self.label_vocab.encode(labels)\n",
142
+ " \n",
143
+ " return torch.tensor(token_ids), torch.tensor(label_ids)"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 6,
149
+ "id": "499deba2",
150
+ "metadata": {
151
+ "execution": {
152
+ "iopub.execute_input": "2025-08-03T18:03:15.224549Z",
153
+ "iopub.status.busy": "2025-08-03T18:03:15.224344Z",
154
+ "iopub.status.idle": "2025-08-03T18:03:15.227931Z",
155
+ "shell.execute_reply": "2025-08-03T18:03:15.227258Z"
156
+ },
157
+ "papermill": {
158
+ "duration": 0.00789,
159
+ "end_time": "2025-08-03T18:03:15.229026",
160
+ "exception": false,
161
+ "start_time": "2025-08-03T18:03:15.221136",
162
+ "status": "completed"
163
+ },
164
+ "tags": []
165
+ },
166
+ "outputs": [],
167
+ "source": [
168
+ "def collate_fn(batch):\n",
169
+ " \"\"\"Custom collate function for padding sequences\"\"\"\n",
170
+ " tokens, labels = zip(*batch)\n",
171
+ " tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
172
+ " labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
173
+ " return tokens_padded, labels_padded"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": 7,
179
+ "id": "7ade0505",
180
+ "metadata": {
181
+ "execution": {
182
+ "iopub.execute_input": "2025-08-03T18:03:15.237394Z",
183
+ "iopub.status.busy": "2025-08-03T18:03:15.236977Z",
184
+ "iopub.status.idle": "2025-08-03T18:03:15.250346Z",
185
+ "shell.execute_reply": "2025-08-03T18:03:15.249624Z"
186
+ },
187
+ "papermill": {
188
+ "duration": 0.018587,
189
+ "end_time": "2025-08-03T18:03:15.251405",
190
+ "exception": false,
191
+ "start_time": "2025-08-03T18:03:15.232818",
192
+ "status": "completed"
193
+ },
194
+ "tags": []
195
+ },
196
+ "outputs": [],
197
+ "source": [
198
+ "class F1ScoreMetric:\n",
199
+ " \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
200
+ " def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
201
+ " self.beta = beta\n",
202
+ " self.num_classes = num_classes\n",
203
+ " self.ignore_index = ignore_index\n",
204
+ " self.label_vocab = label_vocab\n",
205
+ " self.reset()\n",
206
+ " \n",
207
+ " def reset(self):\n",
208
+ " self.true_positives = 0\n",
209
+ " self.false_positives = 0\n",
210
+ " self.false_negatives = 0\n",
211
+ " self.class_metrics = {}\n",
212
+ " \n",
213
+ " def update(self, predictions, targets):\n",
214
+ " mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
215
+ " o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
216
+ " \n",
217
+ " for class_id in range(1, self.num_classes):\n",
218
+ " if class_id == o_idx:\n",
219
+ " continue\n",
220
+ " \n",
221
+ " pred_mask = (predictions == class_id) & mask\n",
222
+ " true_mask = (targets == class_id) & mask\n",
223
+ " \n",
224
+ " tp = ((pred_mask) & (true_mask)).sum().item()\n",
225
+ " fp = ((pred_mask) & (~true_mask)).sum().item()\n",
226
+ " fn = ((~pred_mask) & (true_mask)).sum().item()\n",
227
+ " \n",
228
+ " self.true_positives += tp\n",
229
+ " self.false_positives += fp\n",
230
+ " self.false_negatives += fn\n",
231
+ " \n",
232
+ " if class_id not in self.class_metrics:\n",
233
+ " self.class_metrics[class_id] = {'tp': 0, 'fp': 0, 'fn': 0}\n",
234
+ " self.class_metrics[class_id]['tp'] += tp\n",
235
+ " self.class_metrics[class_id]['fp'] += fp\n",
236
+ " self.class_metrics[class_id]['fn'] += fn\n",
237
+ " \n",
238
+ " def compute(self):\n",
239
+ " beta_squared = self.beta ** 2\n",
240
+ " precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
241
+ " recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
242
+ " f1 = (1 + beta_squared) * precision * recall / (beta_squared * precision + recall + 1e-8)\n",
243
+ " return f1\n",
244
+ " \n",
245
+ " def get_class_metrics(self):\n",
246
+ " results = {}\n",
247
+ " for class_id, metrics in self.class_metrics.items():\n",
248
+ " if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
249
+ " class_name = self.label_vocab.idx2word[class_id]\n",
250
+ " precision = metrics['tp'] / (metrics['tp'] + metrics['fp'] + 1e-8)\n",
251
+ " recall = metrics['tp'] / (metrics['tp'] + metrics['fn'] + 1e-8)\n",
252
+ " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
253
+ " results[class_name] = {\n",
254
+ " 'precision': precision,\n",
255
+ " 'recall': recall,\n",
256
+ " 'f1': f1,\n",
257
+ " 'support': metrics['tp'] + metrics['fn']\n",
258
+ " }\n",
259
+ " return results"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": 8,
265
+ "id": "361b5505",
266
+ "metadata": {
267
+ "execution": {
268
+ "iopub.execute_input": "2025-08-03T18:03:15.258002Z",
269
+ "iopub.status.busy": "2025-08-03T18:03:15.257703Z",
270
+ "iopub.status.idle": "2025-08-03T18:03:15.265171Z",
271
+ "shell.execute_reply": "2025-08-03T18:03:15.264658Z"
272
+ },
273
+ "papermill": {
274
+ "duration": 0.011955,
275
+ "end_time": "2025-08-03T18:03:15.266159",
276
+ "exception": false,
277
+ "start_time": "2025-08-03T18:03:15.254204",
278
+ "status": "completed"
279
+ },
280
+ "tags": []
281
+ },
282
+ "outputs": [],
283
+ "source": [
284
+ "class FocalLoss(nn.Module):\n",
285
+ " \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
286
+ " def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
287
+ " super(FocalLoss, self).__init__()\n",
288
+ " self.alpha = alpha\n",
289
+ " self.gamma = gamma\n",
290
+ " self.reduction = reduction\n",
291
+ " self.ignore_index = ignore_index\n",
292
+ " \n",
293
+ " def forward(self, inputs, targets):\n",
294
+ " ce_loss = nn.functional.cross_entropy(\n",
295
+ " inputs, targets, \n",
296
+ " weight=self.alpha, \n",
297
+ " reduction='none',\n",
298
+ " ignore_index=self.ignore_index\n",
299
+ " )\n",
300
+ " \n",
301
+ " pt = torch.exp(-ce_loss)\n",
302
+ " focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
303
+ " \n",
304
+ " if self.reduction == 'mean':\n",
305
+ " return focal_loss.mean()\n",
306
+ " elif self.reduction == 'sum':\n",
307
+ " return focal_loss.sum()\n",
308
+ " else:\n",
309
+ " return focal_loss"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 9,
315
+ "id": "1de646e9",
316
+ "metadata": {
317
+ "execution": {
318
+ "iopub.execute_input": "2025-08-03T18:03:15.272639Z",
319
+ "iopub.status.busy": "2025-08-03T18:03:15.272459Z",
320
+ "iopub.status.idle": "2025-08-03T18:03:15.277673Z",
321
+ "shell.execute_reply": "2025-08-03T18:03:15.277165Z"
322
+ },
323
+ "papermill": {
324
+ "duration": 0.009528,
325
+ "end_time": "2025-08-03T18:03:15.278705",
326
+ "exception": false,
327
+ "start_time": "2025-08-03T18:03:15.269177",
328
+ "status": "completed"
329
+ },
330
+ "tags": []
331
+ },
332
+ "outputs": [],
333
+ "source": [
334
+ "def train_epoch(model, dataloader, optimizer, criterion, device, f1_metric):\n",
335
+ " \"\"\"Train for one epoch\"\"\"\n",
336
+ " model.train()\n",
337
+ " total_loss = 0\n",
338
+ " f1_metric.reset()\n",
339
+ " \n",
340
+ " progress_bar = tqdm(dataloader, desc='Training')\n",
341
+ " for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
342
+ " tokens = tokens.to(device)\n",
343
+ " labels = labels.to(device)\n",
344
+ " \n",
345
+ " # Forward pass\n",
346
+ " optimizer.zero_grad()\n",
347
+ " outputs = model(tokens)\n",
348
+ " \n",
349
+ " # Reshape for loss calculation\n",
350
+ " outputs_flat = outputs.view(-1, outputs.size(-1))\n",
351
+ " labels_flat = labels.view(-1)\n",
352
+ " \n",
353
+ " # Calculate loss and backward pass\n",
354
+ " loss = criterion(outputs_flat, labels_flat)\n",
355
+ " loss.backward()\n",
356
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
357
+ " optimizer.step()\n",
358
+ " \n",
359
+ " # Update metrics\n",
360
+ " total_loss += loss.item()\n",
361
+ " predictions = torch.argmax(outputs, dim=-1)\n",
362
+ " f1_metric.update(predictions, labels)\n",
363
+ " \n",
364
+ " # Update progress bar\n",
365
+ " progress_bar.set_postfix({\n",
366
+ " 'loss': f\"{loss.item():.4f}\",\n",
367
+ " 'f1': f\"{f1_metric.compute():.4f}\"\n",
368
+ " })\n",
369
+ " \n",
370
+ " return total_loss / len(dataloader), f1_metric.compute()"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": 10,
376
+ "id": "d1ce3b0f",
377
+ "metadata": {
378
+ "execution": {
379
+ "iopub.execute_input": "2025-08-03T18:03:15.284917Z",
380
+ "iopub.status.busy": "2025-08-03T18:03:15.284718Z",
381
+ "iopub.status.idle": "2025-08-03T18:03:15.289392Z",
382
+ "shell.execute_reply": "2025-08-03T18:03:15.288854Z"
383
+ },
384
+ "papermill": {
385
+ "duration": 0.008891,
386
+ "end_time": "2025-08-03T18:03:15.290379",
387
+ "exception": false,
388
+ "start_time": "2025-08-03T18:03:15.281488",
389
+ "status": "completed"
390
+ },
391
+ "tags": []
392
+ },
393
+ "outputs": [],
394
+ "source": [
395
+ "def evaluate(model, dataloader, criterion, device, f1_metric):\n",
396
+ " \"\"\"Evaluate model on validation/test set\"\"\"\n",
397
+ " model.eval()\n",
398
+ " total_loss = 0\n",
399
+ " f1_metric.reset()\n",
400
+ " \n",
401
+ " with torch.no_grad():\n",
402
+ " for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
403
+ " tokens = tokens.to(device)\n",
404
+ " labels = labels.to(device)\n",
405
+ " \n",
406
+ " # Forward pass\n",
407
+ " outputs = model(tokens)\n",
408
+ " outputs_flat = outputs.view(-1, outputs.size(-1))\n",
409
+ " labels_flat = labels.view(-1)\n",
410
+ " \n",
411
+ " # Calculate loss\n",
412
+ " loss = criterion(outputs_flat, labels_flat)\n",
413
+ " total_loss += loss.item()\n",
414
+ " \n",
415
+ " # Update metrics\n",
416
+ " predictions = torch.argmax(outputs, dim=-1)\n",
417
+ " f1_metric.update(predictions, labels)\n",
418
+ " \n",
419
+ " return total_loss / len(dataloader), f1_metric.compute()"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": 11,
425
+ "id": "da3ff80c",
426
+ "metadata": {
427
+ "execution": {
428
+ "iopub.execute_input": "2025-08-03T18:03:15.296567Z",
429
+ "iopub.status.busy": "2025-08-03T18:03:15.296378Z",
430
+ "iopub.status.idle": "2025-08-03T18:03:15.300725Z",
431
+ "shell.execute_reply": "2025-08-03T18:03:15.300185Z"
432
+ },
433
+ "papermill": {
434
+ "duration": 0.008576,
435
+ "end_time": "2025-08-03T18:03:15.301673",
436
+ "exception": false,
437
+ "start_time": "2025-08-03T18:03:15.293097",
438
+ "status": "completed"
439
+ },
440
+ "tags": []
441
+ },
442
+ "outputs": [],
443
+ "source": [
444
+ "def create_balanced_sampler(dataset, label_vocab):\n",
445
+ " \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
446
+ " sample_weights = []\n",
447
+ " \n",
448
+ " for idx in range(len(dataset)):\n",
449
+ " _, labels = dataset[idx]\n",
450
+ " \n",
451
+ " # Give higher weight to samples with rare PII\n",
452
+ " min_weight = 1.0\n",
453
+ " for label_id in labels:\n",
454
+ " if label_id > 3: # Skip special tokens\n",
455
+ " label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
456
+ " if label_name != 'o' and 'B-' in label_name:\n",
457
+ " min_weight = 10.0\n",
458
+ " break\n",
459
+ " \n",
460
+ " sample_weights.append(min_weight)\n",
461
+ " \n",
462
+ " sampler = WeightedRandomSampler(\n",
463
+ " weights=sample_weights,\n",
464
+ " num_samples=len(sample_weights),\n",
465
+ " replacement=True\n",
466
+ " )\n",
467
+ " \n",
468
+ " return sampler"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": 12,
474
+ "id": "69b37e68",
475
+ "metadata": {
476
+ "execution": {
477
+ "iopub.execute_input": "2025-08-03T18:03:15.307761Z",
478
+ "iopub.status.busy": "2025-08-03T18:03:15.307589Z",
479
+ "iopub.status.idle": "2025-08-03T18:03:15.311849Z",
480
+ "shell.execute_reply": "2025-08-03T18:03:15.311334Z"
481
+ },
482
+ "papermill": {
483
+ "duration": 0.008327,
484
+ "end_time": "2025-08-03T18:03:15.312778",
485
+ "exception": false,
486
+ "start_time": "2025-08-03T18:03:15.304451",
487
+ "status": "completed"
488
+ },
489
+ "tags": []
490
+ },
491
+ "outputs": [],
492
+ "source": [
493
+ "def print_label_distribution(data, title=\"Label Distribution\"):\n",
494
+ " \"\"\"Print label distribution statistics\"\"\"\n",
495
+ " label_counts = Counter()\n",
496
+ " for label_seq in data.labels:\n",
497
+ " for label in label_seq:\n",
498
+ " if label not in ['<pad>', '<start>', '<end>']:\n",
499
+ " label_counts[label] += 1\n",
500
+ " \n",
501
+ " print(f\"\\n{title}:\")\n",
502
+ " print(\"-\" * 50)\n",
503
+ " total = sum(label_counts.values())\n",
504
+ " for label, count in label_counts.most_common():\n",
505
+ " percentage = (count / total) * 100\n",
506
+ " print(f\" {label:20} : {count:8,} ({percentage:5.2f}%)\")\n",
507
+ " print(\"-\" * 50)\n",
508
+ " print(f\" {'Total':20} : {total:8,}\")"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "execution_count": 13,
514
+ "id": "4b1b4f86",
515
+ "metadata": {
516
+ "execution": {
517
+ "iopub.execute_input": "2025-08-03T18:03:15.319812Z",
518
+ "iopub.status.busy": "2025-08-03T18:03:15.319647Z",
519
+ "iopub.status.idle": "2025-08-03T18:03:15.323992Z",
520
+ "shell.execute_reply": "2025-08-03T18:03:15.323517Z"
521
+ },
522
+ "papermill": {
523
+ "duration": 0.00942,
524
+ "end_time": "2025-08-03T18:03:15.325043",
525
+ "exception": false,
526
+ "start_time": "2025-08-03T18:03:15.315623",
527
+ "status": "completed"
528
+ },
529
+ "tags": []
530
+ },
531
+ "outputs": [],
532
+ "source": [
533
+ "def save_model(model, text_vocab, label_vocab, config, save_dir):\n",
534
+ " \"\"\"Save model and all necessary components for deployment\"\"\"\n",
535
+ " os.makedirs(save_dir, exist_ok=True)\n",
536
+ " \n",
537
+ " # Save model state\n",
538
+ " model_path = os.path.join(save_dir, 'pii_lstm_model.pt')\n",
539
+ " torch.save(model.state_dict(), model_path)\n",
540
+ " \n",
541
+ " # Save vocabularies\n",
542
+ " vocab_path = os.path.join(save_dir, 'vocabularies.pkl')\n",
543
+ " with open(vocab_path, 'wb') as f:\n",
544
+ " pickle.dump({\n",
545
+ " 'text_vocab': text_vocab,\n",
546
+ " 'label_vocab': label_vocab\n",
547
+ " }, f)\n",
548
+ " \n",
549
+ " # Save model configuration\n",
550
+ " config_path = os.path.join(save_dir, 'model_config.pkl')\n",
551
+ " with open(config_path, 'wb') as f:\n",
552
+ " pickle.dump(config, f)\n",
553
+ " \n",
554
+ " print(f\"\\nModel saved for deployment in '{save_dir}/' directory\")\n",
555
+ " print(\"Files saved:\")\n",
556
+ " print(f\" - {model_path}\")\n",
557
+ " print(f\" - {vocab_path}\")\n",
558
+ " print(f\" - {config_path}\")"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": 14,
564
+ "id": "31d2f1b1",
565
+ "metadata": {
566
+ "execution": {
567
+ "iopub.execute_input": "2025-08-03T18:03:15.331818Z",
568
+ "iopub.status.busy": "2025-08-03T18:03:15.331643Z",
569
+ "iopub.status.idle": "2025-08-03T18:03:15.347264Z",
570
+ "shell.execute_reply": "2025-08-03T18:03:15.346735Z"
571
+ },
572
+ "papermill": {
573
+ "duration": 0.020356,
574
+ "end_time": "2025-08-03T18:03:15.348292",
575
+ "exception": false,
576
+ "start_time": "2025-08-03T18:03:15.327936",
577
+ "status": "completed"
578
+ },
579
+ "tags": []
580
+ },
581
+ "outputs": [],
582
+ "source": [
583
+ "def train_lstm_pii_model(\n",
584
+ " data_path,\n",
585
+ " num_epochs=30,\n",
586
+ " batch_size=32,\n",
587
+ " learning_rate=3e-4,\n",
588
+ " use_focal_loss=True,\n",
589
+ " focal_gamma=2.0,\n",
590
+ " device='cuda',\n",
591
+ "):\n",
592
+ " \"\"\"Main training function for LSTM model\"\"\"\n",
593
+ " \n",
594
+ " # Load data\n",
595
+ " print(\"Loading augmented data...\")\n",
596
+ " data = pd.read_json(data_path, lines=True)\n",
597
+ " print(f\"Total samples: {len(data)}\")\n",
598
+ " \n",
599
+ " # Print initial label distribution\n",
600
+ " print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
601
+ " \n",
602
+ " # Build vocabularies\n",
603
+ " print(\"\\nBuilding vocabularies...\")\n",
604
+ " text_vocab = Vocabulary(max_size=100000)\n",
605
+ " label_vocab = Vocabulary(max_size=50)\n",
606
+ " \n",
607
+ " for tokens in data.tokens:\n",
608
+ " text_vocab.add_sentence(tokens)\n",
609
+ " for labels in data.labels:\n",
610
+ " label_vocab.add_sentence(labels)\n",
611
+ " \n",
612
+ " text_vocab.build()\n",
613
+ " label_vocab.build()\n",
614
+ " \n",
615
+ " print(f\"\\nVocabulary sizes:\")\n",
616
+ " print(f\" - Text vocabulary: {len(text_vocab):,}\")\n",
617
+ " print(f\" - Label vocabulary: {len(label_vocab)}\")\n",
618
+ " \n",
619
+ " # Calculate class weights\n",
620
+ " class_weights = calculate_class_weights(data, label_vocab)\n",
621
+ " class_weights = class_weights.to(device)\n",
622
+ " \n",
623
+ " # Split data\n",
624
+ " X_train, X_val, y_train, y_val = train_test_split(\n",
625
+ " data.tokens.tolist(),\n",
626
+ " data.labels.tolist(),\n",
627
+ " test_size=0.2,\n",
628
+ " random_state=42\n",
629
+ " )\n",
630
+ " \n",
631
+ " print(f\"\\nData split:\")\n",
632
+ " print(f\" - Train samples: {len(X_train):,}\")\n",
633
+ " print(f\" - Validation samples: {len(X_val):,}\")\n",
634
+ " \n",
635
+ " # Create datasets and dataloaders\n",
636
+ " max_seq_len = 512\n",
637
+ " train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
638
+ " val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
639
+ " \n",
640
+ " # Use balanced sampler for training\n",
641
+ " train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
642
+ " \n",
643
+ " train_loader = DataLoader(\n",
644
+ " train_dataset, \n",
645
+ " batch_size=batch_size,\n",
646
+ " sampler=train_sampler,\n",
647
+ " collate_fn=collate_fn,\n",
648
+ " num_workers=0\n",
649
+ " )\n",
650
+ " \n",
651
+ " val_loader = DataLoader(\n",
652
+ " val_dataset, \n",
653
+ " batch_size=batch_size,\n",
654
+ " shuffle=False, \n",
655
+ " collate_fn=collate_fn,\n",
656
+ " num_workers=0\n",
657
+ " )\n",
658
+ " \n",
659
+ " # Model configuration\n",
660
+ " model_config = {\n",
661
+ " 'vocab_size': len(text_vocab),\n",
662
+ " 'num_classes': len(label_vocab),\n",
663
+ " 'd_model': 256,\n",
664
+ " 'num_heads': 8, # Not used by LSTM, kept for compatibility\n",
665
+ " 'd_ff': 512, # Not used by LSTM, kept for compatibility\n",
666
+ " 'num_layers': 2, # Number of LSTM layers\n",
667
+ " 'dropout': 0.1,\n",
668
+ " 'max_len': max_seq_len\n",
669
+ " }\n",
670
+ " \n",
671
+ " # Create model\n",
672
+ " print(\"\\nCreating LSTM model...\")\n",
673
+ " model = create_lstm_pii_model(**model_config).to(device)\n",
674
+ " print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
675
+ " \n",
676
+ " # Print model architecture\n",
677
+ " print(\"\\nModel Architecture:\")\n",
678
+ " print(f\" - Embedding: {model_config['vocab_size']} -> {model_config['d_model'] // 2}\")\n",
679
+ " print(f\" - Bidirectional LSTM: {model_config['num_layers']} layers, hidden size: {model_config['d_model']}\")\n",
680
+ " print(f\" - Output: {model_config['d_model'] * 2} -> {model_config['num_classes']}\")\n",
681
+ " \n",
682
+ " # Setup loss function\n",
683
+ " if use_focal_loss:\n",
684
+ " criterion = FocalLoss(\n",
685
+ " alpha=class_weights,\n",
686
+ " gamma=focal_gamma,\n",
687
+ " ignore_index=0\n",
688
+ " )\n",
689
+ " print(f\"\\nUsing Focal Loss with gamma={focal_gamma}\")\n",
690
+ " else:\n",
691
+ " criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=0)\n",
692
+ " print(\"\\nUsing Cross Entropy Loss\")\n",
693
+ " \n",
694
+ " # Setup optimizer and scheduler\n",
695
+ " optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)\n",
696
+ " scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n",
697
+ " optimizer, \n",
698
+ " mode='min',\n",
699
+ " patience=3, \n",
700
+ " factor=0.5,\n",
701
+ " min_lr=1e-6\n",
702
+ " )\n",
703
+ " \n",
704
+ " # Metrics\n",
705
+ " f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
706
+ " f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
707
+ " \n",
708
+ " # Training loop\n",
709
+ " train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
710
+ " best_val_f1 = 0\n",
711
+ " patience = 7\n",
712
+ " patience_counter = 0\n",
713
+ " \n",
714
+ " print(\"\\nStarting training...\")\n",
715
+ " print(\"=\" * 60)\n",
716
+ " \n",
717
+ " for epoch in range(num_epochs):\n",
718
+ " print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
719
+ " \n",
720
+ " # Train and validate\n",
721
+ " train_loss, train_f1 = train_epoch(\n",
722
+ " model, train_loader, optimizer, criterion, device, f1_metric_train\n",
723
+ " )\n",
724
+ " val_loss, val_f1 = evaluate(\n",
725
+ " model, val_loader, criterion, device, f1_metric_val\n",
726
+ " )\n",
727
+ " \n",
728
+ " # Step scheduler based on validation loss\n",
729
+ " scheduler.step(val_loss)\n",
730
+ " \n",
731
+ " # Store metrics\n",
732
+ " train_losses.append(train_loss)\n",
733
+ " train_f1s.append(train_f1)\n",
734
+ " val_losses.append(val_loss)\n",
735
+ " val_f1s.append(val_f1)\n",
736
+ " \n",
737
+ " # Print epoch results\n",
738
+ " print(f\"Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}\")\n",
739
+ " print(f\"Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}\")\n",
740
+ " print(f\"Learning rate: {optimizer.param_groups[0]['lr']:.6f}\")\n",
741
+ " \n",
742
+ " # Save best model\n",
743
+ " if val_f1 > best_val_f1:\n",
744
+ " best_val_f1 = val_f1\n",
745
+ " patience_counter = 0\n",
746
+ " \n",
747
+ " # Save complete checkpoint\n",
748
+ " checkpoint = {\n",
749
+ " 'epoch': epoch,\n",
750
+ " 'model_state_dict': model.state_dict(),\n",
751
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
752
+ " 'scheduler_state_dict': scheduler.state_dict(),\n",
753
+ " 'train_loss': train_loss,\n",
754
+ " 'val_loss': val_loss,\n",
755
+ " 'train_f1': train_f1,\n",
756
+ " 'val_f1': val_f1,\n",
757
+ " 'text_vocab': text_vocab,\n",
758
+ " 'label_vocab': label_vocab,\n",
759
+ " 'model_config': model_config\n",
760
+ " }\n",
761
+ " torch.save(checkpoint, 'best_lstm_checkpoint.pt')\n",
762
+ " \n",
763
+ " print(f\"βœ“ Saved best model with F1: {val_f1:.4f}\")\n",
764
+ " else:\n",
765
+ " patience_counter += 1\n",
766
+ " \n",
767
+ " # Early stopping\n",
768
+ " if patience_counter >= patience and epoch > 10:\n",
769
+ " print(f\"\\nEarly stopping triggered after {patience} epochs without improvement\")\n",
770
+ " break\n",
771
+ " \n",
772
+ " # Plot training curves\n",
773
+ " plt.figure(figsize=(12, 5))\n",
774
+ " \n",
775
+ " plt.subplot(1, 2, 1)\n",
776
+ " plt.plot(train_losses, label='Train Loss', linewidth=2)\n",
777
+ " plt.plot(val_losses, label='Val Loss', linewidth=2)\n",
778
+ " plt.xlabel('Epoch')\n",
779
+ " plt.ylabel('Loss')\n",
780
+ " plt.title('Training and Validation Loss')\n",
781
+ " plt.legend()\n",
782
+ " plt.grid(True, alpha=0.3)\n",
783
+ " \n",
784
+ " plt.subplot(1, 2, 2)\n",
785
+ " plt.plot(train_f1s, label='Train F1', linewidth=2)\n",
786
+ " plt.plot(val_f1s, label='Val F1', linewidth=2)\n",
787
+ " plt.axhline(y=best_val_f1, color='r', linestyle='--', label=f'Best F1: {best_val_f1:.4f}')\n",
788
+ " plt.xlabel('Epoch')\n",
789
+ " plt.ylabel('F1 Score')\n",
790
+ " plt.title('Training and Validation F1 Score')\n",
791
+ " plt.legend()\n",
792
+ " plt.grid(True, alpha=0.3)\n",
793
+ " \n",
794
+ " plt.tight_layout()\n",
795
+ " plt.savefig('lstm_training_curves.png', dpi=300, bbox_inches='tight')\n",
796
+ " plt.close()\n",
797
+ " \n",
798
+ " print(f\"\\n{'='*60}\")\n",
799
+ " print(f\"Training completed!\")\n",
800
+ " print(f\"Best validation F1: {best_val_f1:.4f}\")\n",
801
+ " print(f\"Training curves saved to: lstm_training_curves.png\")\n",
802
+ " \n",
803
+ " # Save model for deployment\n",
804
+ " save_model(model, text_vocab, label_vocab, model_config, 'saved_lstm_model')\n",
805
+ " \n",
806
+ " return model, text_vocab, label_vocab"
807
+ ]
808
+ },
809
+ {
810
+ "cell_type": "code",
811
+ "execution_count": null,
812
+ "id": "fcb2b401",
813
+ "metadata": {
814
+ "execution": {
815
+ "iopub.execute_input": "2025-08-03T18:03:15.354835Z",
816
+ "iopub.status.busy": "2025-08-03T18:03:15.354423Z",
817
+ "iopub.status.idle": "2025-08-04T04:06:32.402286Z",
818
+ "shell.execute_reply": "2025-08-04T04:06:32.401401Z"
819
+ },
820
+ "papermill": {
821
+ "duration": 36197.052354,
822
+ "end_time": "2025-08-04T04:06:32.403447",
823
+ "exception": false,
824
+ "start_time": "2025-08-03T18:03:15.351093",
825
+ "status": "completed"
826
+ },
827
+ "tags": []
828
+ },
829
+ "outputs": [
830
+ {
831
+ "name": "stdout",
832
+ "output_type": "stream",
833
+ "text": [
834
+ "Using device: cuda\n",
835
+ "Loading augmented data...\n",
836
+ "Total samples: 19694\n",
837
+ "\n",
838
+ "Label Distribution in Augmented Data:\n",
839
+ "--------------------------------------------------\n",
840
+ " O : 5,082,150 (99.33%)\n",
841
+ " I-STREET_ADDRESS : 15,650 ( 0.31%)\n",
842
+ " B-ID_NUM : 2,505 ( 0.05%)\n",
843
+ " B-EMAIL : 2,488 ( 0.05%)\n",
844
+ " B-URL_PERSONAL : 2,478 ( 0.05%)\n",
845
+ " B-STREET_ADDRESS : 2,452 ( 0.05%)\n",
846
+ " B-PHONE_NUM : 2,450 ( 0.05%)\n",
847
+ " B-USERNAME : 2,210 ( 0.04%)\n",
848
+ " B-NAME_STUDENT : 1,968 ( 0.04%)\n",
849
+ " I-NAME_STUDENT : 1,735 ( 0.03%)\n",
850
+ " I-PHONE_NUM : 500 ( 0.01%)\n",
851
+ " I-URL_PERSONAL : 1 ( 0.00%)\n",
852
+ " I-ID_NUM : 1 ( 0.00%)\n",
853
+ "--------------------------------------------------\n",
854
+ " Total : 5,116,588\n",
855
+ "\n",
856
+ "Building vocabularies...\n",
857
+ "\n",
858
+ "Vocabulary sizes:\n",
859
+ " - Text vocabulary: 65,295\n",
860
+ " - Label vocabulary: 17\n",
861
+ "\n",
862
+ "Data split:\n",
863
+ " - Train samples: 15,755\n",
864
+ " - Validation samples: 3,939\n",
865
+ "\n",
866
+ "Creating LSTM model...\n",
867
+ "Model parameters: 10,729,873\n",
868
+ "\n",
869
+ "Model Architecture:\n",
870
+ " - Embedding: 65295 -> 128\n",
871
+ " - Bidirectional LSTM: 2 layers, hidden size: 256\n",
872
+ " - Output: 512 -> 17\n",
873
+ "\n",
874
+ "Using Focal Loss with gamma=2.0\n",
875
+ "\n",
876
+ "Starting training...\n",
877
+ "============================================================\n",
878
+ "\n",
879
+ "Epoch 1/20\n"
880
+ ]
881
+ },
882
+ {
883
+ "name": "stderr",
884
+ "output_type": "stream",
885
+ "text": [
886
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:41<00:00, 3.49s/it, loss=0.0000, f1=0.1535]\n",
887
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:27<00:00, 1.41it/s]\n"
888
+ ]
889
+ },
890
+ {
891
+ "name": "stdout",
892
+ "output_type": "stream",
893
+ "text": [
894
+ "Train Loss: 0.0002, Train F1: 0.1535\n",
895
+ "Val Loss: 0.0001, Val F1: 0.4344\n",
896
+ "Learning rate: 0.000300\n",
897
+ "βœ“ Saved best model with F1: 0.4344\n",
898
+ "\n",
899
+ "Epoch 2/20\n"
900
+ ]
901
+ },
902
+ {
903
+ "name": "stderr",
904
+ "output_type": "stream",
905
+ "text": [
906
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:39<00:00, 3.49s/it, loss=0.0000, f1=0.5546]\n",
907
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:29<00:00, 1.38it/s]\n"
908
+ ]
909
+ },
910
+ {
911
+ "name": "stdout",
912
+ "output_type": "stream",
913
+ "text": [
914
+ "Train Loss: 0.0000, Train F1: 0.5546\n",
915
+ "Val Loss: 0.0000, Val F1: 0.6417\n",
916
+ "Learning rate: 0.000300\n",
917
+ "βœ“ Saved best model with F1: 0.6417\n",
918
+ "\n",
919
+ "Epoch 3/20\n"
920
+ ]
921
+ },
922
+ {
923
+ "name": "stderr",
924
+ "output_type": "stream",
925
+ "text": [
926
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:39<00:00, 3.49s/it, loss=0.0000, f1=0.7183]\n",
927
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:28<00:00, 1.40it/s]\n"
928
+ ]
929
+ },
930
+ {
931
+ "name": "stdout",
932
+ "output_type": "stream",
933
+ "text": [
934
+ "Train Loss: 0.0000, Train F1: 0.7183\n",
935
+ "Val Loss: 0.0000, Val F1: 0.7736\n",
936
+ "Learning rate: 0.000300\n",
937
+ "βœ“ Saved best model with F1: 0.7736\n",
938
+ "\n",
939
+ "Epoch 4/20\n"
940
+ ]
941
+ },
942
+ {
943
+ "name": "stderr",
944
+ "output_type": "stream",
945
+ "text": [
946
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:39<00:00, 3.49s/it, loss=0.0000, f1=0.8117]\n",
947
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:29<00:00, 1.38it/s]\n"
948
+ ]
949
+ },
950
+ {
951
+ "name": "stdout",
952
+ "output_type": "stream",
953
+ "text": [
954
+ "Train Loss: 0.0000, Train F1: 0.8117\n",
955
+ "Val Loss: 0.0000, Val F1: 0.8568\n",
956
+ "Learning rate: 0.000300\n",
957
+ "βœ“ Saved best model with F1: 0.8568\n",
958
+ "\n",
959
+ "Epoch 5/20\n"
960
+ ]
961
+ },
962
+ {
963
+ "name": "stderr",
964
+ "output_type": "stream",
965
+ "text": [
966
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:36<00:00, 3.48s/it, loss=0.0000, f1=0.8686]\n",
967
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:28<00:00, 1.41it/s]\n"
968
+ ]
969
+ },
970
+ {
971
+ "name": "stdout",
972
+ "output_type": "stream",
973
+ "text": [
974
+ "Train Loss: 0.0000, Train F1: 0.8686\n",
975
+ "Val Loss: 0.0000, Val F1: 0.8847\n",
976
+ "Learning rate: 0.000300\n",
977
+ "βœ“ Saved best model with F1: 0.8847\n",
978
+ "\n",
979
+ "Epoch 6/20\n"
980
+ ]
981
+ },
982
+ {
983
+ "name": "stderr",
984
+ "output_type": "stream",
985
+ "text": [
986
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:38<00:00, 3.49s/it, loss=0.0000, f1=0.8942]\n",
987
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:30<00:00, 1.38it/s]\n"
988
+ ]
989
+ },
990
+ {
991
+ "name": "stdout",
992
+ "output_type": "stream",
993
+ "text": [
994
+ "Train Loss: 0.0000, Train F1: 0.8942\n",
995
+ "Val Loss: 0.0000, Val F1: 0.8983\n",
996
+ "Learning rate: 0.000300\n",
997
+ "βœ“ Saved best model with F1: 0.8983\n",
998
+ "\n",
999
+ "Epoch 7/20\n"
1000
+ ]
1001
+ },
1002
+ {
1003
+ "name": "stderr",
1004
+ "output_type": "stream",
1005
+ "text": [
1006
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:36<00:00, 3.48s/it, loss=0.0000, f1=0.9097]\n",
1007
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:28<00:00, 1.40it/s]\n"
1008
+ ]
1009
+ },
1010
+ {
1011
+ "name": "stdout",
1012
+ "output_type": "stream",
1013
+ "text": [
1014
+ "Train Loss: 0.0000, Train F1: 0.9097\n",
1015
+ "Val Loss: 0.0000, Val F1: 0.9147\n",
1016
+ "Learning rate: 0.000300\n",
1017
+ "βœ“ Saved best model with F1: 0.9147\n",
1018
+ "\n",
1019
+ "Epoch 8/20\n"
1020
+ ]
1021
+ },
1022
+ {
1023
+ "name": "stderr",
1024
+ "output_type": "stream",
1025
+ "text": [
1026
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:36<00:00, 3.48s/it, loss=0.0000, f1=0.9271]\n",
1027
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:29<00:00, 1.38it/s]\n"
1028
+ ]
1029
+ },
1030
+ {
1031
+ "name": "stdout",
1032
+ "output_type": "stream",
1033
+ "text": [
1034
+ "Train Loss: 0.0000, Train F1: 0.9271\n",
1035
+ "Val Loss: 0.0000, Val F1: 0.9386\n",
1036
+ "Learning rate: 0.000300\n",
1037
+ "βœ“ Saved best model with F1: 0.9386\n",
1038
+ "\n",
1039
+ "Epoch 9/20\n"
1040
+ ]
1041
+ },
1042
+ {
1043
+ "name": "stderr",
1044
+ "output_type": "stream",
1045
+ "text": [
1046
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:36<00:00, 3.48s/it, loss=0.0000, f1=0.9362]\n",
1047
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:28<00:00, 1.40it/s]\n"
1048
+ ]
1049
+ },
1050
+ {
1051
+ "name": "stdout",
1052
+ "output_type": "stream",
1053
+ "text": [
1054
+ "Train Loss: 0.0000, Train F1: 0.9362\n",
1055
+ "Val Loss: 0.0000, Val F1: 0.9371\n",
1056
+ "Learning rate: 0.000300\n",
1057
+ "\n",
1058
+ "Epoch 10/20\n"
1059
+ ]
1060
+ },
1061
+ {
1062
+ "name": "stderr",
1063
+ "output_type": "stream",
1064
+ "text": [
1065
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:37<00:00, 3.48s/it, loss=0.0000, f1=0.9457]\n",
1066
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:28<00:00, 1.40it/s]\n"
1067
+ ]
1068
+ },
1069
+ {
1070
+ "name": "stdout",
1071
+ "output_type": "stream",
1072
+ "text": [
1073
+ "Train Loss: 0.0000, Train F1: 0.9457\n",
1074
+ "Val Loss: 0.0000, Val F1: 0.9418\n",
1075
+ "Learning rate: 0.000150\n",
1076
+ "βœ“ Saved best model with F1: 0.9418\n",
1077
+ "\n",
1078
+ "Epoch 11/20\n"
1079
+ ]
1080
+ },
1081
+ {
1082
+ "name": "stderr",
1083
+ "output_type": "stream",
1084
+ "text": [
1085
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:41<00:00, 3.49s/it, loss=0.0000, f1=0.9561]\n",
1086
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:29<00:00, 1.38it/s]\n"
1087
+ ]
1088
+ },
1089
+ {
1090
+ "name": "stdout",
1091
+ "output_type": "stream",
1092
+ "text": [
1093
+ "Train Loss: 0.0000, Train F1: 0.9561\n",
1094
+ "Val Loss: 0.0000, Val F1: 0.9471\n",
1095
+ "Learning rate: 0.000150\n",
1096
+ "βœ“ Saved best model with F1: 0.9471\n",
1097
+ "\n",
1098
+ "Epoch 12/20\n"
1099
+ ]
1100
+ },
1101
+ {
1102
+ "name": "stderr",
1103
+ "output_type": "stream",
1104
+ "text": [
1105
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:38<00:00, 3.49s/it, loss=0.0000, f1=0.9579]\n",
1106
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:28<00:00, 1.40it/s]\n"
1107
+ ]
1108
+ },
1109
+ {
1110
+ "name": "stdout",
1111
+ "output_type": "stream",
1112
+ "text": [
1113
+ "Train Loss: 0.0000, Train F1: 0.9579\n",
1114
+ "Val Loss: 0.0000, Val F1: 0.9463\n",
1115
+ "Learning rate: 0.000150\n",
1116
+ "\n",
1117
+ "Epoch 13/20\n"
1118
+ ]
1119
+ },
1120
+ {
1121
+ "name": "stderr",
1122
+ "output_type": "stream",
1123
+ "text": [
1124
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:37<00:00, 3.48s/it, loss=0.0000, f1=0.9590]\n",
1125
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:28<00:00, 1.40it/s]\n"
1126
+ ]
1127
+ },
1128
+ {
1129
+ "name": "stdout",
1130
+ "output_type": "stream",
1131
+ "text": [
1132
+ "Train Loss: 0.0000, Train F1: 0.9590\n",
1133
+ "Val Loss: 0.0000, Val F1: 0.9526\n",
1134
+ "Learning rate: 0.000150\n",
1135
+ "βœ“ Saved best model with F1: 0.9526\n",
1136
+ "\n",
1137
+ "Epoch 14/20\n"
1138
+ ]
1139
+ },
1140
+ {
1141
+ "name": "stderr",
1142
+ "output_type": "stream",
1143
+ "text": [
1144
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:37<00:00, 3.48s/it, loss=0.0000, f1=0.9665]\n",
1145
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:29<00:00, 1.38it/s]\n"
1146
+ ]
1147
+ },
1148
+ {
1149
+ "name": "stdout",
1150
+ "output_type": "stream",
1151
+ "text": [
1152
+ "Train Loss: 0.0000, Train F1: 0.9665\n",
1153
+ "Val Loss: 0.0000, Val F1: 0.9499\n",
1154
+ "Learning rate: 0.000075\n",
1155
+ "\n",
1156
+ "Epoch 15/20\n"
1157
+ ]
1158
+ },
1159
+ {
1160
+ "name": "stderr",
1161
+ "output_type": "stream",
1162
+ "text": [
1163
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:38<00:00, 3.49s/it, loss=0.0000, f1=0.9674]\n",
1164
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:29<00:00, 1.38it/s]\n"
1165
+ ]
1166
+ },
1167
+ {
1168
+ "name": "stdout",
1169
+ "output_type": "stream",
1170
+ "text": [
1171
+ "Train Loss: 0.0000, Train F1: 0.9674\n",
1172
+ "Val Loss: 0.0000, Val F1: 0.9518\n",
1173
+ "Learning rate: 0.000075\n",
1174
+ "\n",
1175
+ "Epoch 16/20\n"
1176
+ ]
1177
+ },
1178
+ {
1179
+ "name": "stderr",
1180
+ "output_type": "stream",
1181
+ "text": [
1182
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:41<00:00, 3.49s/it, loss=0.0000, f1=0.9679]\n",
1183
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:30<00:00, 1.37it/s]\n"
1184
+ ]
1185
+ },
1186
+ {
1187
+ "name": "stdout",
1188
+ "output_type": "stream",
1189
+ "text": [
1190
+ "Train Loss: 0.0000, Train F1: 0.9679\n",
1191
+ "Val Loss: 0.0000, Val F1: 0.9509\n",
1192
+ "Learning rate: 0.000075\n",
1193
+ "\n",
1194
+ "Epoch 17/20\n"
1195
+ ]
1196
+ },
1197
+ {
1198
+ "name": "stderr",
1199
+ "output_type": "stream",
1200
+ "text": [
1201
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:40<00:00, 3.49s/it, loss=0.0000, f1=0.9706]\n",
1202
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:29<00:00, 1.38it/s]\n"
1203
+ ]
1204
+ },
1205
+ {
1206
+ "name": "stdout",
1207
+ "output_type": "stream",
1208
+ "text": [
1209
+ "Train Loss: 0.0000, Train F1: 0.9706\n",
1210
+ "Val Loss: 0.0000, Val F1: 0.9525\n",
1211
+ "Learning rate: 0.000075\n",
1212
+ "\n",
1213
+ "Epoch 18/20\n"
1214
+ ]
1215
+ },
1216
+ {
1217
+ "name": "stderr",
1218
+ "output_type": "stream",
1219
+ "text": [
1220
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:37<00:00, 3.48s/it, loss=0.0000, f1=0.9738]\n",
1221
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:29<00:00, 1.38it/s]\n"
1222
+ ]
1223
+ },
1224
+ {
1225
+ "name": "stdout",
1226
+ "output_type": "stream",
1227
+ "text": [
1228
+ "Train Loss: 0.0000, Train F1: 0.9738\n",
1229
+ "Val Loss: 0.0000, Val F1: 0.9509\n",
1230
+ "Learning rate: 0.000037\n",
1231
+ "\n",
1232
+ "Epoch 19/20\n"
1233
+ ]
1234
+ },
1235
+ {
1236
+ "name": "stderr",
1237
+ "output_type": "stream",
1238
+ "text": [
1239
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:39<00:00, 3.49s/it, loss=0.0000, f1=0.9722]\n",
1240
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:29<00:00, 1.38it/s]\n"
1241
+ ]
1242
+ },
1243
+ {
1244
+ "name": "stdout",
1245
+ "output_type": "stream",
1246
+ "text": [
1247
+ "Train Loss: 0.0000, Train F1: 0.9722\n",
1248
+ "Val Loss: 0.0000, Val F1: 0.9524\n",
1249
+ "Learning rate: 0.000037\n",
1250
+ "\n",
1251
+ "Epoch 20/20\n"
1252
+ ]
1253
+ },
1254
+ {
1255
+ "name": "stderr",
1256
+ "output_type": "stream",
1257
+ "text": [
1258
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [28:37<00:00, 3.48s/it, loss=0.0000, f1=0.9747]\n",
1259
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [01:30<00:00, 1.38it/s]\n"
1260
+ ]
1261
+ },
1262
+ {
1263
+ "name": "stdout",
1264
+ "output_type": "stream",
1265
+ "text": [
1266
+ "Train Loss: 0.0000, Train F1: 0.9747\n",
1267
+ "Val Loss: 0.0000, Val F1: 0.9535\n",
1268
+ "Learning rate: 0.000037\n",
1269
+ "βœ“ Saved best model with F1: 0.9535\n",
1270
+ "\n",
1271
+ "============================================================\n",
1272
+ "Training completed!\n",
1273
+ "Best validation F1: 0.9535\n",
1274
+ "Training curves saved to: lstm_training_curves.png\n",
1275
+ "\n",
1276
+ "Model saved for deployment in 'saved_lstm_model/' directory\n",
1277
+ "Files saved:\n",
1278
+ " - saved_lstm_model/pii_lstm_model.pt\n",
1279
+ " - saved_lstm_model/vocabularies.pkl\n",
1280
+ " - saved_lstm_model/model_config.pkl\n"
1281
+ ]
1282
+ }
1283
+ ],
1284
+ "source": [
1285
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
1286
+ "print(f\"Using device: {device}\")\n",
1287
+ "\n",
1288
+ "model, text_vocab, label_vocab = train_lstm_pii_model(\n",
1289
+ " data_path='train_augmented.json',\n",
1290
+ " num_epochs=20,\n",
1291
+ " batch_size=32,\n",
1292
+ " learning_rate=3e-4,\n",
1293
+ " use_focal_loss=True,\n",
1294
+ " focal_gamma=2.0,\n",
1295
+ " device=device\n",
1296
+ ")"
1297
+ ]
1298
+ }
1299
+ ],
1300
+ "metadata": {
1301
+ "kaggle": {
1302
+ "accelerator": "nvidiaTeslaT4",
1303
+ "dataSources": [
1304
+ {
1305
+ "isSourceIdPinned": true,
1306
+ "modelId": 419045,
1307
+ "modelInstanceId": 400879,
1308
+ "sourceId": 504813,
1309
+ "sourceType": "modelInstanceVersion"
1310
+ }
1311
+ ],
1312
+ "dockerImageVersionId": 31090,
1313
+ "isGpuEnabled": true,
1314
+ "isInternetEnabled": true,
1315
+ "language": "python",
1316
+ "sourceType": "notebook"
1317
+ },
1318
+ "kernelspec": {
1319
+ "display_name": "py310-torch",
1320
+ "language": "python",
1321
+ "name": "python3"
1322
+ },
1323
+ "language_info": {
1324
+ "codemirror_mode": {
1325
+ "name": "ipython",
1326
+ "version": 3
1327
+ },
1328
+ "file_extension": ".py",
1329
+ "mimetype": "text/x-python",
1330
+ "name": "python",
1331
+ "nbconvert_exporter": "python",
1332
+ "pygments_lexer": "ipython3",
1333
+ "version": "3.10.18"
1334
+ },
1335
+ "papermill": {
1336
+ "default_parameters": {},
1337
+ "duration": 36216.685618,
1338
+ "end_time": "2025-08-04T04:06:35.164363",
1339
+ "environment_variables": {},
1340
+ "exception": null,
1341
+ "input_path": "__notebook__.ipynb",
1342
+ "output_path": "__notebook__.ipynb",
1343
+ "parameters": {},
1344
+ "start_time": "2025-08-03T18:02:58.478745",
1345
+ "version": "2.6.0"
1346
+ }
1347
+ },
1348
+ "nbformat": 4,
1349
+ "nbformat_minor": 5
1350
+ }
requirements.txt ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ asttokens==3.0.0
5
+ Brotli==1.1.0
6
+ brotlicffi==1.0.9.2
7
+ certifi==2025.7.14
8
+ cffi==1.17.1
9
+ charset-normalizer==3.3.2
10
+ click==8.2.1
11
+ colorama==0.4.6
12
+ comm==0.2.3
13
+ contourpy==1.3.2
14
+ cycler==0.12.1
15
+ debugpy==1.8.15
16
+ decorator==5.2.1
17
+ exceptiongroup==1.3.0
18
+ executing==2.2.0
19
+ Faker==37.5.3
20
+ fastapi==0.116.1
21
+ ffmpy==0.6.1
22
+ filelock==3.17.0
23
+ fonttools==4.59.0
24
+ fsspec==2025.7.0
25
+ gmpy2==2.2.1
26
+ gradio==5.39.0
27
+ gradio_client==1.11.0
28
+ groovy==0.1.2
29
+ h11==0.16.0
30
+ httpcore==1.0.9
31
+ httpx==0.28.1
32
+ huggingface-hub==0.34.3
33
+ idna==3.7
34
+ importlib_metadata==8.7.0
35
+ ipykernel==6.30.0
36
+ ipython==8.37.0
37
+ jedi==0.19.2
38
+ Jinja2==3.1.6
39
+ joblib==1.5.1
40
+ jupyter_client==8.6.3
41
+ jupyter_core==5.8.1
42
+ kiwisolver==1.4.8
43
+ markdown-it-py==3.0.0
44
+ MarkupSafe==3.0.2
45
+ matplotlib==3.10.5
46
+ matplotlib-inline==0.1.7
47
+ mdurl==0.1.2
48
+ mkl_fft==1.3.11
49
+ mkl_random==1.2.8
50
+ mkl-service==2.4.0
51
+ mpmath==1.3.0
52
+ nest_asyncio==1.6.0
53
+ networkx==3.4.2
54
+ numpy==2.0.1
55
+ orjson==3.11.1
56
+ packaging==25.0
57
+ pandas==2.3.1
58
+ parso==0.8.4
59
+ pickleshare==0.7.5
60
+ pillow==11.1.0
61
+ pip==25.1
62
+ platformdirs==4.3.8
63
+ prompt_toolkit==3.0.51
64
+ psutil==7.0.0
65
+ pure_eval==0.2.3
66
+ pycparser==2.21
67
+ pydantic==2.11.7
68
+ pydantic_core==2.33.2
69
+ pydub==0.25.1
70
+ Pygments==2.19.2
71
+ pyparsing==3.2.3
72
+ PySocks==1.7.1
73
+ python-dateutil==2.9.0.post0
74
+ python-multipart==0.0.20
75
+ pytz==2025.2
76
+ pywin32==311
77
+ PyYAML==6.0.2
78
+ pyzmq==27.0.0
79
+ requests==2.32.4
80
+ rich==14.1.0
81
+ ruff==0.12.7
82
+ safehttpx==0.1.6
83
+ scikit-learn==1.7.1
84
+ scipy==1.15.3
85
+ semantic-version==2.10.0
86
+ setuptools==78.1.1
87
+ shellingham==1.5.4
88
+ six==1.17.0
89
+ sniffio==1.3.1
90
+ stack_data==0.6.3
91
+ starlette==0.47.2
92
+ sympy==1.13.1
93
+ threadpoolctl==3.6.0
94
+ tomlkit==0.13.3
95
+ torch==2.5.1
96
+ torchaudio==2.5.1
97
+ torchvision==0.20.1
98
+ tornado==6.5.1
99
+ tqdm==4.67.1
100
+ traitlets==5.14.3
101
+ typer==0.16.0
102
+ typing_extensions==4.12.2
103
+ typing-inspection==0.4.1
104
+ tzdata==2025.2
105
+ urllib3==2.5.0
106
+ uvicorn==0.35.0
107
+ wcwidth==0.2.13
108
+ websockets==15.0.1
109
+ wheel==0.45.1
110
+ win-inet-pton==1.1.0
111
+ zipp==3.23.0
saved_lstm/best_lstm_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:506bc8e4cf77c01844014f0b0f4b2a89235ba256be99cc221bd654722cbe1511
3
+ size 131276970
saved_lstm/lstm_training_curves.png ADDED

Git LFS Details

  • SHA256: 7541582e6aaa9ab04826f862fe4c2eb178a7cc0c428aecc6609ed6a131817339
  • Pointer size: 131 Bytes
  • Size of remote file: 206 kB
saved_lstm/model_config.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ca75bdd6914a57731e2aaa0d62a94e31263e92108d1cc4357f701a2bb92a7e7
3
+ size 132
saved_lstm/pii_lstm_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7abbce5ad1109cfbade6c4d12d2ae3fc4247e187287d8ca270603e4767613ae
3
+ size 42936702
saved_lstm/vocabularies.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9242868bc0aaefd68419706becdee2cf7336799c047886e1e1639af8e1726978
3
+ size 1996397
saved_transformer/best_transformer_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:921e9615e48200b1970f359b8e8343d310f3125d7f070e37b666d38b687e0778
3
+ size 229015868
saved_transformer/model_config.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1caf1f3dd5bcc8ca70bff3443223a0a68636d3f3da5c32897170998e2ca0bc83
3
+ size 132
saved_transformer/pii_transformer_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b31abf1d79f3cd02870dfd47eb3deef0d4456d20023b906bd55c0124cd374b3c
3
+ size 75867152
saved_transformer/transformer_training_curves.png ADDED

Git LFS Details

  • SHA256: 0e4aa3d3f521646f300a70f2b189abc66c3d5c0c2bae7eb920d4ce57f3a24b50
  • Pointer size: 131 Bytes
  • Size of remote file: 191 kB
saved_transformer/vocabularies.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9242868bc0aaefd68419706becdee2cf7336799c047886e1e1639af8e1726978
3
+ size 1996397
train.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8276cd44f3b2eb357dfb405b3c5d8e9f821388e984cbf66e92e7df03f1b13117
3
+ size 109496478
train_augmented.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f195815ab3d4b50ec302f6fc4ab07770c440608054fab9a83136229c0b723e8
3
+ size 59487171
transformer.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ def scaled_dot_product_attention(q, k, v, mask=None, dropout=None):
7
+ """
8
+ Compute scaled dot-product attention.
9
+
10
+ Args:
11
+ q: queries (batch_size, num_heads, seq_len_q, d_k)
12
+ k: keys (batch_size, num_heads, seq_len_k, d_k)
13
+ v: values (batch_size, num_heads, seq_len_v, d_v)
14
+ mask: mask tensor (batch_size, 1, 1, seq_len_k) or (batch_size, 1, seq_len_q, seq_len_k)
15
+ dropout: dropout layer
16
+
17
+ Returns:
18
+ output: attended values (batch_size, num_heads, seq_len_q, d_v)
19
+ attention_weights: attention weights (batch_size, num_heads, seq_len_q, seq_len_k)
20
+ """
21
+ d_k = q.size(-1)
22
+
23
+ # Calculate attention scores
24
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
25
+
26
+ # Apply mask if provided
27
+ if mask is not None:
28
+ scores = scores.masked_fill(mask == 0, float('-inf'))
29
+
30
+ # Apply softmax
31
+ attention_weights = F.softmax(scores, dim=-1)
32
+
33
+ # Apply dropout if provided
34
+ if dropout is not None:
35
+ attention_weights = dropout(attention_weights)
36
+
37
+ # Apply attention to values
38
+ output = torch.matmul(attention_weights, v)
39
+
40
+ return output, attention_weights
41
+
42
+ class MultiHeadAttention(nn.Module):
43
+ """Multi-Head Attention mechanism"""
44
+
45
+ def __init__(self, d_model, num_heads, dropout=0.1):
46
+ super(MultiHeadAttention, self).__init__()
47
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
48
+
49
+ self.d_model = d_model
50
+ self.num_heads = num_heads
51
+ self.d_k = d_model // num_heads
52
+
53
+ # Linear projections for Q, K, V
54
+ self.w_q = nn.Linear(d_model, d_model)
55
+ self.w_k = nn.Linear(d_model, d_model)
56
+ self.w_v = nn.Linear(d_model, d_model)
57
+
58
+ # Output projection
59
+ self.w_o = nn.Linear(d_model, d_model)
60
+
61
+ # Dropout
62
+ self.dropout = nn.Dropout(dropout)
63
+
64
+ def forward(self, query, key, value, mask=None):
65
+ """
66
+ Args:
67
+ query: (batch_size, seq_len_q, d_model)
68
+ key: (batch_size, seq_len_k, d_model)
69
+ value: (batch_size, seq_len_v, d_model)
70
+ mask: (batch_size, 1, 1, seq_len_k) or None
71
+
72
+ Returns:
73
+ output: (batch_size, seq_len_q, d_model)
74
+ attention_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
75
+ """
76
+ batch_size = query.size(0)
77
+ seq_len_q = query.size(1)
78
+ seq_len_k = key.size(1)
79
+ seq_len_v = value.size(1)
80
+
81
+ # 1. Linear projections in batch from d_model => h x d_k
82
+ Q = self.w_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
83
+ K = self.w_k(key).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
84
+ V = self.w_v(value).view(batch_size, seq_len_v, self.num_heads, self.d_k).transpose(1, 2)
85
+
86
+ # 2. Apply attention on all the projected vectors in batch
87
+ attention_output, attention_weights = scaled_dot_product_attention(
88
+ Q, K, V, mask=mask, dropout=self.dropout
89
+ )
90
+
91
+ # 3. Concatenate heads and put through final linear layer
92
+ attention_output = attention_output.transpose(1, 2).contiguous().view(
93
+ batch_size, seq_len_q, self.d_model
94
+ )
95
+
96
+ output = self.w_o(attention_output)
97
+
98
+ return output, attention_weights
99
+
100
+ class PositionwiseFeedForward(nn.Module):
101
+ """Position-wise Feed Forward Network"""
102
+
103
+ def __init__(self, d_model, d_ff, dropout=0.1):
104
+ super(PositionwiseFeedForward, self).__init__()
105
+ self.w_1 = nn.Linear(d_model, d_ff)
106
+ self.w_2 = nn.Linear(d_ff, d_model)
107
+ self.dropout = nn.Dropout(dropout)
108
+ self.activation = nn.ReLU()
109
+
110
+ def forward(self, x):
111
+ """
112
+ Args:
113
+ x: (batch_size, seq_len, d_model)
114
+
115
+ Returns:
116
+ output: (batch_size, seq_len, d_model)
117
+ """
118
+ return self.w_2(self.dropout(self.activation(self.w_1(x))))
119
+
120
+ class EncoderLayer(nn.Module):
121
+ """Single Encoder Layer"""
122
+
123
+ def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
124
+ super(EncoderLayer, self).__init__()
125
+
126
+ # Multi-head attention
127
+ self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
128
+
129
+ # Position-wise feed forward
130
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
131
+
132
+ # Layer normalization
133
+ self.norm1 = nn.LayerNorm(d_model)
134
+ self.norm2 = nn.LayerNorm(d_model)
135
+
136
+ # Dropout
137
+ self.dropout = nn.Dropout(dropout)
138
+
139
+ def forward(self, x, mask=None):
140
+ """
141
+ Args:
142
+ x: (batch_size, seq_len, d_model)
143
+ mask: (batch_size, 1, 1, seq_len) or None
144
+
145
+ Returns:
146
+ output: (batch_size, seq_len, d_model)
147
+ """
148
+ # Self-attention with residual connection and layer norm
149
+ attn_output, _ = self.self_attention(x, x, x, mask)
150
+ x = self.norm1(x + self.dropout(attn_output))
151
+
152
+ # Feed forward with residual connection and layer norm
153
+ ff_output = self.feed_forward(x)
154
+ x = self.norm2(x + self.dropout(ff_output))
155
+
156
+ return x
157
+
158
+ class TransformerEncoder(nn.Module):
159
+ """Stack of Encoder Layers"""
160
+
161
+ def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
162
+ super(TransformerEncoder, self).__init__()
163
+
164
+ self.layers = nn.ModuleList([
165
+ EncoderLayer(d_model, num_heads, d_ff, dropout)
166
+ for _ in range(num_layers)
167
+ ])
168
+
169
+ self.norm = nn.LayerNorm(d_model)
170
+
171
+ def forward(self, x, mask=None):
172
+ """
173
+ Args:
174
+ x: (batch_size, seq_len, d_model)
175
+ mask: (batch_size, 1, 1, seq_len) or None
176
+
177
+ Returns:
178
+ output: (batch_size, seq_len, d_model)
179
+ """
180
+ for layer in self.layers:
181
+ x = layer(x, mask)
182
+
183
+ return self.norm(x)
184
+
185
+ class PositionalEncoding(nn.Module):
186
+ """Positional Encoding for Transformer"""
187
+
188
+ def __init__(self, d_model, max_len=5000, dropout=0.1):
189
+ super(PositionalEncoding, self).__init__()
190
+ self.dropout = nn.Dropout(dropout)
191
+
192
+ # Create positional encoding matrix
193
+ pe = torch.zeros(max_len, d_model)
194
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
195
+
196
+ # Create div_term for sin/cos frequencies
197
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
198
+ (-math.log(10000.0) / d_model))
199
+
200
+ # Apply sin to even indices
201
+ pe[:, 0::2] = torch.sin(position * div_term)
202
+
203
+ # Apply cos to odd indices
204
+ if d_model % 2 == 1:
205
+ pe[:, 1::2] = torch.cos(position * div_term[:-1])
206
+ else:
207
+ pe[:, 1::2] = torch.cos(position * div_term)
208
+
209
+ # Add batch dimension and register as buffer
210
+ pe = pe.unsqueeze(0)
211
+ self.register_buffer('pe', pe)
212
+
213
+ def forward(self, x):
214
+ """
215
+ Args:
216
+ x: (batch_size, seq_len, d_model)
217
+
218
+ Returns:
219
+ output: (batch_size, seq_len, d_model)
220
+ """
221
+ # Add positional encoding
222
+ x = x + self.pe[:, :x.size(1), :]
223
+ return self.dropout(x)
224
+
225
+ class TransformerPII(nn.Module):
226
+ """
227
+ Transformer model for PII detection (token classification)
228
+ Built from scratch with custom implementation
229
+ """
230
+
231
+ def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
232
+ d_ff=512, num_layers=4, dropout=0.1, max_len=512, pad_idx=0):
233
+ super(TransformerPII, self).__init__()
234
+
235
+ self.d_model = d_model
236
+ self.pad_idx = pad_idx
237
+
238
+ # Token embedding layer
239
+ self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
240
+
241
+ # Positional encoding
242
+ self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
243
+
244
+ # Transformer encoder stack
245
+ self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
246
+
247
+ # Classification head
248
+ self.classifier = nn.Linear(d_model, num_classes)
249
+
250
+ # Dropout
251
+ self.dropout = nn.Dropout(dropout)
252
+
253
+ # Initialize weights
254
+ self._init_weights()
255
+
256
+ def _init_weights(self):
257
+ """Initialize model weights"""
258
+ # Initialize embeddings
259
+ nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model**-0.5)
260
+ if self.pad_idx is not None:
261
+ nn.init.constant_(self.embedding.weight[self.pad_idx], 0)
262
+
263
+ # Initialize classifier
264
+ nn.init.xavier_uniform_(self.classifier.weight)
265
+ if self.classifier.bias is not None:
266
+ nn.init.constant_(self.classifier.bias, 0)
267
+
268
+ def create_padding_mask(self, x):
269
+ """
270
+ Create padding mask for attention
271
+
272
+ Args:
273
+ x: (batch_size, seq_len) - input token indices
274
+
275
+ Returns:
276
+ mask: (batch_size, 1, 1, seq_len) - attention mask
277
+ """
278
+ # Create mask where padding tokens are marked as 0
279
+ mask = (x != self.pad_idx).unsqueeze(1).unsqueeze(2)
280
+ return mask.float()
281
+
282
+ def forward(self, x, mask=None):
283
+ """
284
+ Forward pass for token classification
285
+
286
+ Args:
287
+ x: (batch_size, seq_len) - input token indices
288
+ mask: Optional custom attention mask
289
+
290
+ Returns:
291
+ logits: (batch_size, seq_len, num_classes) - classification logits
292
+ """
293
+ # Check input dimensions
294
+ if x.dim() != 2:
295
+ raise ValueError(f"Expected input to have 2 dimensions [batch_size, seq_len], got {x.dim()}")
296
+
297
+ batch_size, seq_len = x.shape
298
+
299
+ # Create padding mask if not provided
300
+ if mask is None:
301
+ mask = self.create_padding_mask(x)
302
+
303
+ # Embedding with scaling
304
+ x = self.embedding(x) * math.sqrt(self.d_model)
305
+
306
+ # Add positional encoding
307
+ x = self.positional_encoding(x)
308
+
309
+ # Pass through transformer encoder
310
+ encoder_output = self.encoder(x, mask)
311
+
312
+ # Apply dropout before classification
313
+ encoder_output = self.dropout(encoder_output)
314
+
315
+ # Classify each token
316
+ logits = self.classifier(encoder_output)
317
+
318
+ return logits
319
+
320
+ def predict(self, x):
321
+ """
322
+ Get predictions for inference
323
+
324
+ Args:
325
+ x: (batch_size, seq_len) - input token indices
326
+
327
+ Returns:
328
+ predictions: (batch_size, seq_len) - predicted class indices
329
+ """
330
+ self.eval()
331
+ with torch.no_grad():
332
+ logits = self.forward(x)
333
+ predictions = torch.argmax(logits, dim=-1)
334
+ return predictions
335
+
336
+ class TransformerPIIWithCRF(TransformerPII):
337
+ """
338
+ Transformer with CRF layer for improved sequence labeling
339
+ (Optional enhancement - requires pytorch-crf)
340
+ """
341
+
342
+ def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
343
+ d_ff=512, num_layers=4, dropout=0.1, max_len=512, pad_idx=0):
344
+ super(TransformerPIIWithCRF, self).__init__(
345
+ vocab_size, num_classes, d_model, num_heads,
346
+ d_ff, num_layers, dropout, max_len, pad_idx
347
+ )
348
+
349
+ # CRF layer would be initialized here
350
+ # from torchcrf import CRF
351
+ # self.crf = CRF(num_classes, batch_first=True)
352
+
353
+ def forward(self, x, labels=None):
354
+ """Forward pass with optional CRF"""
355
+ # Get transformer outputs
356
+ emissions = super().forward(x)
357
+
358
+ if labels is not None:
359
+ # Training mode with CRF
360
+ # mask = (x != self.pad_idx)
361
+ # loss = -self.crf(emissions, labels, mask=mask)
362
+ # return loss
363
+ pass
364
+ else:
365
+ # Inference mode with CRF
366
+ # mask = (x != self.pad_idx)
367
+ # predictions = self.crf.decode(emissions, mask=mask)
368
+ # return predictions
369
+ pass
370
+
371
+ return emissions
372
+
373
+ def create_transformer_pii_model(vocab_size, num_classes, d_model=256, num_heads=8,
374
+ d_ff=512, num_layers=4, dropout=0.1, max_len=512):
375
+ """
376
+ Factory function to create transformer model for PII detection
377
+
378
+ Args:
379
+ vocab_size: Size of vocabulary
380
+ num_classes: Number of PII classes (e.g., 20)
381
+ d_model: Dimension of model (hidden size)
382
+ num_heads: Number of attention heads
383
+ d_ff: Dimension of feedforward network
384
+ num_layers: Number of transformer layers
385
+ dropout: Dropout rate
386
+ max_len: Maximum sequence length
387
+
388
+ Returns:
389
+ TransformerPII model instance
390
+ """
391
+ model = TransformerPII(
392
+ vocab_size=vocab_size,
393
+ num_classes=num_classes,
394
+ d_model=d_model,
395
+ num_heads=num_heads,
396
+ d_ff=d_ff,
397
+ num_layers=num_layers,
398
+ dropout=dropout,
399
+ max_len=max_len,
400
+ pad_idx=0 # Assuming 0 is padding index
401
+ )
402
+
403
+ return model
transformer_training.ipynb ADDED
@@ -0,0 +1,1319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "216181fb",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2025-08-03T16:54:32.135992Z",
10
+ "iopub.status.busy": "2025-08-03T16:54:32.135203Z",
11
+ "iopub.status.idle": "2025-08-03T16:54:44.757081Z",
12
+ "shell.execute_reply": "2025-08-03T16:54:44.756283Z"
13
+ },
14
+ "papermill": {
15
+ "duration": 12.627911,
16
+ "end_time": "2025-08-03T16:54:44.758473",
17
+ "exception": false,
18
+ "start_time": "2025-08-03T16:54:32.130562",
19
+ "status": "completed"
20
+ },
21
+ "tags": []
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "import torch\n",
26
+ "import torch.nn as nn\n",
27
+ "import torch.optim as optim\n",
28
+ "from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler\n",
29
+ "from torch.nn.utils.rnn import pad_sequence\n",
30
+ "import pandas as pd\n",
31
+ "import numpy as np\n",
32
+ "from sklearn.model_selection import train_test_split\n",
33
+ "from collections import Counter\n",
34
+ "import pickle\n",
35
+ "from tqdm import tqdm\n",
36
+ "import matplotlib.pyplot as plt\n",
37
+ "import os\n",
38
+ "from datetime import datetime\n",
39
+ "from transformer import create_transformer_pii_model\n",
40
+ "from data_augmentation import calculate_class_weights"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 4,
46
+ "id": "ff1782dd",
47
+ "metadata": {
48
+ "execution": {
49
+ "iopub.execute_input": "2025-08-03T16:54:44.767637Z",
50
+ "iopub.status.busy": "2025-08-03T16:54:44.766805Z",
51
+ "iopub.status.idle": "2025-08-03T16:54:44.774888Z",
52
+ "shell.execute_reply": "2025-08-03T16:54:44.774045Z"
53
+ },
54
+ "papermill": {
55
+ "duration": 0.013734,
56
+ "end_time": "2025-08-03T16:54:44.776187",
57
+ "exception": false,
58
+ "start_time": "2025-08-03T16:54:44.762453",
59
+ "status": "completed"
60
+ },
61
+ "tags": []
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "class Vocabulary:\n",
66
+ " \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
67
+ " def __init__(self, max_size=100000):\n",
68
+ " self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
69
+ " self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
70
+ " self.word_count = Counter()\n",
71
+ " self.max_size = max_size\n",
72
+ " \n",
73
+ " def add_sentence(self, sentence):\n",
74
+ " for word in sentence:\n",
75
+ " self.word_count[word.lower()] += 1\n",
76
+ " \n",
77
+ " def build(self):\n",
78
+ " most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
79
+ " for word, _ in most_common:\n",
80
+ " if word not in self.word2idx:\n",
81
+ " idx = len(self.word2idx)\n",
82
+ " self.word2idx[word] = idx\n",
83
+ " self.idx2word[idx] = word\n",
84
+ " \n",
85
+ " def __len__(self):\n",
86
+ " return len(self.word2idx)\n",
87
+ " \n",
88
+ " def encode(self, sentence):\n",
89
+ " return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
90
+ " \n",
91
+ " def decode(self, indices):\n",
92
+ " return [self.idx2word.get(idx, '<unk>') for idx in indices]"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 5,
98
+ "id": "5b2b46d6",
99
+ "metadata": {
100
+ "execution": {
101
+ "iopub.execute_input": "2025-08-03T16:54:44.785061Z",
102
+ "iopub.status.busy": "2025-08-03T16:54:44.784479Z",
103
+ "iopub.status.idle": "2025-08-03T16:54:44.790645Z",
104
+ "shell.execute_reply": "2025-08-03T16:54:44.790095Z"
105
+ },
106
+ "papermill": {
107
+ "duration": 0.011749,
108
+ "end_time": "2025-08-03T16:54:44.791761",
109
+ "exception": false,
110
+ "start_time": "2025-08-03T16:54:44.780012",
111
+ "status": "completed"
112
+ },
113
+ "tags": []
114
+ },
115
+ "outputs": [],
116
+ "source": [
117
+ "class PIIDataset(Dataset):\n",
118
+ " \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
119
+ " def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
120
+ " self.tokens = tokens\n",
121
+ " self.labels = labels\n",
122
+ " self.text_vocab = text_vocab\n",
123
+ " self.label_vocab = label_vocab\n",
124
+ " self.max_len = max_len\n",
125
+ " \n",
126
+ " def __len__(self):\n",
127
+ " return len(self.tokens)\n",
128
+ " \n",
129
+ " def __getitem__(self, idx):\n",
130
+ " # Add start and end tokens\n",
131
+ " tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
132
+ " labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
133
+ " \n",
134
+ " # Truncate if too long\n",
135
+ " if len(tokens) > self.max_len:\n",
136
+ " tokens = tokens[:self.max_len-1] + ['<end>']\n",
137
+ " labels = labels[:self.max_len-1] + ['<end>']\n",
138
+ " \n",
139
+ " # Encode\n",
140
+ " token_ids = self.text_vocab.encode(tokens)\n",
141
+ " label_ids = self.label_vocab.encode(labels)\n",
142
+ " \n",
143
+ " return torch.tensor(token_ids), torch.tensor(label_ids)"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 6,
149
+ "id": "e7ca8f8f",
150
+ "metadata": {
151
+ "execution": {
152
+ "iopub.execute_input": "2025-08-03T16:54:44.799705Z",
153
+ "iopub.status.busy": "2025-08-03T16:54:44.799475Z",
154
+ "iopub.status.idle": "2025-08-03T16:54:44.803433Z",
155
+ "shell.execute_reply": "2025-08-03T16:54:44.802870Z"
156
+ },
157
+ "papermill": {
158
+ "duration": 0.009288,
159
+ "end_time": "2025-08-03T16:54:44.804692",
160
+ "exception": false,
161
+ "start_time": "2025-08-03T16:54:44.795404",
162
+ "status": "completed"
163
+ },
164
+ "tags": []
165
+ },
166
+ "outputs": [],
167
+ "source": [
168
+ "def collate_fn(batch):\n",
169
+ " \"\"\"Custom collate function for padding sequences\"\"\"\n",
170
+ " tokens, labels = zip(*batch)\n",
171
+ " tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
172
+ " labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
173
+ " return tokens_padded, labels_padded"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": 7,
179
+ "id": "85b32e21",
180
+ "metadata": {
181
+ "execution": {
182
+ "iopub.execute_input": "2025-08-03T16:54:44.813147Z",
183
+ "iopub.status.busy": "2025-08-03T16:54:44.812906Z",
184
+ "iopub.status.idle": "2025-08-03T16:54:44.823227Z",
185
+ "shell.execute_reply": "2025-08-03T16:54:44.822443Z"
186
+ },
187
+ "papermill": {
188
+ "duration": 0.016244,
189
+ "end_time": "2025-08-03T16:54:44.824490",
190
+ "exception": false,
191
+ "start_time": "2025-08-03T16:54:44.808246",
192
+ "status": "completed"
193
+ },
194
+ "tags": []
195
+ },
196
+ "outputs": [],
197
+ "source": [
198
+ "class F1ScoreMetric:\n",
199
+ " \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
200
+ " def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
201
+ " self.beta = beta\n",
202
+ " self.num_classes = num_classes\n",
203
+ " self.ignore_index = ignore_index\n",
204
+ " self.label_vocab = label_vocab\n",
205
+ " self.reset()\n",
206
+ " \n",
207
+ " def reset(self):\n",
208
+ " self.true_positives = 0\n",
209
+ " self.false_positives = 0\n",
210
+ " self.false_negatives = 0\n",
211
+ " self.class_metrics = {}\n",
212
+ " \n",
213
+ " def update(self, predictions, targets):\n",
214
+ " mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
215
+ " o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
216
+ " \n",
217
+ " for class_id in range(1, self.num_classes):\n",
218
+ " if class_id == o_idx:\n",
219
+ " continue\n",
220
+ " \n",
221
+ " pred_mask = (predictions == class_id) & mask\n",
222
+ " true_mask = (targets == class_id) & mask\n",
223
+ " \n",
224
+ " tp = ((pred_mask) & (true_mask)).sum().item()\n",
225
+ " fp = ((pred_mask) & (~true_mask)).sum().item()\n",
226
+ " fn = ((~pred_mask) & (true_mask)).sum().item()\n",
227
+ " \n",
228
+ " self.true_positives += tp\n",
229
+ " self.false_positives += fp\n",
230
+ " self.false_negatives += fn\n",
231
+ " \n",
232
+ " if class_id not in self.class_metrics:\n",
233
+ " self.class_metrics[class_id] = {'tp': 0, 'fp': 0, 'fn': 0}\n",
234
+ " self.class_metrics[class_id]['tp'] += tp\n",
235
+ " self.class_metrics[class_id]['fp'] += fp\n",
236
+ " self.class_metrics[class_id]['fn'] += fn\n",
237
+ " \n",
238
+ " def compute(self):\n",
239
+ " beta_squared = self.beta ** 2\n",
240
+ " precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
241
+ " recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
242
+ " f1 = (1 + beta_squared) * precision * recall / (beta_squared * precision + recall + 1e-8)\n",
243
+ " return f1\n",
244
+ " \n",
245
+ " def get_class_metrics(self):\n",
246
+ " results = {}\n",
247
+ " for class_id, metrics in self.class_metrics.items():\n",
248
+ " if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
249
+ " class_name = self.label_vocab.idx2word[class_id]\n",
250
+ " precision = metrics['tp'] / (metrics['tp'] + metrics['fp'] + 1e-8)\n",
251
+ " recall = metrics['tp'] / (metrics['tp'] + metrics['fn'] + 1e-8)\n",
252
+ " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
253
+ " results[class_name] = {\n",
254
+ " 'precision': precision,\n",
255
+ " 'recall': recall,\n",
256
+ " 'f1': f1,\n",
257
+ " 'support': metrics['tp'] + metrics['fn']\n",
258
+ " }\n",
259
+ " return results"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": 8,
265
+ "id": "60cf16eb",
266
+ "metadata": {
267
+ "execution": {
268
+ "iopub.execute_input": "2025-08-03T16:54:44.832210Z",
269
+ "iopub.status.busy": "2025-08-03T16:54:44.831970Z",
270
+ "iopub.status.idle": "2025-08-03T16:54:44.837466Z",
271
+ "shell.execute_reply": "2025-08-03T16:54:44.836871Z"
272
+ },
273
+ "papermill": {
274
+ "duration": 0.01072,
275
+ "end_time": "2025-08-03T16:54:44.838672",
276
+ "exception": false,
277
+ "start_time": "2025-08-03T16:54:44.827952",
278
+ "status": "completed"
279
+ },
280
+ "tags": []
281
+ },
282
+ "outputs": [],
283
+ "source": [
284
+ "class FocalLoss(nn.Module):\n",
285
+ " \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
286
+ " def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
287
+ " super(FocalLoss, self).__init__()\n",
288
+ " self.alpha = alpha\n",
289
+ " self.gamma = gamma\n",
290
+ " self.reduction = reduction\n",
291
+ " self.ignore_index = ignore_index\n",
292
+ " \n",
293
+ " def forward(self, inputs, targets):\n",
294
+ " ce_loss = nn.functional.cross_entropy(\n",
295
+ " inputs, targets, \n",
296
+ " weight=self.alpha, \n",
297
+ " reduction='none',\n",
298
+ " ignore_index=self.ignore_index\n",
299
+ " )\n",
300
+ " \n",
301
+ " pt = torch.exp(-ce_loss)\n",
302
+ " focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
303
+ " \n",
304
+ " if self.reduction == 'mean':\n",
305
+ " return focal_loss.mean()\n",
306
+ " elif self.reduction == 'sum':\n",
307
+ " return focal_loss.sum()\n",
308
+ " else:\n",
309
+ " return focal_loss"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 9,
315
+ "id": "4e56747c",
316
+ "metadata": {
317
+ "execution": {
318
+ "iopub.execute_input": "2025-08-03T16:54:44.846907Z",
319
+ "iopub.status.busy": "2025-08-03T16:54:44.846289Z",
320
+ "iopub.status.idle": "2025-08-03T16:54:44.852363Z",
321
+ "shell.execute_reply": "2025-08-03T16:54:44.851772Z"
322
+ },
323
+ "papermill": {
324
+ "duration": 0.011242,
325
+ "end_time": "2025-08-03T16:54:44.853481",
326
+ "exception": false,
327
+ "start_time": "2025-08-03T16:54:44.842239",
328
+ "status": "completed"
329
+ },
330
+ "tags": []
331
+ },
332
+ "outputs": [],
333
+ "source": [
334
+ "def train_epoch(model, dataloader, optimizer, criterion, device, f1_metric):\n",
335
+ " \"\"\"Train for one epoch\"\"\"\n",
336
+ " model.train()\n",
337
+ " total_loss = 0\n",
338
+ " f1_metric.reset()\n",
339
+ " \n",
340
+ " progress_bar = tqdm(dataloader, desc='Training')\n",
341
+ " for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
342
+ " tokens = tokens.to(device)\n",
343
+ " labels = labels.to(device)\n",
344
+ " \n",
345
+ " # Forward pass\n",
346
+ " optimizer.zero_grad()\n",
347
+ " outputs = model(tokens)\n",
348
+ " \n",
349
+ " # Reshape for loss calculation\n",
350
+ " outputs_flat = outputs.view(-1, outputs.size(-1))\n",
351
+ " labels_flat = labels.view(-1)\n",
352
+ " \n",
353
+ " # Calculate loss and backward pass\n",
354
+ " loss = criterion(outputs_flat, labels_flat)\n",
355
+ " loss.backward()\n",
356
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
357
+ " optimizer.step()\n",
358
+ " \n",
359
+ " # Update metrics\n",
360
+ " total_loss += loss.item()\n",
361
+ " predictions = torch.argmax(outputs, dim=-1)\n",
362
+ " f1_metric.update(predictions, labels)\n",
363
+ " \n",
364
+ " # Update progress bar\n",
365
+ " progress_bar.set_postfix({\n",
366
+ " 'loss': f\"{loss.item():.4f}\",\n",
367
+ " 'f1': f\"{f1_metric.compute():.4f}\"\n",
368
+ " })\n",
369
+ " \n",
370
+ " return total_loss / len(dataloader), f1_metric.compute()"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": 10,
376
+ "id": "8a2e8d19",
377
+ "metadata": {
378
+ "execution": {
379
+ "iopub.execute_input": "2025-08-03T16:54:44.860755Z",
380
+ "iopub.status.busy": "2025-08-03T16:54:44.860552Z",
381
+ "iopub.status.idle": "2025-08-03T16:54:44.865987Z",
382
+ "shell.execute_reply": "2025-08-03T16:54:44.865175Z"
383
+ },
384
+ "papermill": {
385
+ "duration": 0.010585,
386
+ "end_time": "2025-08-03T16:54:44.867309",
387
+ "exception": false,
388
+ "start_time": "2025-08-03T16:54:44.856724",
389
+ "status": "completed"
390
+ },
391
+ "tags": []
392
+ },
393
+ "outputs": [],
394
+ "source": [
395
+ "def evaluate(model, dataloader, criterion, device, f1_metric):\n",
396
+ " \"\"\"Evaluate model on validation/test set\"\"\"\n",
397
+ " model.eval()\n",
398
+ " total_loss = 0\n",
399
+ " f1_metric.reset()\n",
400
+ " \n",
401
+ " with torch.no_grad():\n",
402
+ " for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
403
+ " tokens = tokens.to(device)\n",
404
+ " labels = labels.to(device)\n",
405
+ " \n",
406
+ " # Forward pass\n",
407
+ " outputs = model(tokens)\n",
408
+ " outputs_flat = outputs.view(-1, outputs.size(-1))\n",
409
+ " labels_flat = labels.view(-1)\n",
410
+ " \n",
411
+ " # Calculate loss\n",
412
+ " loss = criterion(outputs_flat, labels_flat)\n",
413
+ " total_loss += loss.item()\n",
414
+ " \n",
415
+ " # Update metrics\n",
416
+ " predictions = torch.argmax(outputs, dim=-1)\n",
417
+ " f1_metric.update(predictions, labels)\n",
418
+ " \n",
419
+ " return total_loss / len(dataloader), f1_metric.compute()"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": 11,
425
+ "id": "6e292ace",
426
+ "metadata": {
427
+ "execution": {
428
+ "iopub.execute_input": "2025-08-03T16:54:44.876030Z",
429
+ "iopub.status.busy": "2025-08-03T16:54:44.875513Z",
430
+ "iopub.status.idle": "2025-08-03T16:54:44.880655Z",
431
+ "shell.execute_reply": "2025-08-03T16:54:44.879870Z"
432
+ },
433
+ "papermill": {
434
+ "duration": 0.010355,
435
+ "end_time": "2025-08-03T16:54:44.881962",
436
+ "exception": false,
437
+ "start_time": "2025-08-03T16:54:44.871607",
438
+ "status": "completed"
439
+ },
440
+ "tags": []
441
+ },
442
+ "outputs": [],
443
+ "source": [
444
+ "def create_balanced_sampler(dataset, label_vocab):\n",
445
+ " \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
446
+ " sample_weights = []\n",
447
+ " \n",
448
+ " for idx in range(len(dataset)):\n",
449
+ " _, labels = dataset[idx]\n",
450
+ " \n",
451
+ " # Give higher weight to samples with rare PII\n",
452
+ " min_weight = 1.0\n",
453
+ " for label_id in labels:\n",
454
+ " if label_id > 3: # Skip special tokens\n",
455
+ " label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
456
+ " if label_name != 'o' and 'B-' in label_name:\n",
457
+ " min_weight = 10.0\n",
458
+ " break\n",
459
+ " \n",
460
+ " sample_weights.append(min_weight)\n",
461
+ " \n",
462
+ " sampler = WeightedRandomSampler(\n",
463
+ " weights=sample_weights,\n",
464
+ " num_samples=len(sample_weights),\n",
465
+ " replacement=True\n",
466
+ " )\n",
467
+ " \n",
468
+ " return sampler\n"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": 12,
474
+ "id": "857335cb",
475
+ "metadata": {
476
+ "execution": {
477
+ "iopub.execute_input": "2025-08-03T16:54:44.889690Z",
478
+ "iopub.status.busy": "2025-08-03T16:54:44.889472Z",
479
+ "iopub.status.idle": "2025-08-03T16:54:44.894459Z",
480
+ "shell.execute_reply": "2025-08-03T16:54:44.893888Z"
481
+ },
482
+ "papermill": {
483
+ "duration": 0.010295,
484
+ "end_time": "2025-08-03T16:54:44.895625",
485
+ "exception": false,
486
+ "start_time": "2025-08-03T16:54:44.885330",
487
+ "status": "completed"
488
+ },
489
+ "tags": []
490
+ },
491
+ "outputs": [],
492
+ "source": [
493
+ "def print_label_distribution(data, title=\"Label Distribution\"):\n",
494
+ " \"\"\"Print label distribution statistics\"\"\"\n",
495
+ " label_counts = Counter()\n",
496
+ " for label_seq in data.labels:\n",
497
+ " for label in label_seq:\n",
498
+ " if label not in ['<pad>', '<start>', '<end>']:\n",
499
+ " label_counts[label] += 1\n",
500
+ " \n",
501
+ " print(f\"\\n{title}:\")\n",
502
+ " print(\"-\" * 50)\n",
503
+ " total = sum(label_counts.values())\n",
504
+ " for label, count in label_counts.most_common():\n",
505
+ " percentage = (count / total) * 100\n",
506
+ " print(f\" {label:20} : {count:8,} ({percentage:5.2f}%)\")\n",
507
+ " print(\"-\" * 50)\n",
508
+ " print(f\" {'Total':20} : {total:8,}\")"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "execution_count": 13,
514
+ "id": "1738f8a9",
515
+ "metadata": {
516
+ "execution": {
517
+ "iopub.execute_input": "2025-08-03T16:54:44.903649Z",
518
+ "iopub.status.busy": "2025-08-03T16:54:44.903207Z",
519
+ "iopub.status.idle": "2025-08-03T16:54:44.908673Z",
520
+ "shell.execute_reply": "2025-08-03T16:54:44.908076Z"
521
+ },
522
+ "papermill": {
523
+ "duration": 0.010714,
524
+ "end_time": "2025-08-03T16:54:44.909864",
525
+ "exception": false,
526
+ "start_time": "2025-08-03T16:54:44.899150",
527
+ "status": "completed"
528
+ },
529
+ "tags": []
530
+ },
531
+ "outputs": [],
532
+ "source": [
533
+ "def save_model(model, text_vocab, label_vocab, config, save_dir):\n",
534
+ " \"\"\"Save model and all necessary components for Flask deployment\"\"\"\n",
535
+ " os.makedirs(save_dir, exist_ok=True)\n",
536
+ " \n",
537
+ " # Save model state\n",
538
+ " model_path = os.path.join(save_dir, 'pii_transformer_model.pt')\n",
539
+ " torch.save(model.state_dict(), model_path)\n",
540
+ " \n",
541
+ " # Save vocabularies\n",
542
+ " vocab_path = os.path.join(save_dir, 'vocabularies.pkl')\n",
543
+ " with open(vocab_path, 'wb') as f:\n",
544
+ " pickle.dump({\n",
545
+ " 'text_vocab': text_vocab,\n",
546
+ " 'label_vocab': label_vocab\n",
547
+ " }, f)\n",
548
+ " \n",
549
+ " # Save model configuration\n",
550
+ " config_path = os.path.join(save_dir, 'model_config.pkl')\n",
551
+ " with open(config_path, 'wb') as f:\n",
552
+ " pickle.dump(config, f)\n",
553
+ " \n",
554
+ " print(f\"\\nModel saved for deployment in '{save_dir}/' directory\")\n",
555
+ " print(\"Files saved:\")\n",
556
+ " print(f\" - {model_path}\")\n",
557
+ " print(f\" - {vocab_path}\")\n",
558
+ " print(f\" - {config_path}\")"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": 14,
564
+ "id": "d93e7c25",
565
+ "metadata": {
566
+ "execution": {
567
+ "iopub.execute_input": "2025-08-03T16:54:44.917693Z",
568
+ "iopub.status.busy": "2025-08-03T16:54:44.917438Z",
569
+ "iopub.status.idle": "2025-08-03T16:54:44.933820Z",
570
+ "shell.execute_reply": "2025-08-03T16:54:44.933284Z"
571
+ },
572
+ "papermill": {
573
+ "duration": 0.021776,
574
+ "end_time": "2025-08-03T16:54:44.935035",
575
+ "exception": false,
576
+ "start_time": "2025-08-03T16:54:44.913259",
577
+ "status": "completed"
578
+ },
579
+ "tags": []
580
+ },
581
+ "outputs": [],
582
+ "source": [
583
+ "def train_transformer_pii_model(\n",
584
+ " data_path,\n",
585
+ " num_epochs=30,\n",
586
+ " batch_size=32,\n",
587
+ " learning_rate=2e-4,\n",
588
+ " use_focal_loss=True,\n",
589
+ " focal_gamma=2.0,\n",
590
+ " device='cuda',\n",
591
+ "):\n",
592
+ " \"\"\"Main training function\"\"\"\n",
593
+ " \n",
594
+ " # Load data\n",
595
+ " print(\"Loading augmented data...\")\n",
596
+ " data = pd.read_json(data_path, lines=True)\n",
597
+ " print(f\"Total samples: {len(data)}\")\n",
598
+ " \n",
599
+ " # Print initial label distribution\n",
600
+ " print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
601
+ " \n",
602
+ " # Build vocabularies\n",
603
+ " print(\"\\nBuilding vocabularies...\")\n",
604
+ " text_vocab = Vocabulary(max_size=100000)\n",
605
+ " label_vocab = Vocabulary(max_size=50)\n",
606
+ " \n",
607
+ " for tokens in data.tokens:\n",
608
+ " text_vocab.add_sentence(tokens)\n",
609
+ " for labels in data.labels:\n",
610
+ " label_vocab.add_sentence(labels)\n",
611
+ " \n",
612
+ " text_vocab.build()\n",
613
+ " label_vocab.build()\n",
614
+ " \n",
615
+ " # Calculate class weights\n",
616
+ " class_weights = calculate_class_weights(data, label_vocab)\n",
617
+ " class_weights = class_weights.to(device)\n",
618
+ " \n",
619
+ " # Split data\n",
620
+ " X_train, X_val, y_train, y_val = train_test_split(\n",
621
+ " data.tokens.tolist(),\n",
622
+ " data.labels.tolist(),\n",
623
+ " test_size=0.2,\n",
624
+ " random_state=42\n",
625
+ " )\n",
626
+ " \n",
627
+ " print(f\"\\nData split:\")\n",
628
+ " print(f\" - Train samples: {len(X_train):,}\")\n",
629
+ " print(f\" - Validation samples: {len(X_val):,}\")\n",
630
+ " \n",
631
+ " # Create datasets and dataloaders\n",
632
+ " max_seq_len = 512\n",
633
+ " train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
634
+ " val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
635
+ " \n",
636
+ " train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
637
+ " \n",
638
+ " train_loader = DataLoader(\n",
639
+ " train_dataset, \n",
640
+ " batch_size=batch_size,\n",
641
+ " sampler=train_sampler,\n",
642
+ " collate_fn=collate_fn,\n",
643
+ " num_workers=0\n",
644
+ " )\n",
645
+ " \n",
646
+ " val_loader = DataLoader(\n",
647
+ " val_dataset, \n",
648
+ " batch_size=batch_size,\n",
649
+ " shuffle=False, \n",
650
+ " collate_fn=collate_fn,\n",
651
+ " num_workers=0\n",
652
+ " )\n",
653
+ " \n",
654
+ " # Model configuration\n",
655
+ " model_config = {\n",
656
+ " 'vocab_size': len(text_vocab),\n",
657
+ " 'num_classes': len(label_vocab),\n",
658
+ " 'd_model': 256,\n",
659
+ " 'num_heads': 8,\n",
660
+ " 'd_ff': 512,\n",
661
+ " 'num_layers': 4,\n",
662
+ " 'dropout': 0.1,\n",
663
+ " 'max_len': max_seq_len\n",
664
+ " }\n",
665
+ " \n",
666
+ " # Create model\n",
667
+ " print(\"\\nCreating model...\")\n",
668
+ " model = create_transformer_pii_model(**model_config).to(device)\n",
669
+ " print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
670
+ " \n",
671
+ " # Setup loss function\n",
672
+ " if use_focal_loss:\n",
673
+ " criterion = FocalLoss(\n",
674
+ " alpha=class_weights,\n",
675
+ " gamma=focal_gamma,\n",
676
+ " ignore_index=0\n",
677
+ " )\n",
678
+ " else:\n",
679
+ " criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=0)\n",
680
+ " \n",
681
+ " # Setup optimizer and scheduler\n",
682
+ " optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)\n",
683
+ " scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)\n",
684
+ " \n",
685
+ " # Metrics\n",
686
+ " f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
687
+ " f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
688
+ " \n",
689
+ " # Training loop\n",
690
+ " train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
691
+ " best_val_f1 = 0\n",
692
+ " patience = 5\n",
693
+ " patience_counter = 0\n",
694
+ " \n",
695
+ " print(\"\\nStarting training...\")\n",
696
+ " print(\"=\" * 60)\n",
697
+ " \n",
698
+ " for epoch in range(num_epochs):\n",
699
+ " print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
700
+ " \n",
701
+ " # Train and validate\n",
702
+ " train_loss, train_f1 = train_epoch(\n",
703
+ " model, train_loader, optimizer, criterion, device, f1_metric_train\n",
704
+ " )\n",
705
+ " val_loss, val_f1 = evaluate(\n",
706
+ " model, val_loader, criterion, device, f1_metric_val\n",
707
+ " )\n",
708
+ " \n",
709
+ " # Step scheduler\n",
710
+ " scheduler.step(val_loss)\n",
711
+ " \n",
712
+ " # Store metrics\n",
713
+ " train_losses.append(train_loss)\n",
714
+ " train_f1s.append(train_f1)\n",
715
+ " val_losses.append(val_loss)\n",
716
+ " val_f1s.append(val_f1)\n",
717
+ " \n",
718
+ " # Print epoch results\n",
719
+ " print(f\"Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}\")\n",
720
+ " print(f\"Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}\")\n",
721
+ " print(f\"Learning rate: {optimizer.param_groups[0]['lr']:.6f}\")\n",
722
+ " \n",
723
+ " # Save best model\n",
724
+ " if val_f1 > best_val_f1:\n",
725
+ " best_val_f1 = val_f1\n",
726
+ " patience_counter = 0\n",
727
+ " \n",
728
+ " # Save complete checkpoint\n",
729
+ " torch.save({\n",
730
+ " 'epoch': epoch,\n",
731
+ " 'model_state_dict': model.state_dict(),\n",
732
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
733
+ " 'train_loss': train_loss,\n",
734
+ " 'val_loss': val_loss,\n",
735
+ " 'train_f1': train_f1,\n",
736
+ " 'val_f1': val_f1,\n",
737
+ " 'text_vocab': text_vocab,\n",
738
+ " 'label_vocab': label_vocab,\n",
739
+ " 'model_config': model_config\n",
740
+ " }, 'best_transformer_checkpoint.pt')\n",
741
+ " \n",
742
+ " print(f\"Saved best model with F1: {val_f1:.4f}\")\n",
743
+ " else:\n",
744
+ " patience_counter += 1\n",
745
+ " \n",
746
+ " # Early stopping\n",
747
+ " if patience_counter >= patience and epoch > 10:\n",
748
+ " print(f\"\\nEarly stopping triggered after {patience} epochs without improvement\")\n",
749
+ " break\n",
750
+ " \n",
751
+ " # Plot training curves\n",
752
+ " plt.figure(figsize=(12, 5))\n",
753
+ " \n",
754
+ " plt.subplot(1, 2, 1)\n",
755
+ " plt.plot(train_losses, label='Train Loss', linewidth=2)\n",
756
+ " plt.plot(val_losses, label='Val Loss', linewidth=2)\n",
757
+ " plt.xlabel('Epoch')\n",
758
+ " plt.ylabel('Loss')\n",
759
+ " plt.title('Training and Validation Loss')\n",
760
+ " plt.legend()\n",
761
+ " plt.grid(True, alpha=0.3)\n",
762
+ " \n",
763
+ " plt.subplot(1, 2, 2)\n",
764
+ " plt.plot(train_f1s, label='Train F1', linewidth=2)\n",
765
+ " plt.plot(val_f1s, label='Val F1', linewidth=2)\n",
766
+ " plt.xlabel('Epoch')\n",
767
+ " plt.ylabel('F1 Score')\n",
768
+ " plt.title('Training and Validation F1 Score')\n",
769
+ " plt.legend()\n",
770
+ " plt.grid(True, alpha=0.3)\n",
771
+ " \n",
772
+ " plt.tight_layout()\n",
773
+ " plt.savefig('transformer_training_curves.png', dpi=300, bbox_inches='tight')\n",
774
+ " plt.close()\n",
775
+ " \n",
776
+ " print(f\"\\n{'='*60}\")\n",
777
+ " print(f\"Training completed!\")\n",
778
+ " print(f\"Best validation F1: {best_val_f1:.4f}\")\n",
779
+ " \n",
780
+ " save_model(model, text_vocab, label_vocab, model_config, 'saved_transformer_model')\n",
781
+ " \n",
782
+ " return model, text_vocab, label_vocab"
783
+ ]
784
+ },
785
+ {
786
+ "cell_type": "code",
787
+ "execution_count": null,
788
+ "id": "dbf345da",
789
+ "metadata": {
790
+ "execution": {
791
+ "iopub.execute_input": "2025-08-03T16:54:44.942669Z",
792
+ "iopub.status.busy": "2025-08-03T16:54:44.942460Z",
793
+ "iopub.status.idle": "2025-08-03T17:39:36.443370Z",
794
+ "shell.execute_reply": "2025-08-03T17:39:36.442507Z"
795
+ },
796
+ "papermill": {
797
+ "duration": 2691.506418,
798
+ "end_time": "2025-08-03T17:39:36.444814",
799
+ "exception": false,
800
+ "start_time": "2025-08-03T16:54:44.938396",
801
+ "status": "completed"
802
+ },
803
+ "tags": []
804
+ },
805
+ "outputs": [
806
+ {
807
+ "name": "stdout",
808
+ "output_type": "stream",
809
+ "text": [
810
+ "Using device: cuda\n",
811
+ "Loading augmented data...\n",
812
+ "Total samples: 19694\n",
813
+ "\n",
814
+ "Label Distribution in Augmented Data:\n",
815
+ "--------------------------------------------------\n",
816
+ " O : 5,082,150 (99.33%)\n",
817
+ " I-STREET_ADDRESS : 15,650 ( 0.31%)\n",
818
+ " B-ID_NUM : 2,505 ( 0.05%)\n",
819
+ " B-EMAIL : 2,488 ( 0.05%)\n",
820
+ " B-URL_PERSONAL : 2,478 ( 0.05%)\n",
821
+ " B-STREET_ADDRESS : 2,452 ( 0.05%)\n",
822
+ " B-PHONE_NUM : 2,450 ( 0.05%)\n",
823
+ " B-USERNAME : 2,210 ( 0.04%)\n",
824
+ " B-NAME_STUDENT : 1,968 ( 0.04%)\n",
825
+ " I-NAME_STUDENT : 1,735 ( 0.03%)\n",
826
+ " I-PHONE_NUM : 500 ( 0.01%)\n",
827
+ " I-URL_PERSONAL : 1 ( 0.00%)\n",
828
+ " I-ID_NUM : 1 ( 0.00%)\n",
829
+ "--------------------------------------------------\n",
830
+ " Total : 5,116,588\n",
831
+ "\n",
832
+ "Building vocabularies...\n",
833
+ "\n",
834
+ "Data split:\n",
835
+ " - Train samples: 15,755\n",
836
+ " - Validation samples: 3,939\n",
837
+ "\n",
838
+ "Creating model...\n",
839
+ "Model parameters: 18,828,817\n",
840
+ "\n",
841
+ "Starting training...\n",
842
+ "============================================================\n",
843
+ "\n",
844
+ "Epoch 1/20\n"
845
+ ]
846
+ },
847
+ {
848
+ "name": "stderr",
849
+ "output_type": "stream",
850
+ "text": [
851
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.06it/s, loss=0.0000, f1=0.2908]\n",
852
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.63it/s]\n"
853
+ ]
854
+ },
855
+ {
856
+ "name": "stdout",
857
+ "output_type": "stream",
858
+ "text": [
859
+ "Train Loss: 0.0001, Train F1: 0.2908\n",
860
+ "Val Loss: 0.0001, Val F1: 0.5855\n",
861
+ "Learning rate: 0.000200\n",
862
+ "Saved best model with F1: 0.5855\n",
863
+ "\n",
864
+ "Epoch 2/20\n"
865
+ ]
866
+ },
867
+ {
868
+ "name": "stderr",
869
+ "output_type": "stream",
870
+ "text": [
871
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.6256]\n",
872
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.56it/s]\n"
873
+ ]
874
+ },
875
+ {
876
+ "name": "stdout",
877
+ "output_type": "stream",
878
+ "text": [
879
+ "Train Loss: 0.0000, Train F1: 0.6256\n",
880
+ "Val Loss: 0.0000, Val F1: 0.7335\n",
881
+ "Learning rate: 0.000200\n",
882
+ "Saved best model with F1: 0.7335\n",
883
+ "\n",
884
+ "Epoch 3/20\n"
885
+ ]
886
+ },
887
+ {
888
+ "name": "stderr",
889
+ "output_type": "stream",
890
+ "text": [
891
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.06it/s, loss=0.0000, f1=0.7573]\n",
892
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.55it/s]\n"
893
+ ]
894
+ },
895
+ {
896
+ "name": "stdout",
897
+ "output_type": "stream",
898
+ "text": [
899
+ "Train Loss: 0.0000, Train F1: 0.7573\n",
900
+ "Val Loss: 0.0000, Val F1: 0.7576\n",
901
+ "Learning rate: 0.000200\n",
902
+ "Saved best model with F1: 0.7576\n",
903
+ "\n",
904
+ "Epoch 4/20\n"
905
+ ]
906
+ },
907
+ {
908
+ "name": "stderr",
909
+ "output_type": "stream",
910
+ "text": [
911
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.06it/s, loss=0.0000, f1=0.8054]\n",
912
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.58it/s]\n"
913
+ ]
914
+ },
915
+ {
916
+ "name": "stdout",
917
+ "output_type": "stream",
918
+ "text": [
919
+ "Train Loss: 0.0000, Train F1: 0.8054\n",
920
+ "Val Loss: 0.0000, Val F1: 0.7756\n",
921
+ "Learning rate: 0.000200\n",
922
+ "Saved best model with F1: 0.7756\n",
923
+ "\n",
924
+ "Epoch 5/20\n"
925
+ ]
926
+ },
927
+ {
928
+ "name": "stderr",
929
+ "output_type": "stream",
930
+ "text": [
931
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.8403]\n",
932
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.59it/s]\n"
933
+ ]
934
+ },
935
+ {
936
+ "name": "stdout",
937
+ "output_type": "stream",
938
+ "text": [
939
+ "Train Loss: 0.0000, Train F1: 0.8403\n",
940
+ "Val Loss: 0.0000, Val F1: 0.7872\n",
941
+ "Learning rate: 0.000200\n",
942
+ "Saved best model with F1: 0.7872\n",
943
+ "\n",
944
+ "Epoch 6/20\n"
945
+ ]
946
+ },
947
+ {
948
+ "name": "stderr",
949
+ "output_type": "stream",
950
+ "text": [
951
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0001, f1=0.8743]\n",
952
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.61it/s]\n"
953
+ ]
954
+ },
955
+ {
956
+ "name": "stdout",
957
+ "output_type": "stream",
958
+ "text": [
959
+ "Train Loss: 0.0000, Train F1: 0.8743\n",
960
+ "Val Loss: 0.0000, Val F1: 0.7695\n",
961
+ "Learning rate: 0.000200\n",
962
+ "\n",
963
+ "Epoch 7/20\n"
964
+ ]
965
+ },
966
+ {
967
+ "name": "stderr",
968
+ "output_type": "stream",
969
+ "text": [
970
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.8976]\n",
971
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.61it/s]\n"
972
+ ]
973
+ },
974
+ {
975
+ "name": "stdout",
976
+ "output_type": "stream",
977
+ "text": [
978
+ "Train Loss: 0.0000, Train F1: 0.8976\n",
979
+ "Val Loss: 0.0000, Val F1: 0.8148\n",
980
+ "Learning rate: 0.000200\n",
981
+ "Saved best model with F1: 0.8148\n",
982
+ "\n",
983
+ "Epoch 8/20\n"
984
+ ]
985
+ },
986
+ {
987
+ "name": "stderr",
988
+ "output_type": "stream",
989
+ "text": [
990
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.9231]\n",
991
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.62it/s]\n"
992
+ ]
993
+ },
994
+ {
995
+ "name": "stdout",
996
+ "output_type": "stream",
997
+ "text": [
998
+ "Train Loss: 0.0000, Train F1: 0.9231\n",
999
+ "Val Loss: 0.0000, Val F1: 0.8247\n",
1000
+ "Learning rate: 0.000100\n",
1001
+ "Saved best model with F1: 0.8247\n",
1002
+ "\n",
1003
+ "Epoch 9/20\n"
1004
+ ]
1005
+ },
1006
+ {
1007
+ "name": "stderr",
1008
+ "output_type": "stream",
1009
+ "text": [
1010
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.9384]\n",
1011
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.60it/s]\n"
1012
+ ]
1013
+ },
1014
+ {
1015
+ "name": "stdout",
1016
+ "output_type": "stream",
1017
+ "text": [
1018
+ "Train Loss: 0.0000, Train F1: 0.9384\n",
1019
+ "Val Loss: 0.0000, Val F1: 0.8289\n",
1020
+ "Learning rate: 0.000100\n",
1021
+ "Saved best model with F1: 0.8289\n",
1022
+ "\n",
1023
+ "Epoch 10/20\n"
1024
+ ]
1025
+ },
1026
+ {
1027
+ "name": "stderr",
1028
+ "output_type": "stream",
1029
+ "text": [
1030
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.9508]\n",
1031
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.61it/s]\n"
1032
+ ]
1033
+ },
1034
+ {
1035
+ "name": "stdout",
1036
+ "output_type": "stream",
1037
+ "text": [
1038
+ "Train Loss: 0.0000, Train F1: 0.9508\n",
1039
+ "Val Loss: 0.0000, Val F1: 0.8402\n",
1040
+ "Learning rate: 0.000100\n",
1041
+ "Saved best model with F1: 0.8402\n",
1042
+ "\n",
1043
+ "Epoch 11/20\n"
1044
+ ]
1045
+ },
1046
+ {
1047
+ "name": "stderr",
1048
+ "output_type": "stream",
1049
+ "text": [
1050
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.06it/s, loss=0.0000, f1=0.9544]\n",
1051
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.60it/s]\n"
1052
+ ]
1053
+ },
1054
+ {
1055
+ "name": "stdout",
1056
+ "output_type": "stream",
1057
+ "text": [
1058
+ "Train Loss: 0.0000, Train F1: 0.9544\n",
1059
+ "Val Loss: 0.0000, Val F1: 0.8414\n",
1060
+ "Learning rate: 0.000100\n",
1061
+ "Saved best model with F1: 0.8414\n",
1062
+ "\n",
1063
+ "Epoch 12/20\n"
1064
+ ]
1065
+ },
1066
+ {
1067
+ "name": "stderr",
1068
+ "output_type": "stream",
1069
+ "text": [
1070
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.06it/s, loss=0.0000, f1=0.9617]\n",
1071
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.63it/s]\n"
1072
+ ]
1073
+ },
1074
+ {
1075
+ "name": "stdout",
1076
+ "output_type": "stream",
1077
+ "text": [
1078
+ "Train Loss: 0.0000, Train F1: 0.9617\n",
1079
+ "Val Loss: 0.0001, Val F1: 0.8420\n",
1080
+ "Learning rate: 0.000050\n",
1081
+ "Saved best model with F1: 0.8420\n",
1082
+ "\n",
1083
+ "Epoch 13/20\n"
1084
+ ]
1085
+ },
1086
+ {
1087
+ "name": "stderr",
1088
+ "output_type": "stream",
1089
+ "text": [
1090
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.06it/s, loss=0.0000, f1=0.9672]\n",
1091
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.61it/s]\n"
1092
+ ]
1093
+ },
1094
+ {
1095
+ "name": "stdout",
1096
+ "output_type": "stream",
1097
+ "text": [
1098
+ "Train Loss: 0.0000, Train F1: 0.9672\n",
1099
+ "Val Loss: 0.0000, Val F1: 0.8435\n",
1100
+ "Learning rate: 0.000050\n",
1101
+ "Saved best model with F1: 0.8435\n",
1102
+ "\n",
1103
+ "Epoch 14/20\n"
1104
+ ]
1105
+ },
1106
+ {
1107
+ "name": "stderr",
1108
+ "output_type": "stream",
1109
+ "text": [
1110
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.06it/s, loss=0.0000, f1=0.9656]\n",
1111
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.59it/s]\n"
1112
+ ]
1113
+ },
1114
+ {
1115
+ "name": "stdout",
1116
+ "output_type": "stream",
1117
+ "text": [
1118
+ "Train Loss: 0.0000, Train F1: 0.9656\n",
1119
+ "Val Loss: 0.0000, Val F1: 0.8481\n",
1120
+ "Learning rate: 0.000050\n",
1121
+ "Saved best model with F1: 0.8481\n",
1122
+ "\n",
1123
+ "Epoch 15/20\n"
1124
+ ]
1125
+ },
1126
+ {
1127
+ "name": "stderr",
1128
+ "output_type": "stream",
1129
+ "text": [
1130
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.06it/s, loss=0.0000, f1=0.9683]\n",
1131
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.59it/s]\n"
1132
+ ]
1133
+ },
1134
+ {
1135
+ "name": "stdout",
1136
+ "output_type": "stream",
1137
+ "text": [
1138
+ "Train Loss: 0.0000, Train F1: 0.9683\n",
1139
+ "Val Loss: 0.0001, Val F1: 0.8483\n",
1140
+ "Learning rate: 0.000050\n",
1141
+ "Saved best model with F1: 0.8483\n",
1142
+ "\n",
1143
+ "Epoch 16/20\n"
1144
+ ]
1145
+ },
1146
+ {
1147
+ "name": "stderr",
1148
+ "output_type": "stream",
1149
+ "text": [
1150
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.9719]\n",
1151
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.59it/s]\n"
1152
+ ]
1153
+ },
1154
+ {
1155
+ "name": "stdout",
1156
+ "output_type": "stream",
1157
+ "text": [
1158
+ "Train Loss: 0.0000, Train F1: 0.9719\n",
1159
+ "Val Loss: 0.0001, Val F1: 0.8503\n",
1160
+ "Learning rate: 0.000025\n",
1161
+ "Saved best model with F1: 0.8503\n",
1162
+ "\n",
1163
+ "Epoch 17/20\n"
1164
+ ]
1165
+ },
1166
+ {
1167
+ "name": "stderr",
1168
+ "output_type": "stream",
1169
+ "text": [
1170
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.9745]\n",
1171
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.61it/s]\n"
1172
+ ]
1173
+ },
1174
+ {
1175
+ "name": "stdout",
1176
+ "output_type": "stream",
1177
+ "text": [
1178
+ "Train Loss: 0.0000, Train F1: 0.9745\n",
1179
+ "Val Loss: 0.0001, Val F1: 0.8525\n",
1180
+ "Learning rate: 0.000025\n",
1181
+ "Saved best model with F1: 0.8525\n",
1182
+ "\n",
1183
+ "Epoch 18/20\n"
1184
+ ]
1185
+ },
1186
+ {
1187
+ "name": "stderr",
1188
+ "output_type": "stream",
1189
+ "text": [
1190
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.9757]\n",
1191
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.61it/s]\n"
1192
+ ]
1193
+ },
1194
+ {
1195
+ "name": "stdout",
1196
+ "output_type": "stream",
1197
+ "text": [
1198
+ "Train Loss: 0.0000, Train F1: 0.9757\n",
1199
+ "Val Loss: 0.0001, Val F1: 0.8500\n",
1200
+ "Learning rate: 0.000025\n",
1201
+ "\n",
1202
+ "Epoch 19/20\n"
1203
+ ]
1204
+ },
1205
+ {
1206
+ "name": "stderr",
1207
+ "output_type": "stream",
1208
+ "text": [
1209
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.07it/s, loss=0.0000, f1=0.9780]\n",
1210
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.59it/s]\n"
1211
+ ]
1212
+ },
1213
+ {
1214
+ "name": "stdout",
1215
+ "output_type": "stream",
1216
+ "text": [
1217
+ "Train Loss: 0.0000, Train F1: 0.9780\n",
1218
+ "Val Loss: 0.0001, Val F1: 0.8508\n",
1219
+ "Learning rate: 0.000025\n",
1220
+ "\n",
1221
+ "Epoch 20/20\n"
1222
+ ]
1223
+ },
1224
+ {
1225
+ "name": "stderr",
1226
+ "output_type": "stream",
1227
+ "text": [
1228
+ "Training: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 493/493 [02:01<00:00, 4.06it/s, loss=0.0000, f1=0.9770]\n",
1229
+ "Evaluating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 124/124 [00:10<00:00, 11.58it/s]\n"
1230
+ ]
1231
+ },
1232
+ {
1233
+ "name": "stdout",
1234
+ "output_type": "stream",
1235
+ "text": [
1236
+ "Train Loss: 0.0000, Train F1: 0.9770\n",
1237
+ "Val Loss: 0.0001, Val F1: 0.8538\n",
1238
+ "Learning rate: 0.000013\n",
1239
+ "Saved best model with F1: 0.8538\n",
1240
+ "\n",
1241
+ "============================================================\n",
1242
+ "Training completed!\n",
1243
+ "Best validation F1: 0.8538\n",
1244
+ "\n",
1245
+ "Model saved for deployment in 'saved_transformer_model/' directory\n",
1246
+ "Files saved:\n",
1247
+ " - saved_transformer_model/pii_transformer_model.pt\n",
1248
+ " - saved_transformer_model/vocabularies.pkl\n",
1249
+ " - saved_transformer_model/model_config.pkl\n"
1250
+ ]
1251
+ }
1252
+ ],
1253
+ "source": [
1254
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
1255
+ "print(f\"Using device: {device}\")\n",
1256
+ "\n",
1257
+ "model, text_vocab, label_vocab = train_transformer_pii_model(\n",
1258
+ " data_path='train_augmented.json',\n",
1259
+ " num_epochs=20,\n",
1260
+ " batch_size=32,\n",
1261
+ " learning_rate=2e-4,\n",
1262
+ " use_focal_loss=True,\n",
1263
+ " focal_gamma=2.0,\n",
1264
+ " device=device\n",
1265
+ ")"
1266
+ ]
1267
+ }
1268
+ ],
1269
+ "metadata": {
1270
+ "kaggle": {
1271
+ "accelerator": "nvidiaTeslaT4",
1272
+ "dataSources": [
1273
+ {
1274
+ "isSourceIdPinned": true,
1275
+ "modelId": 419045,
1276
+ "modelInstanceId": 400879,
1277
+ "sourceId": 504813,
1278
+ "sourceType": "modelInstanceVersion"
1279
+ }
1280
+ ],
1281
+ "dockerImageVersionId": 31090,
1282
+ "isGpuEnabled": true,
1283
+ "isInternetEnabled": true,
1284
+ "language": "python",
1285
+ "sourceType": "notebook"
1286
+ },
1287
+ "kernelspec": {
1288
+ "display_name": "Python 3",
1289
+ "language": "python",
1290
+ "name": "python3"
1291
+ },
1292
+ "language_info": {
1293
+ "codemirror_mode": {
1294
+ "name": "ipython",
1295
+ "version": 3
1296
+ },
1297
+ "file_extension": ".py",
1298
+ "mimetype": "text/x-python",
1299
+ "name": "python",
1300
+ "nbconvert_exporter": "python",
1301
+ "pygments_lexer": "ipython3",
1302
+ "version": "3.11.13"
1303
+ },
1304
+ "papermill": {
1305
+ "default_parameters": {},
1306
+ "duration": 2723.9142,
1307
+ "end_time": "2025-08-03T17:39:40.959986",
1308
+ "environment_variables": {},
1309
+ "exception": null,
1310
+ "input_path": "__notebook__.ipynb",
1311
+ "output_path": "__notebook__.ipynb",
1312
+ "parameters": {},
1313
+ "start_time": "2025-08-03T16:54:17.045786",
1314
+ "version": "2.6.0"
1315
+ }
1316
+ },
1317
+ "nbformat": 4,
1318
+ "nbformat_minor": 5
1319
+ }