AmitHirpara commited on
Commit
f53fac9
·
1 Parent(s): 3dea7de

add comments

Browse files
Files changed (6) hide show
  1. app.py +67 -47
  2. data_augmentation.py +64 -95
  3. lstm.py +48 -61
  4. lstm_training.ipynb +65 -26
  5. transformer.py +56 -172
  6. transformer_training.ipynb +76 -29
app.py CHANGED
@@ -10,20 +10,23 @@ from collections import Counter
10
  import warnings
11
  warnings.filterwarnings('ignore')
12
 
13
- # Define the Vocabulary class (needed for unpickling)
14
  class Vocabulary:
15
  """Vocabulary class for encoding/decoding text and labels"""
16
  def __init__(self, max_size=100000):
 
17
  self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}
18
  self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}
19
  self.word_count = Counter()
20
  self.max_size = max_size
21
 
22
  def add_sentence(self, sentence):
 
23
  for word in sentence:
24
  self.word_count[word.lower()] += 1
25
 
26
  def build(self):
 
27
  most_common = self.word_count.most_common(self.max_size - len(self.word2idx))
28
  for word, _ in most_common:
29
  if word not in self.word2idx:
@@ -35,12 +38,14 @@ class Vocabulary:
35
  return len(self.word2idx)
36
 
37
  def encode(self, sentence):
 
38
  return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]
39
 
40
  def decode(self, indices):
 
41
  return [self.idx2word.get(idx, '<unk>') for idx in indices]
42
 
43
- # Custom Transformer components to match the saved model
44
  class MultiHeadAttention(nn.Module):
45
  def __init__(self, d_model, num_heads, dropout=0.1):
46
  super().__init__()
@@ -49,6 +54,7 @@ class MultiHeadAttention(nn.Module):
49
  self.num_heads = num_heads
50
  self.d_k = d_model // num_heads
51
 
 
52
  self.w_q = nn.Linear(d_model, d_model)
53
  self.w_k = nn.Linear(d_model, d_model)
54
  self.w_v = nn.Linear(d_model, d_model)
@@ -59,24 +65,27 @@ class MultiHeadAttention(nn.Module):
59
  def forward(self, query, key, value, mask=None):
60
  batch_size = query.size(0)
61
 
62
- # Linear transformations and split into heads
63
  Q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
64
  K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
65
  V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
66
 
67
- # Attention
68
  scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
69
 
 
70
  if mask is not None:
71
  mask = mask.unsqueeze(1).unsqueeze(1)
72
  scores = scores.masked_fill(mask, -1e9)
73
 
 
74
  attention = F.softmax(scores, dim=-1)
75
  attention = self.dropout(attention)
76
 
 
77
  context = torch.matmul(attention, V)
78
 
79
- # Concatenate heads
80
  context = context.transpose(1, 2).contiguous().view(
81
  batch_size, -1, self.d_model
82
  )
@@ -84,6 +93,7 @@ class MultiHeadAttention(nn.Module):
84
  output = self.w_o(context)
85
  return output
86
 
 
87
  class FeedForward(nn.Module):
88
  def __init__(self, d_model, d_ff, dropout=0.1):
89
  super().__init__()
@@ -92,8 +102,10 @@ class FeedForward(nn.Module):
92
  self.dropout = nn.Dropout(dropout)
93
 
94
  def forward(self, x):
 
95
  return self.w_2(self.dropout(F.gelu(self.w_1(x))))
96
 
 
97
  class EncoderLayer(nn.Module):
98
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
99
  super().__init__()
@@ -104,19 +116,21 @@ class EncoderLayer(nn.Module):
104
  self.dropout = nn.Dropout(dropout)
105
 
106
  def forward(self, x, mask=None):
107
- # Self-attention with residual connection and layer norm
108
  attn_output = self.self_attention(x, x, x, mask)
109
  x = self.norm1(x + self.dropout(attn_output))
110
 
111
- # Feed forward with residual connection and layer norm
112
  ff_output = self.feed_forward(x)
113
  x = self.norm2(x + self.dropout(ff_output))
114
 
115
  return x
116
 
 
117
  class TransformerEncoder(nn.Module):
118
  def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
119
  super().__init__()
 
120
  self.layers = nn.ModuleList([
121
  EncoderLayer(d_model, num_heads, d_ff, dropout)
122
  for _ in range(num_layers)
@@ -124,49 +138,63 @@ class TransformerEncoder(nn.Module):
124
  self.norm = nn.LayerNorm(d_model)
125
 
126
  def forward(self, x, mask=None):
 
127
  for layer in self.layers:
128
  x = layer(x, mask)
129
  return self.norm(x)
130
 
 
131
  class PositionalEncoding(nn.Module):
132
  def __init__(self, d_model, max_len=5000):
133
  super().__init__()
134
  self.d_model = d_model
 
 
135
  pe = torch.zeros(max_len, d_model)
136
  position = torch.arange(0, max_len).unsqueeze(1).float()
137
  div_term = torch.exp(torch.arange(0, d_model, 2).float() *
138
  -(torch.log(torch.tensor(10000.0)) / d_model))
 
 
139
  pe[:, 0::2] = torch.sin(position * div_term)
140
  pe[:, 1::2] = torch.cos(position * div_term)
141
  self.register_buffer('pe', pe.unsqueeze(0))
142
 
143
  def forward(self, x):
 
144
  return x * torch.sqrt(torch.tensor(self.d_model, dtype=x.dtype)) + self.pe[:, :x.size(1)]
145
 
 
146
  class TransformerPIIDetector(nn.Module):
147
  def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
148
  d_ff=512, num_layers=4, dropout=0.1, max_len=512):
149
  super().__init__()
150
 
 
151
  self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
152
- self.positional_encoding = PositionalEncoding(d_model, max_len) # Changed name to match saved model
153
  self.dropout = nn.Dropout(dropout)
154
-
155
- # Custom encoder to match saved model structure
156
  self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
157
  self.classifier = nn.Linear(d_model, num_classes)
158
 
159
  def forward(self, x):
 
160
  padding_mask = (x == 0)
 
 
161
  x = self.embedding(x)
162
  x = self.positional_encoding(x)
163
  x = self.dropout(x)
 
 
164
  x = self.encoder(x, padding_mask)
165
  return self.classifier(x)
166
 
167
  def create_transformer_pii_model(**kwargs):
 
168
  return TransformerPIIDetector(**kwargs)
169
 
 
170
  class PIIDetector:
171
  def __init__(self, model_dir='saved_transformer'):
172
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -176,13 +204,13 @@ class PIIDetector:
176
  self.label_vocab = None
177
  self.load_model()
178
 
179
- # Single color for all PII highlighting
180
  self.highlight_color = '#FF6B6B'
181
 
182
  def load_model(self):
183
  """Load the trained model and vocabularies"""
184
  try:
185
- # Load vocabularies
186
  vocab_path = os.path.join(self.model_dir, 'vocabularies.pkl')
187
  with open(vocab_path, 'rb') as f:
188
  vocabs = pickle.load(f)
@@ -194,7 +222,7 @@ class PIIDetector:
194
  with open(config_path, 'rb') as f:
195
  model_config = pickle.load(f)
196
 
197
- # Create and load model
198
  self.model = create_transformer_pii_model(**model_config)
199
  model_path = os.path.join(self.model_dir, 'pii_transformer_model.pt')
200
  self.model.load_state_dict(torch.load(model_path, map_location=self.device))
@@ -211,7 +239,7 @@ class PIIDetector:
211
  def tokenize(self, text: str) -> List[str]:
212
  """Simple tokenization by splitting on spaces and punctuation"""
213
  import re
214
- # Split on whitespace and keep punctuation as separate tokens
215
  tokens = re.findall(r'\w+|[^\w\s]', text)
216
  return tokens
217
 
@@ -220,30 +248,30 @@ class PIIDetector:
220
  if not text.strip():
221
  return []
222
 
223
- # Tokenize
224
  tokens = self.tokenize(text)
225
 
226
- # Add start and end tokens
227
  tokens_with_special = ['<start>'] + tokens + ['<end>']
228
 
229
- # Encode tokens
230
  token_ids = self.text_vocab.encode(tokens_with_special)
231
 
232
- # Convert to tensor and add batch dimension
233
  input_tensor = torch.tensor([token_ids]).to(self.device)
234
 
235
- # Predict
236
  with torch.no_grad():
237
  outputs = self.model(input_tensor)
238
  predictions = torch.argmax(outputs, dim=-1)
239
 
240
- # Decode predictions (skip start and end tokens)
241
  predicted_labels = []
242
- for idx in predictions[0][1:-1]: # Skip <start> and <end>
243
  label = self.label_vocab.idx2word.get(idx.item(), 'O')
244
  predicted_labels.append(label.upper())
245
 
246
- # Pair tokens with their labels
247
  return list(zip(tokens, predicted_labels))
248
 
249
  def create_highlighted_html(self, token_label_pairs: List[Tuple[str, str]]) -> str:
@@ -254,14 +282,14 @@ class PIIDetector:
254
  while i < len(token_label_pairs):
255
  token, label = token_label_pairs[i]
256
 
257
- # Check if this is the start of a PII entity
258
  if label != 'O':
259
  # Collect all tokens for this entity
260
  entity_tokens = [token]
261
  entity_label = label
262
  j = i + 1
263
 
264
- # Look for continuation tokens (I- tags)
265
  while j < len(token_label_pairs):
266
  next_token, next_label = token_label_pairs[j]
267
  if next_label.startswith('I-') and next_label.replace('I-', 'B-') == entity_label:
@@ -270,14 +298,14 @@ class PIIDetector:
270
  else:
271
  break
272
 
273
- # Join tokens with appropriate spacing
274
  entity_text = ''
275
  for k, tok in enumerate(entity_tokens):
276
  if k > 0 and tok not in '.,!?;:':
277
  entity_text += ' '
278
  entity_text += tok
279
 
280
- # Add highlighted entity
281
  label_display = entity_label.replace('B-', '').replace('I-', '').replace('_', ' ')
282
  html_parts.append(
283
  f'<mark style="background-color: {self.highlight_color}; padding: 2px 4px; '
@@ -287,7 +315,7 @@ class PIIDetector:
287
 
288
  i = j
289
  else:
290
- # Add space before token if needed
291
  if i > 0 and token not in '.,!?;:' and len(token_label_pairs) > i-1:
292
  prev_token, _ = token_label_pairs[i-1]
293
  if prev_token not in '(':
@@ -306,14 +334,14 @@ class PIIDetector:
306
  total_tokens = len(token_label_pairs)
307
  pii_tokens = 0
308
 
 
309
  for _, label in token_label_pairs:
310
  if label != 'O':
311
  pii_tokens += 1
312
- # Clean up label for display
313
  label_clean = label.replace('B-', '').replace('I-', '').replace('_', ' ')
314
  stats[label_clean] = stats.get(label_clean, 0) + 1
315
 
316
- # Create statistics text
317
  stats_text = f"### Detection Summary\n\n"
318
  stats_text += f"**Total tokens:** {total_tokens}\n\n"
319
  stats_text += f"**PII tokens:** {pii_tokens} ({pii_tokens/total_tokens*100:.1f}%)\n\n"
@@ -323,7 +351,7 @@ class PIIDetector:
323
 
324
  return stats_text
325
 
326
- # Initialize the detector
327
  print("Initializing PII Detector...")
328
  detector = PIIDetector()
329
 
@@ -333,13 +361,13 @@ def detect_pii(text):
333
  return "<p style='color: #6c757d; padding: 20px;'>Please enter some text to analyze.</p>", "No text provided."
334
 
335
  try:
336
- # Get predictions
337
  token_label_pairs = detector.predict(text)
338
 
339
- # Create highlighted HTML
340
  highlighted_html = detector.create_highlighted_html(token_label_pairs)
341
 
342
- # Get statistics
343
  stats = detector.get_statistics(token_label_pairs)
344
 
345
  return highlighted_html, stats
@@ -349,18 +377,7 @@ def detect_pii(text):
349
  error_stats = f"Error occurred: {str(e)}"
350
  return error_html, error_stats
351
 
352
- # Example texts
353
- examples = [
354
- "My name is John Smith and my email is john.smith@email.com. You can reach me at 555-123-4567.",
355
- "Student ID: 12345678. Please send the documents to 123 Main Street, Anytown, USA 12345.",
356
- "Contact Sarah Johnson at sarah_j_2023@gmail.com or visit her profile at linkedin.com/in/sarahjohnson",
357
- "The project was completed by student A1B2C3D4 who lives at 456 Oak Avenue.",
358
- "For verification, my phone number is (555) 987-6543 and my username is cool_user_99.",
359
- "Hi, I'm Emily Chen. My student number is STU-2023-98765 and I live at 789 Pine Street, Apt 4B.",
360
- "You can reach me at my personal website: www.johndoe.com or call me at +1-555-0123.",
361
- ]
362
-
363
- # Create Gradio interface
364
  with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
365
  gr.Markdown(
366
  """
@@ -371,6 +388,7 @@ with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
371
  )
372
 
373
  with gr.Column():
 
374
  input_text = gr.Textbox(
375
  label="Input Text",
376
  placeholder="Enter text to analyze for PII...",
@@ -378,10 +396,12 @@ with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
378
  max_lines=20
379
  )
380
 
 
381
  with gr.Row():
382
  analyze_btn = gr.Button("🔍 Detect PII", variant="primary", scale=2)
383
  clear_btn = gr.Button("🗑️ Clear", scale=1)
384
 
 
385
  highlighted_output = gr.HTML(
386
  label="Highlighted Text",
387
  value="<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>"
@@ -392,7 +412,7 @@ with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
392
  value="*Statistics will appear here...*"
393
  )
394
 
395
- # Set up event handlers
396
  analyze_btn.click(
397
  fn=detect_pii,
398
  inputs=[input_text],
@@ -404,7 +424,7 @@ with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
404
  outputs=[input_text, highlighted_output, stats_output]
405
  )
406
 
407
- # Launch the app
408
  if __name__ == "__main__":
409
  print("\nLaunching Gradio interface...")
410
  demo.launch()
 
10
  import warnings
11
  warnings.filterwarnings('ignore')
12
 
13
+ # Vocabulary class for handling text encoding and decoding
14
  class Vocabulary:
15
  """Vocabulary class for encoding/decoding text and labels"""
16
  def __init__(self, max_size=100000):
17
+ # Initialize special tokens
18
  self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}
19
  self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}
20
  self.word_count = Counter()
21
  self.max_size = max_size
22
 
23
  def add_sentence(self, sentence):
24
+ # Count word frequencies in the sentence
25
  for word in sentence:
26
  self.word_count[word.lower()] += 1
27
 
28
  def build(self):
29
+ # Build vocabulary from most common words
30
  most_common = self.word_count.most_common(self.max_size - len(self.word2idx))
31
  for word, _ in most_common:
32
  if word not in self.word2idx:
 
38
  return len(self.word2idx)
39
 
40
  def encode(self, sentence):
41
+ # Convert words to indices
42
  return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]
43
 
44
  def decode(self, indices):
45
+ # Convert indices back to words
46
  return [self.idx2word.get(idx, '<unk>') for idx in indices]
47
 
48
+ # Multi-head attention mechanism for the transformer
49
  class MultiHeadAttention(nn.Module):
50
  def __init__(self, d_model, num_heads, dropout=0.1):
51
  super().__init__()
 
54
  self.num_heads = num_heads
55
  self.d_k = d_model // num_heads
56
 
57
+ # Linear layers for query, key, value, and output
58
  self.w_q = nn.Linear(d_model, d_model)
59
  self.w_k = nn.Linear(d_model, d_model)
60
  self.w_v = nn.Linear(d_model, d_model)
 
65
  def forward(self, query, key, value, mask=None):
66
  batch_size = query.size(0)
67
 
68
+ # Transform and reshape for multi-head attention
69
  Q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
70
  K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
71
  V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
72
 
73
+ # Calculate attention scores
74
  scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
75
 
76
+ # Apply mask if provided
77
  if mask is not None:
78
  mask = mask.unsqueeze(1).unsqueeze(1)
79
  scores = scores.masked_fill(mask, -1e9)
80
 
81
+ # Apply softmax and dropout
82
  attention = F.softmax(scores, dim=-1)
83
  attention = self.dropout(attention)
84
 
85
+ # Apply attention to values
86
  context = torch.matmul(attention, V)
87
 
88
+ # Reshape back to original dimensions
89
  context = context.transpose(1, 2).contiguous().view(
90
  batch_size, -1, self.d_model
91
  )
 
93
  output = self.w_o(context)
94
  return output
95
 
96
+ # Feed-forward network component
97
  class FeedForward(nn.Module):
98
  def __init__(self, d_model, d_ff, dropout=0.1):
99
  super().__init__()
 
102
  self.dropout = nn.Dropout(dropout)
103
 
104
  def forward(self, x):
105
+ # Two linear layers with GELU activation
106
  return self.w_2(self.dropout(F.gelu(self.w_1(x))))
107
 
108
+ # Single encoder layer combining attention and feed-forward
109
  class EncoderLayer(nn.Module):
110
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
111
  super().__init__()
 
116
  self.dropout = nn.Dropout(dropout)
117
 
118
  def forward(self, x, mask=None):
119
+ # Apply self-attention with residual connection
120
  attn_output = self.self_attention(x, x, x, mask)
121
  x = self.norm1(x + self.dropout(attn_output))
122
 
123
+ # Apply feed-forward with residual connection
124
  ff_output = self.feed_forward(x)
125
  x = self.norm2(x + self.dropout(ff_output))
126
 
127
  return x
128
 
129
+ # Stack of encoder layers
130
  class TransformerEncoder(nn.Module):
131
  def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
132
  super().__init__()
133
+ # Create multiple encoder layers
134
  self.layers = nn.ModuleList([
135
  EncoderLayer(d_model, num_heads, d_ff, dropout)
136
  for _ in range(num_layers)
 
138
  self.norm = nn.LayerNorm(d_model)
139
 
140
  def forward(self, x, mask=None):
141
+ # Pass through each encoder layer
142
  for layer in self.layers:
143
  x = layer(x, mask)
144
  return self.norm(x)
145
 
146
+ # Positional encoding to add position information to embeddings
147
  class PositionalEncoding(nn.Module):
148
  def __init__(self, d_model, max_len=5000):
149
  super().__init__()
150
  self.d_model = d_model
151
+
152
+ # Create positional encoding matrix
153
  pe = torch.zeros(max_len, d_model)
154
  position = torch.arange(0, max_len).unsqueeze(1).float()
155
  div_term = torch.exp(torch.arange(0, d_model, 2).float() *
156
  -(torch.log(torch.tensor(10000.0)) / d_model))
157
+
158
+ # Apply sine and cosine functions
159
  pe[:, 0::2] = torch.sin(position * div_term)
160
  pe[:, 1::2] = torch.cos(position * div_term)
161
  self.register_buffer('pe', pe.unsqueeze(0))
162
 
163
  def forward(self, x):
164
+ # Scale embeddings and add positional encoding
165
  return x * torch.sqrt(torch.tensor(self.d_model, dtype=x.dtype)) + self.pe[:, :x.size(1)]
166
 
167
+ # Main transformer model for PII detection
168
  class TransformerPIIDetector(nn.Module):
169
  def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
170
  d_ff=512, num_layers=4, dropout=0.1, max_len=512):
171
  super().__init__()
172
 
173
+ # Model components
174
  self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
175
+ self.positional_encoding = PositionalEncoding(d_model, max_len)
176
  self.dropout = nn.Dropout(dropout)
 
 
177
  self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
178
  self.classifier = nn.Linear(d_model, num_classes)
179
 
180
  def forward(self, x):
181
+ # Create padding mask
182
  padding_mask = (x == 0)
183
+
184
+ # Pass through embedding and positional encoding
185
  x = self.embedding(x)
186
  x = self.positional_encoding(x)
187
  x = self.dropout(x)
188
+
189
+ # Encode and classify
190
  x = self.encoder(x, padding_mask)
191
  return self.classifier(x)
192
 
193
  def create_transformer_pii_model(**kwargs):
194
+ # Factory function to create the model
195
  return TransformerPIIDetector(**kwargs)
196
 
197
+ # Main PII detection class
198
  class PIIDetector:
199
  def __init__(self, model_dir='saved_transformer'):
200
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
204
  self.label_vocab = None
205
  self.load_model()
206
 
207
+ # Color for highlighting PII entities
208
  self.highlight_color = '#FF6B6B'
209
 
210
  def load_model(self):
211
  """Load the trained model and vocabularies"""
212
  try:
213
+ # Load saved vocabularies
214
  vocab_path = os.path.join(self.model_dir, 'vocabularies.pkl')
215
  with open(vocab_path, 'rb') as f:
216
  vocabs = pickle.load(f)
 
222
  with open(config_path, 'rb') as f:
223
  model_config = pickle.load(f)
224
 
225
+ # Initialize and load model weights
226
  self.model = create_transformer_pii_model(**model_config)
227
  model_path = os.path.join(self.model_dir, 'pii_transformer_model.pt')
228
  self.model.load_state_dict(torch.load(model_path, map_location=self.device))
 
239
  def tokenize(self, text: str) -> List[str]:
240
  """Simple tokenization by splitting on spaces and punctuation"""
241
  import re
242
+ # Split text into words and punctuation marks
243
  tokens = re.findall(r'\w+|[^\w\s]', text)
244
  return tokens
245
 
 
248
  if not text.strip():
249
  return []
250
 
251
+ # Tokenize input text
252
  tokens = self.tokenize(text)
253
 
254
+ # Add special tokens
255
  tokens_with_special = ['<start>'] + tokens + ['<end>']
256
 
257
+ # Convert tokens to indices
258
  token_ids = self.text_vocab.encode(tokens_with_special)
259
 
260
+ # Prepare tensor for model
261
  input_tensor = torch.tensor([token_ids]).to(self.device)
262
 
263
+ # Get predictions
264
  with torch.no_grad():
265
  outputs = self.model(input_tensor)
266
  predictions = torch.argmax(outputs, dim=-1)
267
 
268
+ # Convert predictions to labels
269
  predicted_labels = []
270
+ for idx in predictions[0][1:-1]: # Skip special tokens
271
  label = self.label_vocab.idx2word.get(idx.item(), 'O')
272
  predicted_labels.append(label.upper())
273
 
274
+ # Return token-label pairs
275
  return list(zip(tokens, predicted_labels))
276
 
277
  def create_highlighted_html(self, token_label_pairs: List[Tuple[str, str]]) -> str:
 
282
  while i < len(token_label_pairs):
283
  token, label = token_label_pairs[i]
284
 
285
+ # Check if token is part of PII entity
286
  if label != 'O':
287
  # Collect all tokens for this entity
288
  entity_tokens = [token]
289
  entity_label = label
290
  j = i + 1
291
 
292
+ # Find continuation tokens
293
  while j < len(token_label_pairs):
294
  next_token, next_label = token_label_pairs[j]
295
  if next_label.startswith('I-') and next_label.replace('I-', 'B-') == entity_label:
 
298
  else:
299
  break
300
 
301
+ # Join entity tokens with proper spacing
302
  entity_text = ''
303
  for k, tok in enumerate(entity_tokens):
304
  if k > 0 and tok not in '.,!?;:':
305
  entity_text += ' '
306
  entity_text += tok
307
 
308
+ # Create highlighted HTML for entity
309
  label_display = entity_label.replace('B-', '').replace('I-', '').replace('_', ' ')
310
  html_parts.append(
311
  f'<mark style="background-color: {self.highlight_color}; padding: 2px 4px; '
 
315
 
316
  i = j
317
  else:
318
+ # Add non-PII token with proper spacing
319
  if i > 0 and token not in '.,!?;:' and len(token_label_pairs) > i-1:
320
  prev_token, _ = token_label_pairs[i-1]
321
  if prev_token not in '(':
 
334
  total_tokens = len(token_label_pairs)
335
  pii_tokens = 0
336
 
337
+ # Count PII tokens by type
338
  for _, label in token_label_pairs:
339
  if label != 'O':
340
  pii_tokens += 1
 
341
  label_clean = label.replace('B-', '').replace('I-', '').replace('_', ' ')
342
  stats[label_clean] = stats.get(label_clean, 0) + 1
343
 
344
+ # Format statistics text
345
  stats_text = f"### Detection Summary\n\n"
346
  stats_text += f"**Total tokens:** {total_tokens}\n\n"
347
  stats_text += f"**PII tokens:** {pii_tokens} ({pii_tokens/total_tokens*100:.1f}%)\n\n"
 
351
 
352
  return stats_text
353
 
354
+ # Initialize the detector when the script runs
355
  print("Initializing PII Detector...")
356
  detector = PIIDetector()
357
 
 
361
  return "<p style='color: #6c757d; padding: 20px;'>Please enter some text to analyze.</p>", "No text provided."
362
 
363
  try:
364
+ # Run PII detection
365
  token_label_pairs = detector.predict(text)
366
 
367
+ # Generate highlighted output
368
  highlighted_html = detector.create_highlighted_html(token_label_pairs)
369
 
370
+ # Generate statistics
371
  stats = detector.get_statistics(token_label_pairs)
372
 
373
  return highlighted_html, stats
 
377
  error_stats = f"Error occurred: {str(e)}"
378
  return error_html, error_stats
379
 
380
+ # Create the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
381
  with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
382
  gr.Markdown(
383
  """
 
388
  )
389
 
390
  with gr.Column():
391
+ # Input text area
392
  input_text = gr.Textbox(
393
  label="Input Text",
394
  placeholder="Enter text to analyze for PII...",
 
396
  max_lines=20
397
  )
398
 
399
+ # Control buttons
400
  with gr.Row():
401
  analyze_btn = gr.Button("🔍 Detect PII", variant="primary", scale=2)
402
  clear_btn = gr.Button("🗑️ Clear", scale=1)
403
 
404
+ # Output areas
405
  highlighted_output = gr.HTML(
406
  label="Highlighted Text",
407
  value="<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>"
 
412
  value="*Statistics will appear here...*"
413
  )
414
 
415
+ # Connect buttons to functions
416
  analyze_btn.click(
417
  fn=detect_pii,
418
  inputs=[input_text],
 
424
  outputs=[input_text, highlighted_output, stats_output]
425
  )
426
 
427
+ # Launch the application
428
  if __name__ == "__main__":
429
  print("\nLaunching Gradio interface...")
430
  demo.launch()
data_augmentation.py CHANGED
@@ -15,17 +15,20 @@ class PIIDataAugmenter:
15
 
16
  def __init__(self, seed=42):
17
  """Initialize the augmenter with random seeds for reproducibility."""
 
18
  random.seed(seed)
19
  np.random.seed(seed)
20
  self.fake = Faker()
21
  Faker.seed(seed)
22
 
 
23
  self._init_templates()
24
  self._init_context_phrases()
25
  self._init_generators()
26
 
27
  def _init_templates(self):
28
  """Initialize templates for different PII types."""
 
29
  self.templates = {
30
  'NAME_STUDENT': [
31
  "My name is {name}",
@@ -115,6 +118,7 @@ class PIIDataAugmenter:
115
 
116
  def _init_context_phrases(self):
117
  """Initialize context phrases for more natural text generation."""
 
118
  self.context_prefix = [
119
  "Hello everyone,",
120
  "Dear Sir/Madam,",
@@ -128,6 +132,7 @@ class PIIDataAugmenter:
128
  "I am writing to tell you that"
129
  ]
130
 
 
131
  self.context_suffix = [
132
  "Thank you.",
133
  "Best regards.",
@@ -141,12 +146,14 @@ class PIIDataAugmenter:
141
  "Let me know if you have questions."
142
  ]
143
 
 
144
  self.connectors = [
145
  " and ", " or ", ", ", ". Also, ", ". Additionally, "
146
  ]
147
 
148
  def _init_generators(self):
149
  """Initialize PII generators mapping."""
 
150
  self.generators = {
151
  'NAME_STUDENT': self.generate_name,
152
  'EMAIL': self.generate_email,
@@ -157,6 +164,7 @@ class PIIDataAugmenter:
157
  'USERNAME': self.generate_username
158
  }
159
 
 
160
  self.format_keys = {
161
  'NAME_STUDENT': 'name',
162
  'EMAIL': 'email',
@@ -167,8 +175,6 @@ class PIIDataAugmenter:
167
  'USERNAME': 'username'
168
  }
169
 
170
- # ========== PII Generators ==========
171
-
172
  def generate_name(self):
173
  """Generate realistic person names."""
174
  return self.fake.name()
@@ -179,6 +185,7 @@ class PIIDataAugmenter:
179
 
180
  def generate_phone(self):
181
  """Generate realistic phone numbers in various formats."""
 
182
  formats = [
183
  "555-{:03d}-{:04d}",
184
  "(555) {:03d}-{:04d}",
@@ -186,6 +193,7 @@ class PIIDataAugmenter:
186
  "+1-555-{:03d}-{:04d}",
187
  "555{:03d}{:04d}"
188
  ]
 
189
  format_choice = random.choice(formats)
190
  area = random.randint(100, 999)
191
  number = random.randint(1000, 9999)
@@ -193,10 +201,12 @@ class PIIDataAugmenter:
193
 
194
  def generate_address(self):
195
  """Generate realistic street addresses."""
 
196
  return self.fake.address().replace('\n', ', ')
197
 
198
  def generate_id_num(self):
199
  """Generate various ID number formats."""
 
200
  formats = [
201
  "{:06d}", # 6-digit ID
202
  "{:08d}", # 8-digit ID
@@ -207,6 +217,7 @@ class PIIDataAugmenter:
207
  ]
208
  format_choice = random.choice(formats)
209
 
 
210
  if '-' in format_choice:
211
  return format_choice.format(
212
  random.randint(1000, 9999),
@@ -217,6 +228,7 @@ class PIIDataAugmenter:
217
 
218
  def generate_url(self):
219
  """Generate personal website URLs."""
 
220
  domains = ['github.com', 'linkedin.com', 'portfolio.com',
221
  'personal.com', 'website.com']
222
  username = self.fake.user_name()
@@ -227,72 +239,53 @@ class PIIDataAugmenter:
227
  """Generate usernames."""
228
  return self.fake.user_name()
229
 
230
- # ========== Synthetic Example Creation ==========
231
-
232
  def create_synthetic_example(self, pii_type, add_context=True):
233
- """
234
- Create a synthetic example with proper BIO labeling.
235
-
236
- Args:
237
- pii_type: Type of PII to generate
238
- add_context: Whether to add context phrases
239
-
240
- Returns:
241
- Tuple of (tokens, labels)
242
- """
243
- # Generate PII value
244
  pii_value = self.generators[pii_type]()
245
 
246
- # Select and fill template
247
  template = random.choice(self.templates[pii_type])
248
  format_key = self.format_keys[pii_type]
249
  sentence = template.format(**{format_key: pii_value})
250
 
251
- # Add context if requested
252
  if add_context and random.random() > 0.3:
253
  sentence = self._add_context(sentence)
254
 
255
- # Tokenize and label
256
  tokens, labels = self._tokenize_and_label(sentence, pii_value, pii_type)
257
 
258
  return tokens, labels
259
 
260
  def create_mixed_example(self, pii_types, num_pii=2):
261
- """
262
- Create examples with multiple PII types.
263
-
264
- Args:
265
- pii_types: List of PII types to include
266
- num_pii: Number of PII entities to include
267
-
268
- Returns:
269
- Tuple of (tokens, labels)
270
- """
271
  selected_types = random.sample(pii_types, min(num_pii, len(pii_types)))
272
 
273
  all_tokens = []
274
  all_labels = []
275
 
276
- # Add context prefix
277
  if random.random() > 0.3:
278
  prefix = random.choice(self.context_prefix)
279
  all_tokens.extend(prefix.split())
280
  all_labels.extend(['O'] * len(prefix.split()))
281
 
282
- # Add each PII with connectors
283
  for i, pii_type in enumerate(selected_types):
284
- # Add connector between PIIs
285
  if i > 0 and random.random() > 0.5:
286
  connector = random.choice(self.connectors)
287
  all_tokens.extend(connector.strip().split())
288
  all_labels.extend(['O'] * len(connector.strip().split()))
289
 
290
- # Create PII example without additional context
291
  tokens, labels = self.create_synthetic_example(pii_type, add_context=False)
292
  all_tokens.extend(tokens)
293
  all_labels.extend(labels)
294
 
295
- # Add context suffix
296
  if random.random() > 0.3:
297
  suffix = random.choice(self.context_suffix)
298
  all_tokens.extend(suffix.split())
@@ -302,79 +295,60 @@ class PIIDataAugmenter:
302
 
303
  def _add_context(self, sentence):
304
  """Add context phrases to make text more natural."""
 
305
  if random.random() > 0.5:
306
  sentence = random.choice(self.context_prefix) + " " + sentence
 
307
  if random.random() > 0.5:
308
  sentence = sentence + " " + random.choice(self.context_suffix)
309
  return sentence
310
 
311
  def _tokenize_and_label(self, sentence, pii_value, pii_type):
312
- """
313
- Tokenize sentence and apply BIO labels for PII.
314
-
315
- Args:
316
- sentence: The sentence containing PII
317
- pii_value: The PII value to find and label
318
- pii_type: The type of PII for labeling
319
-
320
- Returns:
321
- Tuple of (tokens, labels)
322
- """
323
  tokens = sentence.split()
324
  labels = ['O'] * len(tokens)
325
 
326
- # Tokenize PII value
327
  pii_tokens = pii_value.split()
328
 
329
- # Find and label PII in the sentence
330
  for i in range(len(tokens) - len(pii_tokens) + 1):
331
- # Check if tokens match PII value
332
  if (tokens[i:i+len(pii_tokens)] == pii_tokens or
333
  ' '.join(tokens[i:i+len(pii_tokens)]).lower() == pii_value.lower()):
334
 
335
- # Apply BIO labels
336
- labels[i] = f'B-{pii_type}'
337
  for j in range(1, len(pii_tokens)):
338
- labels[i+j] = f'I-{pii_type}'
339
  break
340
 
341
  return tokens, labels
342
 
343
- # ========== Dataset Augmentation ==========
344
-
345
  def augment_dataset(self, original_data, target_samples_per_class=1000, mix_ratio=0.3):
346
- """
347
- Augment dataset with synthetic examples to balance PII classes.
348
-
349
- Args:
350
- original_data: Original dataset DataFrame
351
- target_samples_per_class: Target number of samples per PII class
352
- mix_ratio: Ratio of mixed (multi-PII) examples
353
-
354
- Returns:
355
- Augmented dataset DataFrame
356
- """
357
- # Analyze original distribution
358
  label_counts = self._analyze_label_distribution(original_data)
359
  print("\nOriginal label distribution:")
360
  self._print_distribution(label_counts)
361
 
362
- # Generate synthetic examples
363
  synthetic_tokens, synthetic_labels = self._generate_synthetic_data(
364
  label_counts, target_samples_per_class, mix_ratio
365
  )
366
 
367
- # Add non-PII examples
368
  synthetic_tokens, synthetic_labels = self._add_non_pii_examples(
369
  synthetic_tokens, synthetic_labels
370
  )
371
 
372
- # Combine and shuffle data
373
  augmented_df = self._combine_and_shuffle(
374
  original_data, synthetic_tokens, synthetic_labels
375
  )
376
 
377
- # Analyze new distribution
378
  new_label_counts = self._analyze_label_distribution(augmented_df)
379
  print("\nAugmented label distribution:")
380
  self._print_distribution(new_label_counts)
@@ -385,10 +359,11 @@ class PIIDataAugmenter:
385
  """Analyze the distribution of PII labels in the dataset."""
386
  label_counts = Counter()
387
 
 
388
  for labels in data['labels']:
389
  for label in labels:
390
  if label != 'O':
391
- # Extract base label (remove B- or I- prefix)
392
  base_label = label.split('-')[1] if '-' in label else label
393
  label_counts[base_label] += 1
394
 
@@ -397,6 +372,7 @@ class PIIDataAugmenter:
397
  def _print_distribution(self, label_counts):
398
  """Print label distribution statistics."""
399
  total = sum(label_counts.values())
 
400
  for label, count in label_counts.most_common():
401
  percentage = (count / total * 100) if total > 0 else 0
402
  print(f" {label:15} : {count:6,} ({percentage:5.2f}%)")
@@ -406,6 +382,7 @@ class PIIDataAugmenter:
406
  synthetic_tokens = []
407
  synthetic_labels = []
408
 
 
409
  for pii_type in self.templates.keys():
410
  current_count = label_counts.get(pii_type, 0)
411
  needed = max(0, target_samples - current_count)
@@ -415,17 +392,17 @@ class PIIDataAugmenter:
415
 
416
  print(f"\nGenerating {needed} synthetic examples for {pii_type}")
417
 
418
- # Single PII examples
419
  single_count = int(needed * (1 - mix_ratio))
420
  for _ in range(single_count):
421
  tokens, labels = self.create_synthetic_example(pii_type)
422
  synthetic_tokens.append(tokens)
423
  synthetic_labels.append(labels)
424
 
425
- # Mixed PII examples
426
  mixed_count = int(needed * mix_ratio)
427
  for _ in range(mixed_count):
428
- # Ensure current PII type is included
429
  other_types = [t for t in self.templates.keys() if t != pii_type]
430
  selected_types = [pii_type] + random.sample(
431
  other_types, min(1, len(other_types))
@@ -439,6 +416,7 @@ class PIIDataAugmenter:
439
 
440
  def _add_non_pii_examples(self, synthetic_tokens, synthetic_labels):
441
  """Add examples without PII (all 'O' labels) for balance."""
 
442
  num_non_pii = int(len(synthetic_tokens) * 0.1)
443
 
444
  for _ in range(num_non_pii):
@@ -454,17 +432,17 @@ class PIIDataAugmenter:
454
 
455
  def _combine_and_shuffle(self, original_data, synthetic_tokens, synthetic_labels):
456
  """Combine original and synthetic data, then shuffle."""
457
- # Combine data
458
  all_tokens = original_data['tokens'].tolist() + synthetic_tokens
459
  all_labels = original_data['labels'].tolist() + synthetic_labels
460
 
461
- # Create DataFrame
462
  augmented_data = pd.DataFrame({
463
  'tokens': all_tokens,
464
  'labels': all_labels
465
  })
466
 
467
- # Shuffle
468
  augmented_data = augmented_data.sample(frac=1, random_state=42).reset_index(drop=True)
469
 
470
  print(f"\nTotal augmented samples: {len(augmented_data):,}")
@@ -472,17 +450,8 @@ class PIIDataAugmenter:
472
  return augmented_data
473
 
474
  def calculate_class_weights(data, label_vocab):
475
- """
476
- Calculate class weights for balanced loss function.
477
-
478
- Args:
479
- data: Dataset DataFrame with 'labels' column
480
- label_vocab: Vocabulary object with word2idx mapping
481
-
482
- Returns:
483
- Tensor of class weights
484
- """
485
- # Count label occurrences
486
  label_counts = Counter()
487
 
488
  for labels in data['labels']:
@@ -490,7 +459,7 @@ def calculate_class_weights(data, label_vocab):
490
  label_id = label_vocab.word2idx.get(label.lower(), 0)
491
  label_counts[label_id] += 1
492
 
493
- # Calculate inverse frequency weights
494
  total_samples = sum(label_counts.values())
495
  num_classes = len(label_vocab)
496
 
@@ -498,31 +467,31 @@ def calculate_class_weights(data, label_vocab):
498
 
499
  for class_id, count in label_counts.items():
500
  if count > 0:
501
- # Inverse frequency with smoothing
502
  weights[class_id] = total_samples / (num_classes * count)
503
 
504
- # Normalize weights
505
  weights = weights / weights.sum() * num_classes
506
 
507
- # Cap extreme weights to prevent instability
508
  weights = torch.clamp(weights, min=0.1, max=10.0)
509
 
510
- # Set padding weight to 0
511
  weights[0] = 0.0
512
 
513
  return weights
514
 
515
  if __name__ == '__main__':
516
  """Example usage of the augmentation module."""
517
- # Load original data
518
  print("Loading original training data...")
519
  original_data = pd.read_json('train.json')
520
  print(f"Original dataset size: {len(original_data):,}")
521
 
522
- # Initialize augmenter
523
  augmenter = PIIDataAugmenter(seed=42)
524
 
525
- # Augment dataset
526
  print("\n" + "="*60)
527
  print("Starting data augmentation...")
528
  print("="*60)
@@ -533,7 +502,7 @@ if __name__ == '__main__':
533
  mix_ratio=0.3
534
  )
535
 
536
- # Save augmented data
537
  output_path = './train_augmented.json'
538
  augmented_data.to_json(output_path, orient='records', lines=True)
539
  print(f"\nSaved augmented data to {output_path}")
 
15
 
16
  def __init__(self, seed=42):
17
  """Initialize the augmenter with random seeds for reproducibility."""
18
+ # Set random seeds for consistent results
19
  random.seed(seed)
20
  np.random.seed(seed)
21
  self.fake = Faker()
22
  Faker.seed(seed)
23
 
24
+ # Initialize data structures
25
  self._init_templates()
26
  self._init_context_phrases()
27
  self._init_generators()
28
 
29
  def _init_templates(self):
30
  """Initialize templates for different PII types."""
31
+ # Templates for generating sentences with PII
32
  self.templates = {
33
  'NAME_STUDENT': [
34
  "My name is {name}",
 
118
 
119
  def _init_context_phrases(self):
120
  """Initialize context phrases for more natural text generation."""
121
+ # Opening phrases for generated text
122
  self.context_prefix = [
123
  "Hello everyone,",
124
  "Dear Sir/Madam,",
 
132
  "I am writing to tell you that"
133
  ]
134
 
135
+ # Closing phrases for generated text
136
  self.context_suffix = [
137
  "Thank you.",
138
  "Best regards.",
 
146
  "Let me know if you have questions."
147
  ]
148
 
149
+ # Words to connect multiple PII elements
150
  self.connectors = [
151
  " and ", " or ", ", ", ". Also, ", ". Additionally, "
152
  ]
153
 
154
  def _init_generators(self):
155
  """Initialize PII generators mapping."""
156
+ # Map PII types to their generator functions
157
  self.generators = {
158
  'NAME_STUDENT': self.generate_name,
159
  'EMAIL': self.generate_email,
 
164
  'USERNAME': self.generate_username
165
  }
166
 
167
+ # Map PII types to template placeholder keys
168
  self.format_keys = {
169
  'NAME_STUDENT': 'name',
170
  'EMAIL': 'email',
 
175
  'USERNAME': 'username'
176
  }
177
 
 
 
178
  def generate_name(self):
179
  """Generate realistic person names."""
180
  return self.fake.name()
 
185
 
186
  def generate_phone(self):
187
  """Generate realistic phone numbers in various formats."""
188
+ # Different phone number formats
189
  formats = [
190
  "555-{:03d}-{:04d}",
191
  "(555) {:03d}-{:04d}",
 
193
  "+1-555-{:03d}-{:04d}",
194
  "555{:03d}{:04d}"
195
  ]
196
+ # Pick a random format and fill with random numbers
197
  format_choice = random.choice(formats)
198
  area = random.randint(100, 999)
199
  number = random.randint(1000, 9999)
 
201
 
202
  def generate_address(self):
203
  """Generate realistic street addresses."""
204
+ # Get address and replace newlines with commas
205
  return self.fake.address().replace('\n', ', ')
206
 
207
  def generate_id_num(self):
208
  """Generate various ID number formats."""
209
+ # Different ID number patterns
210
  formats = [
211
  "{:06d}", # 6-digit ID
212
  "{:08d}", # 8-digit ID
 
217
  ]
218
  format_choice = random.choice(formats)
219
 
220
+ # Handle hyphenated format differently
221
  if '-' in format_choice:
222
  return format_choice.format(
223
  random.randint(1000, 9999),
 
228
 
229
  def generate_url(self):
230
  """Generate personal website URLs."""
231
+ # Common personal website domains
232
  domains = ['github.com', 'linkedin.com', 'portfolio.com',
233
  'personal.com', 'website.com']
234
  username = self.fake.user_name()
 
239
  """Generate usernames."""
240
  return self.fake.user_name()
241
 
 
 
242
  def create_synthetic_example(self, pii_type, add_context=True):
243
+ """Create a synthetic example with proper BIO labeling."""
244
+ # Generate the PII value
 
 
 
 
 
 
 
 
 
245
  pii_value = self.generators[pii_type]()
246
 
247
+ # Choose a template and insert the PII
248
  template = random.choice(self.templates[pii_type])
249
  format_key = self.format_keys[pii_type]
250
  sentence = template.format(**{format_key: pii_value})
251
 
252
+ # Optionally add context for more natural text
253
  if add_context and random.random() > 0.3:
254
  sentence = self._add_context(sentence)
255
 
256
+ # Create tokens and labels
257
  tokens, labels = self._tokenize_and_label(sentence, pii_value, pii_type)
258
 
259
  return tokens, labels
260
 
261
  def create_mixed_example(self, pii_types, num_pii=2):
262
+ """Create examples with multiple PII types."""
263
+ # Select which PII types to include
 
 
 
 
 
 
 
 
264
  selected_types = random.sample(pii_types, min(num_pii, len(pii_types)))
265
 
266
  all_tokens = []
267
  all_labels = []
268
 
269
+ # Add opening context
270
  if random.random() > 0.3:
271
  prefix = random.choice(self.context_prefix)
272
  all_tokens.extend(prefix.split())
273
  all_labels.extend(['O'] * len(prefix.split()))
274
 
275
+ # Add each PII entity
276
  for i, pii_type in enumerate(selected_types):
277
+ # Add connector between PII entities
278
  if i > 0 and random.random() > 0.5:
279
  connector = random.choice(self.connectors)
280
  all_tokens.extend(connector.strip().split())
281
  all_labels.extend(['O'] * len(connector.strip().split()))
282
 
283
+ # Generate PII example
284
  tokens, labels = self.create_synthetic_example(pii_type, add_context=False)
285
  all_tokens.extend(tokens)
286
  all_labels.extend(labels)
287
 
288
+ # Add closing context
289
  if random.random() > 0.3:
290
  suffix = random.choice(self.context_suffix)
291
  all_tokens.extend(suffix.split())
 
295
 
296
  def _add_context(self, sentence):
297
  """Add context phrases to make text more natural."""
298
+ # Randomly add prefix
299
  if random.random() > 0.5:
300
  sentence = random.choice(self.context_prefix) + " " + sentence
301
+ # Randomly add suffix
302
  if random.random() > 0.5:
303
  sentence = sentence + " " + random.choice(self.context_suffix)
304
  return sentence
305
 
306
  def _tokenize_and_label(self, sentence, pii_value, pii_type):
307
+ """Tokenize sentence and apply BIO labels for PII."""
308
+ # Split sentence into tokens
 
 
 
 
 
 
 
 
 
309
  tokens = sentence.split()
310
  labels = ['O'] * len(tokens)
311
 
312
+ # Split PII value into tokens
313
  pii_tokens = pii_value.split()
314
 
315
+ # Find where PII appears in the sentence
316
  for i in range(len(tokens) - len(pii_tokens) + 1):
317
+ # Check if tokens match the PII value
318
  if (tokens[i:i+len(pii_tokens)] == pii_tokens or
319
  ' '.join(tokens[i:i+len(pii_tokens)]).lower() == pii_value.lower()):
320
 
321
+ # Apply BIO tagging
322
+ labels[i] = f'B-{pii_type}' # Beginning
323
  for j in range(1, len(pii_tokens)):
324
+ labels[i+j] = f'I-{pii_type}' # Inside
325
  break
326
 
327
  return tokens, labels
328
 
 
 
329
  def augment_dataset(self, original_data, target_samples_per_class=1000, mix_ratio=0.3):
330
+ """Augment dataset with synthetic examples to balance PII classes."""
331
+ # Check current distribution
 
 
 
 
 
 
 
 
 
 
332
  label_counts = self._analyze_label_distribution(original_data)
333
  print("\nOriginal label distribution:")
334
  self._print_distribution(label_counts)
335
 
336
+ # Generate synthetic data
337
  synthetic_tokens, synthetic_labels = self._generate_synthetic_data(
338
  label_counts, target_samples_per_class, mix_ratio
339
  )
340
 
341
+ # Add some non-PII examples for balance
342
  synthetic_tokens, synthetic_labels = self._add_non_pii_examples(
343
  synthetic_tokens, synthetic_labels
344
  )
345
 
346
+ # Combine original and synthetic data
347
  augmented_df = self._combine_and_shuffle(
348
  original_data, synthetic_tokens, synthetic_labels
349
  )
350
 
351
+ # Check new distribution
352
  new_label_counts = self._analyze_label_distribution(augmented_df)
353
  print("\nAugmented label distribution:")
354
  self._print_distribution(new_label_counts)
 
359
  """Analyze the distribution of PII labels in the dataset."""
360
  label_counts = Counter()
361
 
362
+ # Count each PII type
363
  for labels in data['labels']:
364
  for label in labels:
365
  if label != 'O':
366
+ # Remove B- or I- prefix to get base label
367
  base_label = label.split('-')[1] if '-' in label else label
368
  label_counts[base_label] += 1
369
 
 
372
  def _print_distribution(self, label_counts):
373
  """Print label distribution statistics."""
374
  total = sum(label_counts.values())
375
+ # Print each label count and percentage
376
  for label, count in label_counts.most_common():
377
  percentage = (count / total * 100) if total > 0 else 0
378
  print(f" {label:15} : {count:6,} ({percentage:5.2f}%)")
 
382
  synthetic_tokens = []
383
  synthetic_labels = []
384
 
385
+ # Generate examples for each PII type
386
  for pii_type in self.templates.keys():
387
  current_count = label_counts.get(pii_type, 0)
388
  needed = max(0, target_samples - current_count)
 
392
 
393
  print(f"\nGenerating {needed} synthetic examples for {pii_type}")
394
 
395
+ # Generate single PII examples
396
  single_count = int(needed * (1 - mix_ratio))
397
  for _ in range(single_count):
398
  tokens, labels = self.create_synthetic_example(pii_type)
399
  synthetic_tokens.append(tokens)
400
  synthetic_labels.append(labels)
401
 
402
+ # Generate mixed PII examples
403
  mixed_count = int(needed * mix_ratio)
404
  for _ in range(mixed_count):
405
+ # Make sure current PII type is included
406
  other_types = [t for t in self.templates.keys() if t != pii_type]
407
  selected_types = [pii_type] + random.sample(
408
  other_types, min(1, len(other_types))
 
416
 
417
  def _add_non_pii_examples(self, synthetic_tokens, synthetic_labels):
418
  """Add examples without PII (all 'O' labels) for balance."""
419
+ # Add 10% non-PII examples
420
  num_non_pii = int(len(synthetic_tokens) * 0.1)
421
 
422
  for _ in range(num_non_pii):
 
432
 
433
  def _combine_and_shuffle(self, original_data, synthetic_tokens, synthetic_labels):
434
  """Combine original and synthetic data, then shuffle."""
435
+ # Merge all data
436
  all_tokens = original_data['tokens'].tolist() + synthetic_tokens
437
  all_labels = original_data['labels'].tolist() + synthetic_labels
438
 
439
+ # Create new dataframe
440
  augmented_data = pd.DataFrame({
441
  'tokens': all_tokens,
442
  'labels': all_labels
443
  })
444
 
445
+ # Shuffle the data
446
  augmented_data = augmented_data.sample(frac=1, random_state=42).reset_index(drop=True)
447
 
448
  print(f"\nTotal augmented samples: {len(augmented_data):,}")
 
450
  return augmented_data
451
 
452
  def calculate_class_weights(data, label_vocab):
453
+ """Calculate class weights for balanced loss function."""
454
+ # Count occurrences of each label
 
 
 
 
 
 
 
 
 
455
  label_counts = Counter()
456
 
457
  for labels in data['labels']:
 
459
  label_id = label_vocab.word2idx.get(label.lower(), 0)
460
  label_counts[label_id] += 1
461
 
462
+ # Calculate weights based on inverse frequency
463
  total_samples = sum(label_counts.values())
464
  num_classes = len(label_vocab)
465
 
 
467
 
468
  for class_id, count in label_counts.items():
469
  if count > 0:
470
+ # Inverse frequency weighting
471
  weights[class_id] = total_samples / (num_classes * count)
472
 
473
+ # Normalize the weights
474
  weights = weights / weights.sum() * num_classes
475
 
476
+ # Prevent extreme weights
477
  weights = torch.clamp(weights, min=0.1, max=10.0)
478
 
479
+ # Don't weight padding tokens
480
  weights[0] = 0.0
481
 
482
  return weights
483
 
484
  if __name__ == '__main__':
485
  """Example usage of the augmentation module."""
486
+ # Load original training data
487
  print("Loading original training data...")
488
  original_data = pd.read_json('train.json')
489
  print(f"Original dataset size: {len(original_data):,}")
490
 
491
+ # Create augmenter instance
492
  augmenter = PIIDataAugmenter(seed=42)
493
 
494
+ # Run augmentation
495
  print("\n" + "="*60)
496
  print("Starting data augmentation...")
497
  print("="*60)
 
502
  mix_ratio=0.3
503
  )
504
 
505
+ # Save the augmented dataset
506
  output_path = './train_augmented.json'
507
  augmented_data.to_json(output_path, orient='records', lines=True)
508
  print(f"\nSaved augmented data to {output_path}")
lstm.py CHANGED
@@ -12,28 +12,28 @@ class LSTMCell(nn.Module):
12
  self.input_size = input_size
13
  self.hidden_size = hidden_size
14
 
15
- # Initialize weight matrices and bias vectors for LSTM gates
16
- # Input gate
17
  self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
18
  self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
19
  self.b_i = nn.Parameter(torch.Tensor(hidden_size))
20
 
21
- # Forget gate
22
  self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
23
  self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
24
  self.b_f = nn.Parameter(torch.Tensor(hidden_size))
25
 
26
- # Input node (candidate)
27
  self.W_in = nn.Parameter(torch.Tensor(input_size, hidden_size))
28
  self.W_hn = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
29
  self.b_n = nn.Parameter(torch.Tensor(hidden_size))
30
 
31
- # Output gate
32
  self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
33
  self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
34
  self.b_o = nn.Parameter(torch.Tensor(hidden_size))
35
 
36
- # Initialize all weights with xavier_uniform and biases with zeros
37
  for name, param in self.named_parameters():
38
  if 'W_' in name:
39
  nn.init.xavier_uniform_(param)
@@ -41,35 +41,26 @@ class LSTMCell(nn.Module):
41
  nn.init.zeros_(param)
42
 
43
  def forward(self, input: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
44
- """
45
- Forward pass for one time step
46
- Args:
47
- input: input at current time step [batch_size, input_size]
48
- states: tuple of (hidden_state, cell_state) from previous time step
49
- both with shape [batch_size, hidden_size]
50
- Returns:
51
- new_hidden: updated hidden state [batch_size, hidden_size]
52
- new_cell: updated cell state [batch_size, hidden_size]
53
- """
54
  hidden, cell = states
55
 
56
- # Implement LSTM cell forward pass
57
- # Forget gate: f_t = sigmoid(W_if @ x_t + W_hf @ h_{t-1} + b_f)
58
  forget_gate = torch.sigmoid(torch.mm(input, self.W_if) + torch.mm(hidden, self.W_hf) + self.b_f)
59
 
60
- # Input gate: i_t = sigmoid(W_ii @ x_t + W_hi @ h_{t-1} + b_i)
61
  input_gate = torch.sigmoid(torch.mm(input, self.W_ii) + torch.mm(hidden, self.W_hi) + self.b_i)
62
 
63
- # Input node values: n_t = tanh(W_in @ x_t + W_hn @ h_{t-1} + b_n)
64
  candidate = torch.tanh(torch.mm(input, self.W_in) + torch.mm(hidden, self.W_hn) + self.b_n)
65
 
66
- # Output gate: o_t = sigmoid(W_io @ x_t + W_ho @ h_{t-1} + b_o)
67
  output_gate = torch.sigmoid(torch.mm(input, self.W_io) + torch.mm(hidden, self.W_ho) + self.b_o)
68
 
69
- # Update cell state: c_t = f_t * c_{t-1} + i_t * n_t
70
  new_cell = forget_gate * cell + input_gate * candidate
71
 
72
- # Update hidden state: h_t = o_t * tanh(c_t)
73
  new_hidden = output_gate * torch.tanh(new_cell)
74
 
75
  return new_hidden, new_cell
@@ -87,27 +78,31 @@ class BidirectionalLSTM(nn.Module):
87
  self.batch_first = batch_first
88
  self.dropout = dropout if num_layers > 1 else 0.0
89
 
90
- # Create forward and backward cells for each layer
91
  self.forward_cells = nn.ModuleList()
92
  self.backward_cells = nn.ModuleList()
93
  self.dropout_layers = nn.ModuleList() if self.dropout > 0 else None
94
 
95
  for layer in range(num_layers):
96
- # Input size is input_size for first layer, hidden_size * 2 for others (bidirectional)
97
  layer_input_size = input_size if layer == 0 else hidden_size * 2
98
 
 
99
  self.forward_cells.append(LSTMCell(layer_input_size, hidden_size))
100
  self.backward_cells.append(LSTMCell(layer_input_size, hidden_size))
101
 
 
102
  if self.dropout > 0 and layer < num_layers - 1:
103
  self.dropout_layers.append(nn.Dropout(dropout))
104
 
105
  def forward(self, input, states=None, lengths=None):
106
- # Handle PackedSequence input
107
  is_packed = isinstance(input, PackedSequence)
108
  if is_packed:
 
109
  padded, lengths = pad_packed_sequence(input, batch_first=self.batch_first)
110
  outputs, (h_n, c_n) = self._forward_unpacked(padded, states, lengths)
 
111
  packed_out = pack_padded_sequence(
112
  outputs, lengths,
113
  batch_first=self.batch_first,
@@ -118,13 +113,15 @@ class BidirectionalLSTM(nn.Module):
118
  return self._forward_unpacked(input, states, lengths)
119
 
120
  def _forward_unpacked(self, input: torch.Tensor, states, lengths=None):
 
121
  if not self.batch_first:
122
  input = input.transpose(0, 1)
123
 
124
  batch_size, seq_len, _ = input.size()
125
 
126
- # Initialize states if not provided
127
  if states is None:
 
128
  h_t_forward = [input.new_zeros(batch_size, self.hidden_size)
129
  for _ in range(self.num_layers)]
130
  c_t_forward = [input.new_zeros(batch_size, self.hidden_size)
@@ -134,23 +131,24 @@ class BidirectionalLSTM(nn.Module):
134
  c_t_backward = [input.new_zeros(batch_size, self.hidden_size)
135
  for _ in range(self.num_layers)]
136
  else:
 
137
  h0, c0 = states
138
- # h0 and c0 are [num_layers * 2, batch_size, hidden_size]
139
  h_t_forward = []
140
  c_t_forward = []
141
  h_t_backward = []
142
  c_t_backward = []
143
 
 
144
  for layer in range(self.num_layers):
145
  h_t_forward.append(h0[layer * 2])
146
  c_t_forward.append(c0[layer * 2])
147
  h_t_backward.append(h0[layer * 2 + 1])
148
  c_t_backward.append(c0[layer * 2 + 1])
149
 
150
- # Process through layers
151
  layer_input = input
152
  for layer_idx in range(self.num_layers):
153
- # Forward direction
154
  forward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
155
  for t in range(seq_len):
156
  x = layer_input[:, t, :]
@@ -159,7 +157,7 @@ class BidirectionalLSTM(nn.Module):
159
  c_t_forward[layer_idx] = c
160
  forward_output[:, t, :] = h
161
 
162
- # Backward direction
163
  backward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
164
  for t in reversed(range(seq_len)):
165
  x = layer_input[:, t, :]
@@ -168,19 +166,20 @@ class BidirectionalLSTM(nn.Module):
168
  c_t_backward[layer_idx] = c
169
  backward_output[:, t, :] = h
170
 
171
- # Concatenate forward and backward
172
  layer_output = torch.cat([forward_output, backward_output], dim=2)
173
 
174
- # Apply dropout between layers (except last layer)
175
  if self.dropout > 0 and layer_idx < self.num_layers - 1:
176
  layer_output = self.dropout_layers[layer_idx](layer_output)
177
 
 
178
  layer_input = layer_output
179
 
180
- # Final output
181
  outputs = layer_output
182
 
183
- # Stack hidden and cell states
184
  h_n = []
185
  c_n = []
186
  for layer in range(self.num_layers):
@@ -189,6 +188,7 @@ class BidirectionalLSTM(nn.Module):
189
  h_n = torch.stack(h_n, dim=0)
190
  c_n = torch.stack(c_n, dim=0)
191
 
 
192
  if not self.batch_first:
193
  outputs = outputs.transpose(0, 1)
194
 
@@ -209,7 +209,7 @@ class LSTM(nn.Module):
209
  self.hidden_size = hidden_size
210
  self.num_layers = num_layers
211
 
212
- # Embedding layer
213
  self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
214
  self.embed_dropout = nn.Dropout(dropout)
215
 
@@ -222,8 +222,8 @@ class LSTM(nn.Module):
222
  dropout=dropout if num_layers > 1 else 0.0
223
  )
224
 
225
- # Output projection layer
226
- lstm_output_size = hidden_size * 2 # bidirectional
227
  self.fc = nn.Linear(lstm_output_size, num_classes)
228
  self.output_dropout = nn.Dropout(dropout)
229
 
@@ -236,11 +236,11 @@ class LSTM(nn.Module):
236
  Returns:
237
  logits: class predictions [batch_size, seq_len, num_classes]
238
  """
239
- # Embedding
240
- embedded = self.embedding(input_ids) # [batch_size, seq_len, embed_size]
241
  embedded = self.embed_dropout(embedded)
242
 
243
- # Pack if lengths provided for efficiency
244
  if lengths is not None:
245
  packed_embedded = pack_padded_sequence(
246
  embedded, lengths.cpu(),
@@ -248,40 +248,27 @@ class LSTM(nn.Module):
248
  enforce_sorted=False
249
  )
250
  lstm_out, _ = self.lstm(packed_embedded)
 
251
  lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
252
  else:
 
253
  lstm_out, _ = self.lstm(embedded)
254
 
255
- # Apply dropout and project to output
256
  lstm_out = self.output_dropout(lstm_out)
257
- logits = self.fc(lstm_out) # [batch_size, seq_len, num_classes]
258
 
259
  return logits
260
 
261
  def create_lstm_pii_model(vocab_size: int, num_classes: int, d_model: int = 256,
262
  num_heads: int = 8, d_ff: int = 512, num_layers: int = 4,
263
  dropout: float = 0.1, max_len: int = 512):
264
- """
265
- Create Bidirectional LSTM model for PII detection
266
- Note: num_heads and d_ff are ignored (kept for compatibility with transformer interface)
267
-
268
- Args:
269
- vocab_size: size of vocabulary
270
- num_classes: number of output classes (PII tags)
271
- d_model: hidden dimension size
272
- num_heads: ignored (for compatibility)
273
- d_ff: ignored (for compatibility)
274
- num_layers: number of LSTM layers
275
- dropout: dropout rate
276
- max_len: maximum sequence length
277
-
278
- Returns:
279
- LSTM
280
- """
281
  return LSTM(
282
  vocab_size=vocab_size,
283
  num_classes=num_classes,
284
- embed_size=d_model // 2, # Use half of d_model as embedding size
285
  hidden_size=d_model,
286
  num_layers=num_layers,
287
  dropout=dropout,
 
12
  self.input_size = input_size
13
  self.hidden_size = hidden_size
14
 
15
+ # Weight matrices and biases for each gate
16
+ # Input gate parameters
17
  self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
18
  self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
19
  self.b_i = nn.Parameter(torch.Tensor(hidden_size))
20
 
21
+ # Forget gate parameters
22
  self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
23
  self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
24
  self.b_f = nn.Parameter(torch.Tensor(hidden_size))
25
 
26
+ # Candidate values parameters
27
  self.W_in = nn.Parameter(torch.Tensor(input_size, hidden_size))
28
  self.W_hn = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
29
  self.b_n = nn.Parameter(torch.Tensor(hidden_size))
30
 
31
+ # Output gate parameters
32
  self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
33
  self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
34
  self.b_o = nn.Parameter(torch.Tensor(hidden_size))
35
 
36
+ # Initialize weights using Xavier initialization
37
  for name, param in self.named_parameters():
38
  if 'W_' in name:
39
  nn.init.xavier_uniform_(param)
 
41
  nn.init.zeros_(param)
42
 
43
  def forward(self, input: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
44
+ """Forward pass for one time step"""
45
+ # Unpack previous states
 
 
 
 
 
 
 
 
46
  hidden, cell = states
47
 
48
+ # Calculate forget gate - decides what to forget from previous cell state
 
49
  forget_gate = torch.sigmoid(torch.mm(input, self.W_if) + torch.mm(hidden, self.W_hf) + self.b_f)
50
 
51
+ # Calculate input gate - decides what new information to store
52
  input_gate = torch.sigmoid(torch.mm(input, self.W_ii) + torch.mm(hidden, self.W_hi) + self.b_i)
53
 
54
+ # Calculate candidate values - new information that could be added
55
  candidate = torch.tanh(torch.mm(input, self.W_in) + torch.mm(hidden, self.W_hn) + self.b_n)
56
 
57
+ # Calculate output gate - decides what parts of cell state to output
58
  output_gate = torch.sigmoid(torch.mm(input, self.W_io) + torch.mm(hidden, self.W_ho) + self.b_o)
59
 
60
+ # Update cell state by forgetting old info and adding new info
61
  new_cell = forget_gate * cell + input_gate * candidate
62
 
63
+ # Generate new hidden state based on filtered cell state
64
  new_hidden = output_gate * torch.tanh(new_cell)
65
 
66
  return new_hidden, new_cell
 
78
  self.batch_first = batch_first
79
  self.dropout = dropout if num_layers > 1 else 0.0
80
 
81
+ # Create forward and backward LSTM cells for each layer
82
  self.forward_cells = nn.ModuleList()
83
  self.backward_cells = nn.ModuleList()
84
  self.dropout_layers = nn.ModuleList() if self.dropout > 0 else None
85
 
86
  for layer in range(num_layers):
87
+ # First layer takes input_size, others take concatenated bidirectional output
88
  layer_input_size = input_size if layer == 0 else hidden_size * 2
89
 
90
+ # Add forward and backward cells for this layer
91
  self.forward_cells.append(LSTMCell(layer_input_size, hidden_size))
92
  self.backward_cells.append(LSTMCell(layer_input_size, hidden_size))
93
 
94
+ # Add dropout between layers (except after last layer)
95
  if self.dropout > 0 and layer < num_layers - 1:
96
  self.dropout_layers.append(nn.Dropout(dropout))
97
 
98
  def forward(self, input, states=None, lengths=None):
99
+ # Check if input is packed sequence
100
  is_packed = isinstance(input, PackedSequence)
101
  if is_packed:
102
+ # Unpack for processing
103
  padded, lengths = pad_packed_sequence(input, batch_first=self.batch_first)
104
  outputs, (h_n, c_n) = self._forward_unpacked(padded, states, lengths)
105
+ # Pack output back
106
  packed_out = pack_padded_sequence(
107
  outputs, lengths,
108
  batch_first=self.batch_first,
 
113
  return self._forward_unpacked(input, states, lengths)
114
 
115
  def _forward_unpacked(self, input: torch.Tensor, states, lengths=None):
116
+ # Convert to batch-first if needed
117
  if not self.batch_first:
118
  input = input.transpose(0, 1)
119
 
120
  batch_size, seq_len, _ = input.size()
121
 
122
+ # Initialize hidden and cell states if not provided
123
  if states is None:
124
+ # Create zero states for each layer and direction
125
  h_t_forward = [input.new_zeros(batch_size, self.hidden_size)
126
  for _ in range(self.num_layers)]
127
  c_t_forward = [input.new_zeros(batch_size, self.hidden_size)
 
131
  c_t_backward = [input.new_zeros(batch_size, self.hidden_size)
132
  for _ in range(self.num_layers)]
133
  else:
134
+ # Unpack provided states
135
  h0, c0 = states
 
136
  h_t_forward = []
137
  c_t_forward = []
138
  h_t_backward = []
139
  c_t_backward = []
140
 
141
+ # Separate forward and backward states for each layer
142
  for layer in range(self.num_layers):
143
  h_t_forward.append(h0[layer * 2])
144
  c_t_forward.append(c0[layer * 2])
145
  h_t_backward.append(h0[layer * 2 + 1])
146
  c_t_backward.append(c0[layer * 2 + 1])
147
 
148
+ # Process through each layer
149
  layer_input = input
150
  for layer_idx in range(self.num_layers):
151
+ # Process forward direction
152
  forward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
153
  for t in range(seq_len):
154
  x = layer_input[:, t, :]
 
157
  c_t_forward[layer_idx] = c
158
  forward_output[:, t, :] = h
159
 
160
+ # Process backward direction
161
  backward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
162
  for t in reversed(range(seq_len)):
163
  x = layer_input[:, t, :]
 
166
  c_t_backward[layer_idx] = c
167
  backward_output[:, t, :] = h
168
 
169
+ # Concatenate forward and backward outputs
170
  layer_output = torch.cat([forward_output, backward_output], dim=2)
171
 
172
+ # Apply dropout between layers
173
  if self.dropout > 0 and layer_idx < self.num_layers - 1:
174
  layer_output = self.dropout_layers[layer_idx](layer_output)
175
 
176
+ # Use this layer's output as next layer's input
177
  layer_input = layer_output
178
 
179
+ # Final output is the last layer's output
180
  outputs = layer_output
181
 
182
+ # Stack final hidden and cell states for all layers
183
  h_n = []
184
  c_n = []
185
  for layer in range(self.num_layers):
 
188
  h_n = torch.stack(h_n, dim=0)
189
  c_n = torch.stack(c_n, dim=0)
190
 
191
+ # Convert back if not batch-first
192
  if not self.batch_first:
193
  outputs = outputs.transpose(0, 1)
194
 
 
209
  self.hidden_size = hidden_size
210
  self.num_layers = num_layers
211
 
212
+ # Embedding layer to convert tokens to vectors
213
  self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
214
  self.embed_dropout = nn.Dropout(dropout)
215
 
 
222
  dropout=dropout if num_layers > 1 else 0.0
223
  )
224
 
225
+ # Output layer to predict PII labels
226
+ lstm_output_size = hidden_size * 2 # doubled for bidirectional
227
  self.fc = nn.Linear(lstm_output_size, num_classes)
228
  self.output_dropout = nn.Dropout(dropout)
229
 
 
236
  Returns:
237
  logits: class predictions [batch_size, seq_len, num_classes]
238
  """
239
+ # Convert token ids to embeddings
240
+ embedded = self.embedding(input_ids)
241
  embedded = self.embed_dropout(embedded)
242
 
243
+ # Pack sequences for efficient processing if lengths provided
244
  if lengths is not None:
245
  packed_embedded = pack_padded_sequence(
246
  embedded, lengths.cpu(),
 
248
  enforce_sorted=False
249
  )
250
  lstm_out, _ = self.lstm(packed_embedded)
251
+ # Unpack the output
252
  lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
253
  else:
254
+ # Process without packing
255
  lstm_out, _ = self.lstm(embedded)
256
 
257
+ # Apply dropout and get final predictions
258
  lstm_out = self.output_dropout(lstm_out)
259
+ logits = self.fc(lstm_out)
260
 
261
  return logits
262
 
263
  def create_lstm_pii_model(vocab_size: int, num_classes: int, d_model: int = 256,
264
  num_heads: int = 8, d_ff: int = 512, num_layers: int = 4,
265
  dropout: float = 0.1, max_len: int = 512):
266
+ """Create Bidirectional LSTM model for PII detection"""
267
+ # Create LSTM with appropriate dimensions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  return LSTM(
269
  vocab_size=vocab_size,
270
  num_classes=num_classes,
271
+ embed_size=d_model // 2,
272
  hidden_size=d_model,
273
  num_layers=num_layers,
274
  dropout=dropout,
lstm_training.ipynb CHANGED
@@ -42,7 +42,7 @@
42
  },
43
  {
44
  "cell_type": "code",
45
- "execution_count": 4,
46
  "id": "1207cd93",
47
  "metadata": {
48
  "execution": {
@@ -62,19 +62,23 @@
62
  },
63
  "outputs": [],
64
  "source": [
 
65
  "class Vocabulary:\n",
66
  " \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
67
  " def __init__(self, max_size=100000):\n",
 
68
  " self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
69
  " self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
70
  " self.word_count = Counter()\n",
71
  " self.max_size = max_size\n",
72
  " \n",
73
  " def add_sentence(self, sentence):\n",
 
74
  " for word in sentence:\n",
75
  " self.word_count[word.lower()] += 1\n",
76
  " \n",
77
  " def build(self):\n",
 
78
  " most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
79
  " for word, _ in most_common:\n",
80
  " if word not in self.word2idx:\n",
@@ -86,15 +90,17 @@
86
  " return len(self.word2idx)\n",
87
  " \n",
88
  " def encode(self, sentence):\n",
 
89
  " return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
90
  " \n",
91
  " def decode(self, indices):\n",
 
92
  " return [self.idx2word.get(idx, '<unk>') for idx in indices]"
93
  ]
94
  },
95
  {
96
  "cell_type": "code",
97
- "execution_count": 5,
98
  "id": "f4056292",
99
  "metadata": {
100
  "execution": {
@@ -114,6 +120,7 @@
114
  },
115
  "outputs": [],
116
  "source": [
 
117
  "class PIIDataset(Dataset):\n",
118
  " \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
119
  " def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
@@ -127,16 +134,16 @@
127
  " return len(self.tokens)\n",
128
  " \n",
129
  " def __getitem__(self, idx):\n",
130
- " # Add start and end tokens\n",
131
  " tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
132
  " labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
133
  " \n",
134
- " # Truncate if too long\n",
135
  " if len(tokens) > self.max_len:\n",
136
  " tokens = tokens[:self.max_len-1] + ['<end>']\n",
137
  " labels = labels[:self.max_len-1] + ['<end>']\n",
138
  " \n",
139
- " # Encode\n",
140
  " token_ids = self.text_vocab.encode(tokens)\n",
141
  " label_ids = self.label_vocab.encode(labels)\n",
142
  " \n",
@@ -145,7 +152,7 @@
145
  },
146
  {
147
  "cell_type": "code",
148
- "execution_count": 6,
149
  "id": "499deba2",
150
  "metadata": {
151
  "execution": {
@@ -167,7 +174,9 @@
167
  "source": [
168
  "def collate_fn(batch):\n",
169
  " \"\"\"Custom collate function for padding sequences\"\"\"\n",
 
170
  " tokens, labels = zip(*batch)\n",
 
171
  " tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
172
  " labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
173
  " return tokens_padded, labels_padded"
@@ -175,7 +184,7 @@
175
  },
176
  {
177
  "cell_type": "code",
178
- "execution_count": 7,
179
  "id": "7ade0505",
180
  "metadata": {
181
  "execution": {
@@ -195,6 +204,7 @@
195
  },
196
  "outputs": [],
197
  "source": [
 
198
  "class F1ScoreMetric:\n",
199
  " \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
200
  " def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
@@ -205,26 +215,32 @@
205
  " self.reset()\n",
206
  " \n",
207
  " def reset(self):\n",
 
208
  " self.true_positives = 0\n",
209
  " self.false_positives = 0\n",
210
  " self.false_negatives = 0\n",
211
  " self.class_metrics = {}\n",
212
  " \n",
213
  " def update(self, predictions, targets):\n",
 
214
  " mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
215
  " o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
216
  " \n",
 
217
  " for class_id in range(1, self.num_classes):\n",
218
  " if class_id == o_idx:\n",
219
  " continue\n",
220
- " \n",
 
221
  " pred_mask = (predictions == class_id) & mask\n",
222
  " true_mask = (targets == class_id) & mask\n",
223
  " \n",
 
224
  " tp = ((pred_mask) & (true_mask)).sum().item()\n",
225
  " fp = ((pred_mask) & (~true_mask)).sum().item()\n",
226
  " fn = ((~pred_mask) & (true_mask)).sum().item()\n",
227
  " \n",
 
228
  " self.true_positives += tp\n",
229
  " self.false_positives += fp\n",
230
  " self.false_negatives += fn\n",
@@ -236,6 +252,7 @@
236
  " self.class_metrics[class_id]['fn'] += fn\n",
237
  " \n",
238
  " def compute(self):\n",
 
239
  " beta_squared = self.beta ** 2\n",
240
  " precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
241
  " recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
@@ -243,6 +260,7 @@
243
  " return f1\n",
244
  " \n",
245
  " def get_class_metrics(self):\n",
 
246
  " results = {}\n",
247
  " for class_id, metrics in self.class_metrics.items():\n",
248
  " if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
@@ -261,7 +279,7 @@
261
  },
262
  {
263
  "cell_type": "code",
264
- "execution_count": 8,
265
  "id": "361b5505",
266
  "metadata": {
267
  "execution": {
@@ -281,6 +299,7 @@
281
  },
282
  "outputs": [],
283
  "source": [
 
284
  "class FocalLoss(nn.Module):\n",
285
  " \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
286
  " def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
@@ -291,6 +310,7 @@
291
  " self.ignore_index = ignore_index\n",
292
  " \n",
293
  " def forward(self, inputs, targets):\n",
 
294
  " ce_loss = nn.functional.cross_entropy(\n",
295
  " inputs, targets, \n",
296
  " weight=self.alpha, \n",
@@ -298,9 +318,11 @@
298
  " ignore_index=self.ignore_index\n",
299
  " )\n",
300
  " \n",
 
301
  " pt = torch.exp(-ce_loss)\n",
302
  " focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
303
  " \n",
 
304
  " if self.reduction == 'mean':\n",
305
  " return focal_loss.mean()\n",
306
  " elif self.reduction == 'sum':\n",
@@ -311,7 +333,7 @@
311
  },
312
  {
313
  "cell_type": "code",
314
- "execution_count": 9,
315
  "id": "1de646e9",
316
  "metadata": {
317
  "execution": {
@@ -337,8 +359,10 @@
337
  " total_loss = 0\n",
338
  " f1_metric.reset()\n",
339
  " \n",
 
340
  " progress_bar = tqdm(dataloader, desc='Training')\n",
341
  " for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
 
342
  " tokens = tokens.to(device)\n",
343
  " labels = labels.to(device)\n",
344
  " \n",
@@ -353,6 +377,7 @@
353
  " # Calculate loss and backward pass\n",
354
  " loss = criterion(outputs_flat, labels_flat)\n",
355
  " loss.backward()\n",
 
356
  " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
357
  " optimizer.step()\n",
358
  " \n",
@@ -372,7 +397,7 @@
372
  },
373
  {
374
  "cell_type": "code",
375
- "execution_count": 10,
376
  "id": "d1ce3b0f",
377
  "metadata": {
378
  "execution": {
@@ -398,8 +423,10 @@
398
  " total_loss = 0\n",
399
  " f1_metric.reset()\n",
400
  " \n",
 
401
  " with torch.no_grad():\n",
402
  " for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
 
403
  " tokens = tokens.to(device)\n",
404
  " labels = labels.to(device)\n",
405
  " \n",
@@ -421,7 +448,7 @@
421
  },
422
  {
423
  "cell_type": "code",
424
- "execution_count": 11,
425
  "id": "da3ff80c",
426
  "metadata": {
427
  "execution": {
@@ -445,6 +472,7 @@
445
  " \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
446
  " sample_weights = []\n",
447
  " \n",
 
448
  " for idx in range(len(dataset)):\n",
449
  " _, labels = dataset[idx]\n",
450
  " \n",
@@ -453,12 +481,14 @@
453
  " for label_id in labels:\n",
454
  " if label_id > 3: # Skip special tokens\n",
455
  " label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
 
456
  " if label_name != 'o' and 'B-' in label_name:\n",
457
  " min_weight = 10.0\n",
458
  " break\n",
459
  " \n",
460
  " sample_weights.append(min_weight)\n",
461
  " \n",
 
462
  " sampler = WeightedRandomSampler(\n",
463
  " weights=sample_weights,\n",
464
  " num_samples=len(sample_weights),\n",
@@ -470,7 +500,7 @@
470
  },
471
  {
472
  "cell_type": "code",
473
- "execution_count": 12,
474
  "id": "69b37e68",
475
  "metadata": {
476
  "execution": {
@@ -493,11 +523,14 @@
493
  "def print_label_distribution(data, title=\"Label Distribution\"):\n",
494
  " \"\"\"Print label distribution statistics\"\"\"\n",
495
  " label_counts = Counter()\n",
 
 
496
  " for label_seq in data.labels:\n",
497
  " for label in label_seq:\n",
498
  " if label not in ['<pad>', '<start>', '<end>']:\n",
499
  " label_counts[label] += 1\n",
500
  " \n",
 
501
  " print(f\"\\n{title}:\")\n",
502
  " print(\"-\" * 50)\n",
503
  " total = sum(label_counts.values())\n",
@@ -510,7 +543,7 @@
510
  },
511
  {
512
  "cell_type": "code",
513
- "execution_count": 13,
514
  "id": "4b1b4f86",
515
  "metadata": {
516
  "execution": {
@@ -534,7 +567,7 @@
534
  " \"\"\"Save model and all necessary components for deployment\"\"\"\n",
535
  " os.makedirs(save_dir, exist_ok=True)\n",
536
  " \n",
537
- " # Save model state\n",
538
  " model_path = os.path.join(save_dir, 'pii_lstm_model.pt')\n",
539
  " torch.save(model.state_dict(), model_path)\n",
540
  " \n",
@@ -560,7 +593,7 @@
560
  },
561
  {
562
  "cell_type": "code",
563
- "execution_count": 14,
564
  "id": "31d2f1b1",
565
  "metadata": {
566
  "execution": {
@@ -596,7 +629,7 @@
596
  " data = pd.read_json(data_path, lines=True)\n",
597
  " print(f\"Total samples: {len(data)}\")\n",
598
  " \n",
599
- " # Print initial label distribution\n",
600
  " print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
601
  " \n",
602
  " # Build vocabularies\n",
@@ -609,6 +642,7 @@
609
  " for labels in data.labels:\n",
610
  " label_vocab.add_sentence(labels)\n",
611
  " \n",
 
612
  " text_vocab.build()\n",
613
  " label_vocab.build()\n",
614
  " \n",
@@ -616,11 +650,11 @@
616
  " print(f\" - Text vocabulary: {len(text_vocab):,}\")\n",
617
  " print(f\" - Label vocabulary: {len(label_vocab)}\")\n",
618
  " \n",
619
- " # Calculate class weights\n",
620
  " class_weights = calculate_class_weights(data, label_vocab)\n",
621
  " class_weights = class_weights.to(device)\n",
622
  " \n",
623
- " # Split data\n",
624
  " X_train, X_val, y_train, y_val = train_test_split(\n",
625
  " data.tokens.tolist(),\n",
626
  " data.labels.tolist(),\n",
@@ -632,14 +666,15 @@
632
  " print(f\" - Train samples: {len(X_train):,}\")\n",
633
  " print(f\" - Validation samples: {len(X_val):,}\")\n",
634
  " \n",
635
- " # Create datasets and dataloaders\n",
636
  " max_seq_len = 512\n",
637
  " train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
638
  " val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
639
  " \n",
640
- " # Use balanced sampler for training\n",
641
  " train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
642
  " \n",
 
643
  " train_loader = DataLoader(\n",
644
  " train_dataset, \n",
645
  " batch_size=batch_size,\n",
@@ -668,7 +703,7 @@
668
  " 'max_len': max_seq_len\n",
669
  " }\n",
670
  " \n",
671
- " # Create model\n",
672
  " print(\"\\nCreating LSTM model...\")\n",
673
  " model = create_lstm_pii_model(**model_config).to(device)\n",
674
  " print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
@@ -701,11 +736,11 @@
701
  " min_lr=1e-6\n",
702
  " )\n",
703
  " \n",
704
- " # Metrics\n",
705
  " f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
706
  " f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
707
  " \n",
708
- " # Training loop\n",
709
  " train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
710
  " best_val_f1 = 0\n",
711
  " patience = 7\n",
@@ -714,18 +749,20 @@
714
  " print(\"\\nStarting training...\")\n",
715
  " print(\"=\" * 60)\n",
716
  " \n",
 
717
  " for epoch in range(num_epochs):\n",
718
  " print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
719
  " \n",
720
- " # Train and validate\n",
721
  " train_loss, train_f1 = train_epoch(\n",
722
  " model, train_loader, optimizer, criterion, device, f1_metric_train\n",
723
  " )\n",
 
724
  " val_loss, val_f1 = evaluate(\n",
725
  " model, val_loader, criterion, device, f1_metric_val\n",
726
  " )\n",
727
  " \n",
728
- " # Step scheduler based on validation loss\n",
729
  " scheduler.step(val_loss)\n",
730
  " \n",
731
  " # Store metrics\n",
@@ -1282,9 +1319,11 @@
1282
  }
1283
  ],
1284
  "source": [
 
1285
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
1286
  "print(f\"Using device: {device}\")\n",
1287
  "\n",
 
1288
  "model, text_vocab, label_vocab = train_lstm_pii_model(\n",
1289
  " data_path='train_augmented.json',\n",
1290
  " num_epochs=20,\n",
 
42
  },
43
  {
44
  "cell_type": "code",
45
+ "execution_count": null,
46
  "id": "1207cd93",
47
  "metadata": {
48
  "execution": {
 
62
  },
63
  "outputs": [],
64
  "source": [
65
+ "# Define vocabulary class for text and label encoding\n",
66
  "class Vocabulary:\n",
67
  " \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
68
  " def __init__(self, max_size=100000):\n",
69
+ " # Initialize special tokens\n",
70
  " self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
71
  " self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
72
  " self.word_count = Counter()\n",
73
  " self.max_size = max_size\n",
74
  " \n",
75
  " def add_sentence(self, sentence):\n",
76
+ " # Count word frequencies\n",
77
  " for word in sentence:\n",
78
  " self.word_count[word.lower()] += 1\n",
79
  " \n",
80
  " def build(self):\n",
81
+ " # Build vocabulary from most common words\n",
82
  " most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
83
  " for word, _ in most_common:\n",
84
  " if word not in self.word2idx:\n",
 
90
  " return len(self.word2idx)\n",
91
  " \n",
92
  " def encode(self, sentence):\n",
93
+ " # Convert words to indices\n",
94
  " return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
95
  " \n",
96
  " def decode(self, indices):\n",
97
+ " # Convert indices back to words\n",
98
  " return [self.idx2word.get(idx, '<unk>') for idx in indices]"
99
  ]
100
  },
101
  {
102
  "cell_type": "code",
103
+ "execution_count": null,
104
  "id": "f4056292",
105
  "metadata": {
106
  "execution": {
 
120
  },
121
  "outputs": [],
122
  "source": [
123
+ "# Dataset class for PII detection\n",
124
  "class PIIDataset(Dataset):\n",
125
  " \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
126
  " def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
 
134
  " return len(self.tokens)\n",
135
  " \n",
136
  " def __getitem__(self, idx):\n",
137
+ " # Add special tokens to beginning and end\n",
138
  " tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
139
  " labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
140
  " \n",
141
+ " # Truncate if sequence is too long\n",
142
  " if len(tokens) > self.max_len:\n",
143
  " tokens = tokens[:self.max_len-1] + ['<end>']\n",
144
  " labels = labels[:self.max_len-1] + ['<end>']\n",
145
  " \n",
146
+ " # Encode tokens and labels to indices\n",
147
  " token_ids = self.text_vocab.encode(tokens)\n",
148
  " label_ids = self.label_vocab.encode(labels)\n",
149
  " \n",
 
152
  },
153
  {
154
  "cell_type": "code",
155
+ "execution_count": null,
156
  "id": "499deba2",
157
  "metadata": {
158
  "execution": {
 
174
  "source": [
175
  "def collate_fn(batch):\n",
176
  " \"\"\"Custom collate function for padding sequences\"\"\"\n",
177
+ " # Separate tokens and labels\n",
178
  " tokens, labels = zip(*batch)\n",
179
+ " # Pad sequences to same length\n",
180
  " tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
181
  " labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
182
  " return tokens_padded, labels_padded"
 
184
  },
185
  {
186
  "cell_type": "code",
187
+ "execution_count": null,
188
  "id": "7ade0505",
189
  "metadata": {
190
  "execution": {
 
204
  },
205
  "outputs": [],
206
  "source": [
207
+ "# F1 score metric for evaluation\n",
208
  "class F1ScoreMetric:\n",
209
  " \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
210
  " def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
 
215
  " self.reset()\n",
216
  " \n",
217
  " def reset(self):\n",
218
+ " # Reset counters\n",
219
  " self.true_positives = 0\n",
220
  " self.false_positives = 0\n",
221
  " self.false_negatives = 0\n",
222
  " self.class_metrics = {}\n",
223
  " \n",
224
  " def update(self, predictions, targets):\n",
225
+ " # Create mask to ignore padding and special tokens\n",
226
  " mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
227
  " o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
228
  " \n",
229
+ " # Calculate metrics for each class\n",
230
  " for class_id in range(1, self.num_classes):\n",
231
  " if class_id == o_idx:\n",
232
  " continue\n",
233
+ " \n",
234
+ " # Find where predictions and targets match this class\n",
235
  " pred_mask = (predictions == class_id) & mask\n",
236
  " true_mask = (targets == class_id) & mask\n",
237
  " \n",
238
+ " # Count true positives, false positives, false negatives\n",
239
  " tp = ((pred_mask) & (true_mask)).sum().item()\n",
240
  " fp = ((pred_mask) & (~true_mask)).sum().item()\n",
241
  " fn = ((~pred_mask) & (true_mask)).sum().item()\n",
242
  " \n",
243
+ " # Update total counts\n",
244
  " self.true_positives += tp\n",
245
  " self.false_positives += fp\n",
246
  " self.false_negatives += fn\n",
 
252
  " self.class_metrics[class_id]['fn'] += fn\n",
253
  " \n",
254
  " def compute(self):\n",
255
+ " # Calculate F-beta score\n",
256
  " beta_squared = self.beta ** 2\n",
257
  " precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
258
  " recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
 
260
  " return f1\n",
261
  " \n",
262
  " def get_class_metrics(self):\n",
263
+ " # Get metrics for each class\n",
264
  " results = {}\n",
265
  " for class_id, metrics in self.class_metrics.items():\n",
266
  " if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
 
279
  },
280
  {
281
  "cell_type": "code",
282
+ "execution_count": null,
283
  "id": "361b5505",
284
  "metadata": {
285
  "execution": {
 
299
  },
300
  "outputs": [],
301
  "source": [
302
+ "# Focal loss for handling class imbalance\n",
303
  "class FocalLoss(nn.Module):\n",
304
  " \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
305
  " def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
 
310
  " self.ignore_index = ignore_index\n",
311
  " \n",
312
  " def forward(self, inputs, targets):\n",
313
+ " # Calculate cross entropy loss\n",
314
  " ce_loss = nn.functional.cross_entropy(\n",
315
  " inputs, targets, \n",
316
  " weight=self.alpha, \n",
 
318
  " ignore_index=self.ignore_index\n",
319
  " )\n",
320
  " \n",
321
+ " # Apply focal term to focus on hard examples\n",
322
  " pt = torch.exp(-ce_loss)\n",
323
  " focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
324
  " \n",
325
+ " # Reduce loss based on specified method\n",
326
  " if self.reduction == 'mean':\n",
327
  " return focal_loss.mean()\n",
328
  " elif self.reduction == 'sum':\n",
 
333
  },
334
  {
335
  "cell_type": "code",
336
+ "execution_count": null,
337
  "id": "1de646e9",
338
  "metadata": {
339
  "execution": {
 
359
  " total_loss = 0\n",
360
  " f1_metric.reset()\n",
361
  " \n",
362
+ " # Progress bar for training\n",
363
  " progress_bar = tqdm(dataloader, desc='Training')\n",
364
  " for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
365
+ " # Move data to device\n",
366
  " tokens = tokens.to(device)\n",
367
  " labels = labels.to(device)\n",
368
  " \n",
 
377
  " # Calculate loss and backward pass\n",
378
  " loss = criterion(outputs_flat, labels_flat)\n",
379
  " loss.backward()\n",
380
+ " # Clip gradients to prevent exploding gradients\n",
381
  " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
382
  " optimizer.step()\n",
383
  " \n",
 
397
  },
398
  {
399
  "cell_type": "code",
400
+ "execution_count": null,
401
  "id": "d1ce3b0f",
402
  "metadata": {
403
  "execution": {
 
423
  " total_loss = 0\n",
424
  " f1_metric.reset()\n",
425
  " \n",
426
+ " # No gradient computation during evaluation\n",
427
  " with torch.no_grad():\n",
428
  " for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
429
+ " # Move data to device\n",
430
  " tokens = tokens.to(device)\n",
431
  " labels = labels.to(device)\n",
432
  " \n",
 
448
  },
449
  {
450
  "cell_type": "code",
451
+ "execution_count": null,
452
  "id": "da3ff80c",
453
  "metadata": {
454
  "execution": {
 
472
  " \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
473
  " sample_weights = []\n",
474
  " \n",
475
+ " # Calculate weight for each sample\n",
476
  " for idx in range(len(dataset)):\n",
477
  " _, labels = dataset[idx]\n",
478
  " \n",
 
481
  " for label_id in labels:\n",
482
  " if label_id > 3: # Skip special tokens\n",
483
  " label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
484
+ " # If sample contains PII, give it higher weight\n",
485
  " if label_name != 'o' and 'B-' in label_name:\n",
486
  " min_weight = 10.0\n",
487
  " break\n",
488
  " \n",
489
  " sample_weights.append(min_weight)\n",
490
  " \n",
491
+ " # Create weighted sampler\n",
492
  " sampler = WeightedRandomSampler(\n",
493
  " weights=sample_weights,\n",
494
  " num_samples=len(sample_weights),\n",
 
500
  },
501
  {
502
  "cell_type": "code",
503
+ "execution_count": null,
504
  "id": "69b37e68",
505
  "metadata": {
506
  "execution": {
 
523
  "def print_label_distribution(data, title=\"Label Distribution\"):\n",
524
  " \"\"\"Print label distribution statistics\"\"\"\n",
525
  " label_counts = Counter()\n",
526
+ "\n",
527
+ " # Count each label type\n",
528
  " for label_seq in data.labels:\n",
529
  " for label in label_seq:\n",
530
  " if label not in ['<pad>', '<start>', '<end>']:\n",
531
  " label_counts[label] += 1\n",
532
  " \n",
533
+ " # Print distribution\n",
534
  " print(f\"\\n{title}:\")\n",
535
  " print(\"-\" * 50)\n",
536
  " total = sum(label_counts.values())\n",
 
543
  },
544
  {
545
  "cell_type": "code",
546
+ "execution_count": null,
547
  "id": "4b1b4f86",
548
  "metadata": {
549
  "execution": {
 
567
  " \"\"\"Save model and all necessary components for deployment\"\"\"\n",
568
  " os.makedirs(save_dir, exist_ok=True)\n",
569
  " \n",
570
+ " # Save model weights\n",
571
  " model_path = os.path.join(save_dir, 'pii_lstm_model.pt')\n",
572
  " torch.save(model.state_dict(), model_path)\n",
573
  " \n",
 
593
  },
594
  {
595
  "cell_type": "code",
596
+ "execution_count": null,
597
  "id": "31d2f1b1",
598
  "metadata": {
599
  "execution": {
 
629
  " data = pd.read_json(data_path, lines=True)\n",
630
  " print(f\"Total samples: {len(data)}\")\n",
631
  " \n",
632
+ " # Show label distribution\n",
633
  " print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
634
  " \n",
635
  " # Build vocabularies\n",
 
642
  " for labels in data.labels:\n",
643
  " label_vocab.add_sentence(labels)\n",
644
  " \n",
645
+ " # Build vocabularies from collected words\n",
646
  " text_vocab.build()\n",
647
  " label_vocab.build()\n",
648
  " \n",
 
650
  " print(f\" - Text vocabulary: {len(text_vocab):,}\")\n",
651
  " print(f\" - Label vocabulary: {len(label_vocab)}\")\n",
652
  " \n",
653
+ " # Calculate class weights for balanced loss\n",
654
  " class_weights = calculate_class_weights(data, label_vocab)\n",
655
  " class_weights = class_weights.to(device)\n",
656
  " \n",
657
+ " # Split data into train and validation sets\n",
658
  " X_train, X_val, y_train, y_val = train_test_split(\n",
659
  " data.tokens.tolist(),\n",
660
  " data.labels.tolist(),\n",
 
666
  " print(f\" - Train samples: {len(X_train):,}\")\n",
667
  " print(f\" - Validation samples: {len(X_val):,}\")\n",
668
  " \n",
669
+ " # Create datasets\n",
670
  " max_seq_len = 512\n",
671
  " train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
672
  " val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
673
  " \n",
674
+ " # Create balanced sampler for training\n",
675
  " train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
676
  " \n",
677
+ " # Create data loaders\n",
678
  " train_loader = DataLoader(\n",
679
  " train_dataset, \n",
680
  " batch_size=batch_size,\n",
 
703
  " 'max_len': max_seq_len\n",
704
  " }\n",
705
  " \n",
706
+ " # Create LSTM model\n",
707
  " print(\"\\nCreating LSTM model...\")\n",
708
  " model = create_lstm_pii_model(**model_config).to(device)\n",
709
  " print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
 
736
  " min_lr=1e-6\n",
737
  " )\n",
738
  " \n",
739
+ " # Initialize metrics\n",
740
  " f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
741
  " f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
742
  " \n",
743
+ " # Training history\n",
744
  " train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
745
  " best_val_f1 = 0\n",
746
  " patience = 7\n",
 
749
  " print(\"\\nStarting training...\")\n",
750
  " print(\"=\" * 60)\n",
751
  " \n",
752
+ " # Training loop\n",
753
  " for epoch in range(num_epochs):\n",
754
  " print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
755
  " \n",
756
+ " # Train for one epoch\n",
757
  " train_loss, train_f1 = train_epoch(\n",
758
  " model, train_loader, optimizer, criterion, device, f1_metric_train\n",
759
  " )\n",
760
+ " # Evaluate on validation set\n",
761
  " val_loss, val_f1 = evaluate(\n",
762
  " model, val_loader, criterion, device, f1_metric_val\n",
763
  " )\n",
764
  " \n",
765
+ " # Adjust learning rate based on validation loss\n",
766
  " scheduler.step(val_loss)\n",
767
  " \n",
768
  " # Store metrics\n",
 
1319
  }
1320
  ],
1321
  "source": [
1322
+ "# Set device\n",
1323
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
1324
  "print(f\"Using device: {device}\")\n",
1325
  "\n",
1326
+ "# Train the LSTM model\n",
1327
  "model, text_vocab, label_vocab = train_lstm_pii_model(\n",
1328
  " data_path='train_augmented.json',\n",
1329
  " num_epochs=20,\n",
transformer.py CHANGED
@@ -4,37 +4,25 @@ import torch.nn.functional as F
4
  import math
5
 
6
  def scaled_dot_product_attention(q, k, v, mask=None, dropout=None):
7
- """
8
- Compute scaled dot-product attention.
9
-
10
- Args:
11
- q: queries (batch_size, num_heads, seq_len_q, d_k)
12
- k: keys (batch_size, num_heads, seq_len_k, d_k)
13
- v: values (batch_size, num_heads, seq_len_v, d_v)
14
- mask: mask tensor (batch_size, 1, 1, seq_len_k) or (batch_size, 1, seq_len_q, seq_len_k)
15
- dropout: dropout layer
16
-
17
- Returns:
18
- output: attended values (batch_size, num_heads, seq_len_q, d_v)
19
- attention_weights: attention weights (batch_size, num_heads, seq_len_q, seq_len_k)
20
- """
21
  d_k = q.size(-1)
22
 
23
- # Calculate attention scores
24
  scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
25
 
26
- # Apply mask if provided
27
  if mask is not None:
28
  scores = scores.masked_fill(mask == 0, float('-inf'))
29
 
30
- # Apply softmax
31
  attention_weights = F.softmax(scores, dim=-1)
32
 
33
- # Apply dropout if provided
34
  if dropout is not None:
35
  attention_weights = dropout(attention_weights)
36
 
37
- # Apply attention to values
38
  output = torch.matmul(attention_weights, v)
39
 
40
  return output, attention_weights
@@ -48,47 +36,42 @@ class MultiHeadAttention(nn.Module):
48
 
49
  self.d_model = d_model
50
  self.num_heads = num_heads
51
- self.d_k = d_model // num_heads
52
 
53
- # Linear projections for Q, K, V
54
  self.w_q = nn.Linear(d_model, d_model)
55
  self.w_k = nn.Linear(d_model, d_model)
56
  self.w_v = nn.Linear(d_model, d_model)
57
 
58
- # Output projection
59
  self.w_o = nn.Linear(d_model, d_model)
60
 
61
- # Dropout
62
  self.dropout = nn.Dropout(dropout)
63
 
64
  def forward(self, query, key, value, mask=None):
65
  """
66
- Args:
67
- query: (batch_size, seq_len_q, d_model)
68
- key: (batch_size, seq_len_k, d_model)
69
- value: (batch_size, seq_len_v, d_model)
70
- mask: (batch_size, 1, 1, seq_len_k) or None
71
-
72
- Returns:
73
- output: (batch_size, seq_len_q, d_model)
74
- attention_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
75
  """
76
  batch_size = query.size(0)
77
  seq_len_q = query.size(1)
78
  seq_len_k = key.size(1)
79
  seq_len_v = value.size(1)
80
 
81
- # 1. Linear projections in batch from d_model => h x d_k
82
  Q = self.w_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
83
  K = self.w_k(key).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
84
  V = self.w_v(value).view(batch_size, seq_len_v, self.num_heads, self.d_k).transpose(1, 2)
85
 
86
- # 2. Apply attention on all the projected vectors in batch
87
  attention_output, attention_weights = scaled_dot_product_attention(
88
  Q, K, V, mask=mask, dropout=self.dropout
89
  )
90
 
91
- # 3. Concatenate heads and put through final linear layer
92
  attention_output = attention_output.transpose(1, 2).contiguous().view(
93
  batch_size, seq_len_q, self.d_model
94
  )
@@ -102,19 +85,14 @@ class PositionwiseFeedForward(nn.Module):
102
 
103
  def __init__(self, d_model, d_ff, dropout=0.1):
104
  super(PositionwiseFeedForward, self).__init__()
 
105
  self.w_1 = nn.Linear(d_model, d_ff)
106
  self.w_2 = nn.Linear(d_ff, d_model)
107
  self.dropout = nn.Dropout(dropout)
108
  self.activation = nn.ReLU()
109
 
110
  def forward(self, x):
111
- """
112
- Args:
113
- x: (batch_size, seq_len, d_model)
114
-
115
- Returns:
116
- output: (batch_size, seq_len, d_model)
117
- """
118
  return self.w_2(self.dropout(self.activation(self.w_1(x))))
119
 
120
  class EncoderLayer(nn.Module):
@@ -123,33 +101,25 @@ class EncoderLayer(nn.Module):
123
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
124
  super(EncoderLayer, self).__init__()
125
 
126
- # Multi-head attention
127
  self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
128
 
129
- # Position-wise feed forward
130
  self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
131
 
132
- # Layer normalization
133
  self.norm1 = nn.LayerNorm(d_model)
134
  self.norm2 = nn.LayerNorm(d_model)
135
 
136
- # Dropout
137
  self.dropout = nn.Dropout(dropout)
138
 
139
  def forward(self, x, mask=None):
140
- """
141
- Args:
142
- x: (batch_size, seq_len, d_model)
143
- mask: (batch_size, 1, 1, seq_len) or None
144
-
145
- Returns:
146
- output: (batch_size, seq_len, d_model)
147
- """
148
- # Self-attention with residual connection and layer norm
149
  attn_output, _ = self.self_attention(x, x, x, mask)
150
  x = self.norm1(x + self.dropout(attn_output))
151
 
152
- # Feed forward with residual connection and layer norm
153
  ff_output = self.feed_forward(x)
154
  x = self.norm2(x + self.dropout(ff_output))
155
 
@@ -161,25 +131,21 @@ class TransformerEncoder(nn.Module):
161
  def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
162
  super(TransformerEncoder, self).__init__()
163
 
 
164
  self.layers = nn.ModuleList([
165
  EncoderLayer(d_model, num_heads, d_ff, dropout)
166
  for _ in range(num_layers)
167
  ])
168
 
 
169
  self.norm = nn.LayerNorm(d_model)
170
 
171
  def forward(self, x, mask=None):
172
- """
173
- Args:
174
- x: (batch_size, seq_len, d_model)
175
- mask: (batch_size, 1, 1, seq_len) or None
176
-
177
- Returns:
178
- output: (batch_size, seq_len, d_model)
179
- """
180
  for layer in self.layers:
181
  x = layer(x, mask)
182
 
 
183
  return self.norm(x)
184
 
185
  class PositionalEncoding(nn.Module):
@@ -189,36 +155,29 @@ class PositionalEncoding(nn.Module):
189
  super(PositionalEncoding, self).__init__()
190
  self.dropout = nn.Dropout(dropout)
191
 
192
- # Create positional encoding matrix
193
  pe = torch.zeros(max_len, d_model)
194
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
195
 
196
- # Create div_term for sin/cos frequencies
197
  div_term = torch.exp(torch.arange(0, d_model, 2).float() *
198
  (-math.log(10000.0) / d_model))
199
 
200
- # Apply sin to even indices
201
  pe[:, 0::2] = torch.sin(position * div_term)
202
 
203
- # Apply cos to odd indices
204
  if d_model % 2 == 1:
205
  pe[:, 1::2] = torch.cos(position * div_term[:-1])
206
  else:
207
  pe[:, 1::2] = torch.cos(position * div_term)
208
 
209
- # Add batch dimension and register as buffer
210
  pe = pe.unsqueeze(0)
211
  self.register_buffer('pe', pe)
212
 
213
  def forward(self, x):
214
- """
215
- Args:
216
- x: (batch_size, seq_len, d_model)
217
-
218
- Returns:
219
- output: (batch_size, seq_len, d_model)
220
- """
221
- # Add positional encoding
222
  x = x + self.pe[:, :x.size(1), :]
223
  return self.dropout(x)
224
 
@@ -235,62 +194,46 @@ class TransformerPII(nn.Module):
235
  self.d_model = d_model
236
  self.pad_idx = pad_idx
237
 
238
- # Token embedding layer
239
  self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
240
 
241
- # Positional encoding
242
  self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
243
 
244
- # Transformer encoder stack
245
  self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
246
 
247
- # Classification head
248
  self.classifier = nn.Linear(d_model, num_classes)
249
 
250
- # Dropout
251
  self.dropout = nn.Dropout(dropout)
252
 
253
- # Initialize weights
254
  self._init_weights()
255
 
256
  def _init_weights(self):
257
  """Initialize model weights"""
258
- # Initialize embeddings
259
  nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model**-0.5)
 
260
  if self.pad_idx is not None:
261
  nn.init.constant_(self.embedding.weight[self.pad_idx], 0)
262
 
263
- # Initialize classifier
264
  nn.init.xavier_uniform_(self.classifier.weight)
265
  if self.classifier.bias is not None:
266
  nn.init.constant_(self.classifier.bias, 0)
267
 
268
  def create_padding_mask(self, x):
269
- """
270
- Create padding mask for attention
271
-
272
- Args:
273
- x: (batch_size, seq_len) - input token indices
274
-
275
- Returns:
276
- mask: (batch_size, 1, 1, seq_len) - attention mask
277
- """
278
- # Create mask where padding tokens are marked as 0
279
  mask = (x != self.pad_idx).unsqueeze(1).unsqueeze(2)
280
  return mask.float()
281
 
282
  def forward(self, x, mask=None):
283
- """
284
- Forward pass for token classification
285
-
286
- Args:
287
- x: (batch_size, seq_len) - input token indices
288
- mask: Optional custom attention mask
289
-
290
- Returns:
291
- logits: (batch_size, seq_len, num_classes) - classification logits
292
- """
293
- # Check input dimensions
294
  if x.dim() != 2:
295
  raise ValueError(f"Expected input to have 2 dimensions [batch_size, seq_len], got {x.dim()}")
296
 
@@ -300,94 +243,35 @@ class TransformerPII(nn.Module):
300
  if mask is None:
301
  mask = self.create_padding_mask(x)
302
 
303
- # Embedding with scaling
304
  x = self.embedding(x) * math.sqrt(self.d_model)
305
 
306
  # Add positional encoding
307
  x = self.positional_encoding(x)
308
 
309
- # Pass through transformer encoder
310
  encoder_output = self.encoder(x, mask)
311
 
312
  # Apply dropout before classification
313
  encoder_output = self.dropout(encoder_output)
314
 
315
- # Classify each token
316
  logits = self.classifier(encoder_output)
317
 
318
  return logits
319
 
320
  def predict(self, x):
321
- """
322
- Get predictions for inference
323
-
324
- Args:
325
- x: (batch_size, seq_len) - input token indices
326
-
327
- Returns:
328
- predictions: (batch_size, seq_len) - predicted class indices
329
- """
330
  self.eval()
331
  with torch.no_grad():
332
  logits = self.forward(x)
333
  predictions = torch.argmax(logits, dim=-1)
334
  return predictions
335
 
336
- class TransformerPIIWithCRF(TransformerPII):
337
- """
338
- Transformer with CRF layer for improved sequence labeling
339
- (Optional enhancement - requires pytorch-crf)
340
- """
341
-
342
- def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
343
- d_ff=512, num_layers=4, dropout=0.1, max_len=512, pad_idx=0):
344
- super(TransformerPIIWithCRF, self).__init__(
345
- vocab_size, num_classes, d_model, num_heads,
346
- d_ff, num_layers, dropout, max_len, pad_idx
347
- )
348
-
349
- # CRF layer would be initialized here
350
- # from torchcrf import CRF
351
- # self.crf = CRF(num_classes, batch_first=True)
352
-
353
- def forward(self, x, labels=None):
354
- """Forward pass with optional CRF"""
355
- # Get transformer outputs
356
- emissions = super().forward(x)
357
-
358
- if labels is not None:
359
- # Training mode with CRF
360
- # mask = (x != self.pad_idx)
361
- # loss = -self.crf(emissions, labels, mask=mask)
362
- # return loss
363
- pass
364
- else:
365
- # Inference mode with CRF
366
- # mask = (x != self.pad_idx)
367
- # predictions = self.crf.decode(emissions, mask=mask)
368
- # return predictions
369
- pass
370
-
371
- return emissions
372
-
373
  def create_transformer_pii_model(vocab_size, num_classes, d_model=256, num_heads=8,
374
  d_ff=512, num_layers=4, dropout=0.1, max_len=512):
375
- """
376
- Factory function to create transformer model for PII detection
377
-
378
- Args:
379
- vocab_size: Size of vocabulary
380
- num_classes: Number of PII classes (e.g., 20)
381
- d_model: Dimension of model (hidden size)
382
- num_heads: Number of attention heads
383
- d_ff: Dimension of feedforward network
384
- num_layers: Number of transformer layers
385
- dropout: Dropout rate
386
- max_len: Maximum sequence length
387
-
388
- Returns:
389
- TransformerPII model instance
390
- """
391
  model = TransformerPII(
392
  vocab_size=vocab_size,
393
  num_classes=num_classes,
@@ -397,7 +281,7 @@ def create_transformer_pii_model(vocab_size, num_classes, d_model=256, num_heads
397
  num_layers=num_layers,
398
  dropout=dropout,
399
  max_len=max_len,
400
- pad_idx=0 # Assuming 0 is padding index
401
  )
402
 
403
  return model
 
4
  import math
5
 
6
  def scaled_dot_product_attention(q, k, v, mask=None, dropout=None):
7
+ """Compute scaled dot-product attention."""
8
+ # Get dimension of keys for scaling
 
 
 
 
 
 
 
 
 
 
 
 
9
  d_k = q.size(-1)
10
 
11
+ # Compute attention scores using dot product
12
  scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
13
 
14
+ # Mask out padding positions if mask provided
15
  if mask is not None:
16
  scores = scores.masked_fill(mask == 0, float('-inf'))
17
 
18
+ # Convert scores to probabilities
19
  attention_weights = F.softmax(scores, dim=-1)
20
 
21
+ # Apply dropout to attention weights if specified
22
  if dropout is not None:
23
  attention_weights = dropout(attention_weights)
24
 
25
+ # Apply attention weights to values
26
  output = torch.matmul(attention_weights, v)
27
 
28
  return output, attention_weights
 
36
 
37
  self.d_model = d_model
38
  self.num_heads = num_heads
39
+ self.d_k = d_model // num_heads # Dimension per head
40
 
41
+ # Linear layers for projecting Q, K, V
42
  self.w_q = nn.Linear(d_model, d_model)
43
  self.w_k = nn.Linear(d_model, d_model)
44
  self.w_v = nn.Linear(d_model, d_model)
45
 
46
+ # Final output projection
47
  self.w_o = nn.Linear(d_model, d_model)
48
 
49
+ # Dropout layer
50
  self.dropout = nn.Dropout(dropout)
51
 
52
  def forward(self, query, key, value, mask=None):
53
  """
54
+ query: (batch_size, seq_len_q, d_model)
55
+ key: (batch_size, seq_len_k, d_model)
56
+ value: (batch_size, seq_len_v, d_model)
57
+ mask: (batch_size, 1, 1, seq_len_k) or None
 
 
 
 
 
58
  """
59
  batch_size = query.size(0)
60
  seq_len_q = query.size(1)
61
  seq_len_k = key.size(1)
62
  seq_len_v = value.size(1)
63
 
64
+ # Project and reshape for multiple heads
65
  Q = self.w_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
66
  K = self.w_k(key).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
67
  V = self.w_v(value).view(batch_size, seq_len_v, self.num_heads, self.d_k).transpose(1, 2)
68
 
69
+ # Apply scaled dot-product attention
70
  attention_output, attention_weights = scaled_dot_product_attention(
71
  Q, K, V, mask=mask, dropout=self.dropout
72
  )
73
 
74
+ # Concatenate heads and apply output projection
75
  attention_output = attention_output.transpose(1, 2).contiguous().view(
76
  batch_size, seq_len_q, self.d_model
77
  )
 
85
 
86
  def __init__(self, d_model, d_ff, dropout=0.1):
87
  super(PositionwiseFeedForward, self).__init__()
88
+ # Two linear layers with ReLU activation
89
  self.w_1 = nn.Linear(d_model, d_ff)
90
  self.w_2 = nn.Linear(d_ff, d_model)
91
  self.dropout = nn.Dropout(dropout)
92
  self.activation = nn.ReLU()
93
 
94
  def forward(self, x):
95
+ # Apply first linear layer, activation, dropout, then second linear layer
 
 
 
 
 
 
96
  return self.w_2(self.dropout(self.activation(self.w_1(x))))
97
 
98
  class EncoderLayer(nn.Module):
 
101
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
102
  super(EncoderLayer, self).__init__()
103
 
104
+ # Multi-head self-attention sublayer
105
  self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
106
 
107
+ # Position-wise feed forward sublayer
108
  self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
109
 
110
+ # Layer normalization for each sublayer
111
  self.norm1 = nn.LayerNorm(d_model)
112
  self.norm2 = nn.LayerNorm(d_model)
113
 
114
+ # Dropout for residual connections
115
  self.dropout = nn.Dropout(dropout)
116
 
117
  def forward(self, x, mask=None):
118
+ # Self-attention sublayer with residual connection and layer norm
 
 
 
 
 
 
 
 
119
  attn_output, _ = self.self_attention(x, x, x, mask)
120
  x = self.norm1(x + self.dropout(attn_output))
121
 
122
+ # Feed forward sublayer with residual connection and layer norm
123
  ff_output = self.feed_forward(x)
124
  x = self.norm2(x + self.dropout(ff_output))
125
 
 
131
  def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
132
  super(TransformerEncoder, self).__init__()
133
 
134
+ # Create stack of encoder layers
135
  self.layers = nn.ModuleList([
136
  EncoderLayer(d_model, num_heads, d_ff, dropout)
137
  for _ in range(num_layers)
138
  ])
139
 
140
+ # Final layer normalization
141
  self.norm = nn.LayerNorm(d_model)
142
 
143
  def forward(self, x, mask=None):
144
+ # Pass through each encoder layer sequentially
 
 
 
 
 
 
 
145
  for layer in self.layers:
146
  x = layer(x, mask)
147
 
148
+ # Apply final normalization
149
  return self.norm(x)
150
 
151
  class PositionalEncoding(nn.Module):
 
155
  super(PositionalEncoding, self).__init__()
156
  self.dropout = nn.Dropout(dropout)
157
 
158
+ # Create matrix to hold positional encodings
159
  pe = torch.zeros(max_len, d_model)
160
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
161
 
162
+ # Create frequency terms for sin/cos functions
163
  div_term = torch.exp(torch.arange(0, d_model, 2).float() *
164
  (-math.log(10000.0) / d_model))
165
 
166
+ # Apply sine to even indices
167
  pe[:, 0::2] = torch.sin(position * div_term)
168
 
169
+ # Apply cosine to odd indices
170
  if d_model % 2 == 1:
171
  pe[:, 1::2] = torch.cos(position * div_term[:-1])
172
  else:
173
  pe[:, 1::2] = torch.cos(position * div_term)
174
 
175
+ # Add batch dimension and save as buffer
176
  pe = pe.unsqueeze(0)
177
  self.register_buffer('pe', pe)
178
 
179
  def forward(self, x):
180
+ # Add positional encoding to input embeddings
 
 
 
 
 
 
 
181
  x = x + self.pe[:, :x.size(1), :]
182
  return self.dropout(x)
183
 
 
194
  self.d_model = d_model
195
  self.pad_idx = pad_idx
196
 
197
+ # Embedding layer for input tokens
198
  self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
199
 
200
+ # Add positional information to embeddings
201
  self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
202
 
203
+ # Stack of transformer encoder layers
204
  self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
205
 
206
+ # Classification head for token-level predictions
207
  self.classifier = nn.Linear(d_model, num_classes)
208
 
209
+ # Dropout layer
210
  self.dropout = nn.Dropout(dropout)
211
 
212
+ # Initialize model weights
213
  self._init_weights()
214
 
215
  def _init_weights(self):
216
  """Initialize model weights"""
217
+ # Initialize embeddings with normal distribution
218
  nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model**-0.5)
219
+ # Set padding token embedding to zero
220
  if self.pad_idx is not None:
221
  nn.init.constant_(self.embedding.weight[self.pad_idx], 0)
222
 
223
+ # Initialize classifier with Xavier uniform
224
  nn.init.xavier_uniform_(self.classifier.weight)
225
  if self.classifier.bias is not None:
226
  nn.init.constant_(self.classifier.bias, 0)
227
 
228
  def create_padding_mask(self, x):
229
+ """Create padding mask for attention"""
230
+ # Create mask where non-padding tokens are marked as 1
 
 
 
 
 
 
 
 
231
  mask = (x != self.pad_idx).unsqueeze(1).unsqueeze(2)
232
  return mask.float()
233
 
234
  def forward(self, x, mask=None):
235
+ """Forward pass for token classification"""
236
+ # Validate input dimensions
 
 
 
 
 
 
 
 
 
237
  if x.dim() != 2:
238
  raise ValueError(f"Expected input to have 2 dimensions [batch_size, seq_len], got {x.dim()}")
239
 
 
243
  if mask is None:
244
  mask = self.create_padding_mask(x)
245
 
246
+ # Embed and scale by sqrt(d_model)
247
  x = self.embedding(x) * math.sqrt(self.d_model)
248
 
249
  # Add positional encoding
250
  x = self.positional_encoding(x)
251
 
252
+ # Pass through transformer encoder stack
253
  encoder_output = self.encoder(x, mask)
254
 
255
  # Apply dropout before classification
256
  encoder_output = self.dropout(encoder_output)
257
 
258
+ # Get class predictions for each token
259
  logits = self.classifier(encoder_output)
260
 
261
  return logits
262
 
263
  def predict(self, x):
264
+ """Get predictions for inference"""
265
+ # Switch to evaluation mode
 
 
 
 
 
 
 
266
  self.eval()
267
  with torch.no_grad():
268
  logits = self.forward(x)
269
  predictions = torch.argmax(logits, dim=-1)
270
  return predictions
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  def create_transformer_pii_model(vocab_size, num_classes, d_model=256, num_heads=8,
273
  d_ff=512, num_layers=4, dropout=0.1, max_len=512):
274
+ """Factory function to create transformer model for PII detection"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  model = TransformerPII(
276
  vocab_size=vocab_size,
277
  num_classes=num_classes,
 
281
  num_layers=num_layers,
282
  dropout=dropout,
283
  max_len=max_len,
284
+ pad_idx=0
285
  )
286
 
287
  return model
transformer_training.ipynb CHANGED
@@ -42,7 +42,7 @@
42
  },
43
  {
44
  "cell_type": "code",
45
- "execution_count": 4,
46
  "id": "ff1782dd",
47
  "metadata": {
48
  "execution": {
@@ -62,19 +62,23 @@
62
  },
63
  "outputs": [],
64
  "source": [
 
65
  "class Vocabulary:\n",
66
  " \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
67
  " def __init__(self, max_size=100000):\n",
 
68
  " self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
69
  " self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
70
  " self.word_count = Counter()\n",
71
  " self.max_size = max_size\n",
72
  " \n",
73
  " def add_sentence(self, sentence):\n",
 
74
  " for word in sentence:\n",
75
  " self.word_count[word.lower()] += 1\n",
76
  " \n",
77
  " def build(self):\n",
 
78
  " most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
79
  " for word, _ in most_common:\n",
80
  " if word not in self.word2idx:\n",
@@ -86,15 +90,17 @@
86
  " return len(self.word2idx)\n",
87
  " \n",
88
  " def encode(self, sentence):\n",
 
89
  " return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
90
  " \n",
91
  " def decode(self, indices):\n",
 
92
  " return [self.idx2word.get(idx, '<unk>') for idx in indices]"
93
  ]
94
  },
95
  {
96
  "cell_type": "code",
97
- "execution_count": 5,
98
  "id": "5b2b46d6",
99
  "metadata": {
100
  "execution": {
@@ -114,6 +120,7 @@
114
  },
115
  "outputs": [],
116
  "source": [
 
117
  "class PIIDataset(Dataset):\n",
118
  " \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
119
  " def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
@@ -127,16 +134,16 @@
127
  " return len(self.tokens)\n",
128
  " \n",
129
  " def __getitem__(self, idx):\n",
130
- " # Add start and end tokens\n",
131
  " tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
132
  " labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
133
  " \n",
134
- " # Truncate if too long\n",
135
  " if len(tokens) > self.max_len:\n",
136
  " tokens = tokens[:self.max_len-1] + ['<end>']\n",
137
  " labels = labels[:self.max_len-1] + ['<end>']\n",
138
  " \n",
139
- " # Encode\n",
140
  " token_ids = self.text_vocab.encode(tokens)\n",
141
  " label_ids = self.label_vocab.encode(labels)\n",
142
  " \n",
@@ -145,7 +152,7 @@
145
  },
146
  {
147
  "cell_type": "code",
148
- "execution_count": 6,
149
  "id": "e7ca8f8f",
150
  "metadata": {
151
  "execution": {
@@ -167,7 +174,9 @@
167
  "source": [
168
  "def collate_fn(batch):\n",
169
  " \"\"\"Custom collate function for padding sequences\"\"\"\n",
 
170
  " tokens, labels = zip(*batch)\n",
 
171
  " tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
172
  " labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
173
  " return tokens_padded, labels_padded"
@@ -175,7 +184,7 @@
175
  },
176
  {
177
  "cell_type": "code",
178
- "execution_count": 7,
179
  "id": "85b32e21",
180
  "metadata": {
181
  "execution": {
@@ -195,6 +204,7 @@
195
  },
196
  "outputs": [],
197
  "source": [
 
198
  "class F1ScoreMetric:\n",
199
  " \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
200
  " def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
@@ -205,30 +215,37 @@
205
  " self.reset()\n",
206
  " \n",
207
  " def reset(self):\n",
 
208
  " self.true_positives = 0\n",
209
  " self.false_positives = 0\n",
210
  " self.false_negatives = 0\n",
211
  " self.class_metrics = {}\n",
212
  " \n",
213
  " def update(self, predictions, targets):\n",
 
214
  " mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
215
  " o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
216
  " \n",
 
217
  " for class_id in range(1, self.num_classes):\n",
218
  " if class_id == o_idx:\n",
219
  " continue\n",
220
  " \n",
 
221
  " pred_mask = (predictions == class_id) & mask\n",
222
  " true_mask = (targets == class_id) & mask\n",
223
  " \n",
 
224
  " tp = ((pred_mask) & (true_mask)).sum().item()\n",
225
  " fp = ((pred_mask) & (~true_mask)).sum().item()\n",
226
  " fn = ((~pred_mask) & (true_mask)).sum().item()\n",
227
  " \n",
 
228
  " self.true_positives += tp\n",
229
  " self.false_positives += fp\n",
230
  " self.false_negatives += fn\n",
231
  " \n",
 
232
  " if class_id not in self.class_metrics:\n",
233
  " self.class_metrics[class_id] = {'tp': 0, 'fp': 0, 'fn': 0}\n",
234
  " self.class_metrics[class_id]['tp'] += tp\n",
@@ -236,6 +253,7 @@
236
  " self.class_metrics[class_id]['fn'] += fn\n",
237
  " \n",
238
  " def compute(self):\n",
 
239
  " beta_squared = self.beta ** 2\n",
240
  " precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
241
  " recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
@@ -243,6 +261,7 @@
243
  " return f1\n",
244
  " \n",
245
  " def get_class_metrics(self):\n",
 
246
  " results = {}\n",
247
  " for class_id, metrics in self.class_metrics.items():\n",
248
  " if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
@@ -261,7 +280,7 @@
261
  },
262
  {
263
  "cell_type": "code",
264
- "execution_count": 8,
265
  "id": "60cf16eb",
266
  "metadata": {
267
  "execution": {
@@ -281,6 +300,7 @@
281
  },
282
  "outputs": [],
283
  "source": [
 
284
  "class FocalLoss(nn.Module):\n",
285
  " \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
286
  " def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
@@ -291,6 +311,7 @@
291
  " self.ignore_index = ignore_index\n",
292
  " \n",
293
  " def forward(self, inputs, targets):\n",
 
294
  " ce_loss = nn.functional.cross_entropy(\n",
295
  " inputs, targets, \n",
296
  " weight=self.alpha, \n",
@@ -298,9 +319,11 @@
298
  " ignore_index=self.ignore_index\n",
299
  " )\n",
300
  " \n",
 
301
  " pt = torch.exp(-ce_loss)\n",
302
  " focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
303
  " \n",
 
304
  " if self.reduction == 'mean':\n",
305
  " return focal_loss.mean()\n",
306
  " elif self.reduction == 'sum':\n",
@@ -311,7 +334,7 @@
311
  },
312
  {
313
  "cell_type": "code",
314
- "execution_count": 9,
315
  "id": "4e56747c",
316
  "metadata": {
317
  "execution": {
@@ -337,8 +360,10 @@
337
  " total_loss = 0\n",
338
  " f1_metric.reset()\n",
339
  " \n",
 
340
  " progress_bar = tqdm(dataloader, desc='Training')\n",
341
  " for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
 
342
  " tokens = tokens.to(device)\n",
343
  " labels = labels.to(device)\n",
344
  " \n",
@@ -350,9 +375,10 @@
350
  " outputs_flat = outputs.view(-1, outputs.size(-1))\n",
351
  " labels_flat = labels.view(-1)\n",
352
  " \n",
353
- " # Calculate loss and backward pass\n",
354
  " loss = criterion(outputs_flat, labels_flat)\n",
355
  " loss.backward()\n",
 
356
  " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
357
  " optimizer.step()\n",
358
  " \n",
@@ -372,7 +398,7 @@
372
  },
373
  {
374
  "cell_type": "code",
375
- "execution_count": 10,
376
  "id": "8a2e8d19",
377
  "metadata": {
378
  "execution": {
@@ -398,8 +424,10 @@
398
  " total_loss = 0\n",
399
  " f1_metric.reset()\n",
400
  " \n",
 
401
  " with torch.no_grad():\n",
402
  " for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
 
403
  " tokens = tokens.to(device)\n",
404
  " labels = labels.to(device)\n",
405
  " \n",
@@ -421,7 +449,7 @@
421
  },
422
  {
423
  "cell_type": "code",
424
- "execution_count": 11,
425
  "id": "6e292ace",
426
  "metadata": {
427
  "execution": {
@@ -445,6 +473,7 @@
445
  " \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
446
  " sample_weights = []\n",
447
  " \n",
 
448
  " for idx in range(len(dataset)):\n",
449
  " _, labels = dataset[idx]\n",
450
  " \n",
@@ -453,24 +482,26 @@
453
  " for label_id in labels:\n",
454
  " if label_id > 3: # Skip special tokens\n",
455
  " label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
 
456
  " if label_name != 'o' and 'B-' in label_name:\n",
457
  " min_weight = 10.0\n",
458
  " break\n",
459
  " \n",
460
  " sample_weights.append(min_weight)\n",
461
  " \n",
 
462
  " sampler = WeightedRandomSampler(\n",
463
  " weights=sample_weights,\n",
464
  " num_samples=len(sample_weights),\n",
465
  " replacement=True\n",
466
  " )\n",
467
  " \n",
468
- " return sampler\n"
469
  ]
470
  },
471
  {
472
  "cell_type": "code",
473
- "execution_count": 12,
474
  "id": "857335cb",
475
  "metadata": {
476
  "execution": {
@@ -493,11 +524,14 @@
493
  "def print_label_distribution(data, title=\"Label Distribution\"):\n",
494
  " \"\"\"Print label distribution statistics\"\"\"\n",
495
  " label_counts = Counter()\n",
 
 
496
  " for label_seq in data.labels:\n",
497
  " for label in label_seq:\n",
498
  " if label not in ['<pad>', '<start>', '<end>']:\n",
499
  " label_counts[label] += 1\n",
500
  " \n",
 
501
  " print(f\"\\n{title}:\")\n",
502
  " print(\"-\" * 50)\n",
503
  " total = sum(label_counts.values())\n",
@@ -510,7 +544,7 @@
510
  },
511
  {
512
  "cell_type": "code",
513
- "execution_count": 13,
514
  "id": "1738f8a9",
515
  "metadata": {
516
  "execution": {
@@ -532,9 +566,10 @@
532
  "source": [
533
  "def save_model(model, text_vocab, label_vocab, config, save_dir):\n",
534
  " \"\"\"Save model and all necessary components for Flask deployment\"\"\"\n",
 
535
  " os.makedirs(save_dir, exist_ok=True)\n",
536
  " \n",
537
- " # Save model state\n",
538
  " model_path = os.path.join(save_dir, 'pii_transformer_model.pt')\n",
539
  " torch.save(model.state_dict(), model_path)\n",
540
  " \n",
@@ -560,7 +595,7 @@
560
  },
561
  {
562
  "cell_type": "code",
563
- "execution_count": 14,
564
  "id": "d93e7c25",
565
  "metadata": {
566
  "execution": {
@@ -591,12 +626,12 @@
591
  "):\n",
592
  " \"\"\"Main training function\"\"\"\n",
593
  " \n",
594
- " # Load data\n",
595
  " print(\"Loading augmented data...\")\n",
596
  " data = pd.read_json(data_path, lines=True)\n",
597
  " print(f\"Total samples: {len(data)}\")\n",
598
  " \n",
599
- " # Print initial label distribution\n",
600
  " print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
601
  " \n",
602
  " # Build vocabularies\n",
@@ -604,19 +639,21 @@
604
  " text_vocab = Vocabulary(max_size=100000)\n",
605
  " label_vocab = Vocabulary(max_size=50)\n",
606
  " \n",
 
607
  " for tokens in data.tokens:\n",
608
  " text_vocab.add_sentence(tokens)\n",
609
  " for labels in data.labels:\n",
610
  " label_vocab.add_sentence(labels)\n",
611
  " \n",
 
612
  " text_vocab.build()\n",
613
  " label_vocab.build()\n",
614
  " \n",
615
- " # Calculate class weights\n",
616
  " class_weights = calculate_class_weights(data, label_vocab)\n",
617
  " class_weights = class_weights.to(device)\n",
618
  " \n",
619
- " # Split data\n",
620
  " X_train, X_val, y_train, y_val = train_test_split(\n",
621
  " data.tokens.tolist(),\n",
622
  " data.labels.tolist(),\n",
@@ -628,13 +665,15 @@
628
  " print(f\" - Train samples: {len(X_train):,}\")\n",
629
  " print(f\" - Validation samples: {len(X_val):,}\")\n",
630
  " \n",
631
- " # Create datasets and dataloaders\n",
632
  " max_seq_len = 512\n",
633
  " train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
634
  " val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
635
  " \n",
 
636
  " train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
637
  " \n",
 
638
  " train_loader = DataLoader(\n",
639
  " train_dataset, \n",
640
  " batch_size=batch_size,\n",
@@ -663,7 +702,7 @@
663
  " 'max_len': max_seq_len\n",
664
  " }\n",
665
  " \n",
666
- " # Create model\n",
667
  " print(\"\\nCreating model...\")\n",
668
  " model = create_transformer_pii_model(**model_config).to(device)\n",
669
  " print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
@@ -678,15 +717,15 @@
678
  " else:\n",
679
  " criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=0)\n",
680
  " \n",
681
- " # Setup optimizer and scheduler\n",
682
  " optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)\n",
683
  " scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)\n",
684
  " \n",
685
- " # Metrics\n",
686
  " f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
687
  " f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
688
  " \n",
689
- " # Training loop\n",
690
  " train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
691
  " best_val_f1 = 0\n",
692
  " patience = 5\n",
@@ -695,18 +734,21 @@
695
  " print(\"\\nStarting training...\")\n",
696
  " print(\"=\" * 60)\n",
697
  " \n",
 
698
  " for epoch in range(num_epochs):\n",
699
  " print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
700
  " \n",
701
- " # Train and validate\n",
702
  " train_loss, train_f1 = train_epoch(\n",
703
  " model, train_loader, optimizer, criterion, device, f1_metric_train\n",
704
  " )\n",
 
 
705
  " val_loss, val_f1 = evaluate(\n",
706
  " model, val_loader, criterion, device, f1_metric_val\n",
707
  " )\n",
708
  " \n",
709
- " # Step scheduler\n",
710
  " scheduler.step(val_loss)\n",
711
  " \n",
712
  " # Store metrics\n",
@@ -743,7 +785,7 @@
743
  " else:\n",
744
  " patience_counter += 1\n",
745
  " \n",
746
- " # Early stopping\n",
747
  " if patience_counter >= patience and epoch > 10:\n",
748
  " print(f\"\\nEarly stopping triggered after {patience} epochs without improvement\")\n",
749
  " break\n",
@@ -751,6 +793,7 @@
751
  " # Plot training curves\n",
752
  " plt.figure(figsize=(12, 5))\n",
753
  " \n",
 
754
  " plt.subplot(1, 2, 1)\n",
755
  " plt.plot(train_losses, label='Train Loss', linewidth=2)\n",
756
  " plt.plot(val_losses, label='Val Loss', linewidth=2)\n",
@@ -760,6 +803,7 @@
760
  " plt.legend()\n",
761
  " plt.grid(True, alpha=0.3)\n",
762
  " \n",
 
763
  " plt.subplot(1, 2, 2)\n",
764
  " plt.plot(train_f1s, label='Train F1', linewidth=2)\n",
765
  " plt.plot(val_f1s, label='Val F1', linewidth=2)\n",
@@ -777,6 +821,7 @@
777
  " print(f\"Training completed!\")\n",
778
  " print(f\"Best validation F1: {best_val_f1:.4f}\")\n",
779
  " \n",
 
780
  " save_model(model, text_vocab, label_vocab, model_config, 'saved_transformer_model')\n",
781
  " \n",
782
  " return model, text_vocab, label_vocab"
@@ -1251,9 +1296,11 @@
1251
  }
1252
  ],
1253
  "source": [
 
1254
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
1255
  "print(f\"Using device: {device}\")\n",
1256
  "\n",
 
1257
  "model, text_vocab, label_vocab = train_transformer_pii_model(\n",
1258
  " data_path='train_augmented.json',\n",
1259
  " num_epochs=20,\n",
 
42
  },
43
  {
44
  "cell_type": "code",
45
+ "execution_count": null,
46
  "id": "ff1782dd",
47
  "metadata": {
48
  "execution": {
 
62
  },
63
  "outputs": [],
64
  "source": [
65
+ "# Vocabulary class for text and label encoding\n",
66
  "class Vocabulary:\n",
67
  " \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
68
  " def __init__(self, max_size=100000):\n",
69
+ " # Initialize special tokens\n",
70
  " self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
71
  " self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
72
  " self.word_count = Counter()\n",
73
  " self.max_size = max_size\n",
74
  " \n",
75
  " def add_sentence(self, sentence):\n",
76
+ " # Count word frequencies\n",
77
  " for word in sentence:\n",
78
  " self.word_count[word.lower()] += 1\n",
79
  " \n",
80
  " def build(self):\n",
81
+ " # Build vocabulary from most common words\n",
82
  " most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
83
  " for word, _ in most_common:\n",
84
  " if word not in self.word2idx:\n",
 
90
  " return len(self.word2idx)\n",
91
  " \n",
92
  " def encode(self, sentence):\n",
93
+ " # Convert words to indices\n",
94
  " return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
95
  " \n",
96
  " def decode(self, indices):\n",
97
+ " # Convert indices back to words\n",
98
  " return [self.idx2word.get(idx, '<unk>') for idx in indices]"
99
  ]
100
  },
101
  {
102
  "cell_type": "code",
103
+ "execution_count": null,
104
  "id": "5b2b46d6",
105
  "metadata": {
106
  "execution": {
 
120
  },
121
  "outputs": [],
122
  "source": [
123
+ "# Dataset class for PII detection\n",
124
  "class PIIDataset(Dataset):\n",
125
  " \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
126
  " def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
 
134
  " return len(self.tokens)\n",
135
  " \n",
136
  " def __getitem__(self, idx):\n",
137
+ " # Add special tokens at beginning and end\n",
138
  " tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
139
  " labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
140
  " \n",
141
+ " # Truncate if sequence is too long\n",
142
  " if len(tokens) > self.max_len:\n",
143
  " tokens = tokens[:self.max_len-1] + ['<end>']\n",
144
  " labels = labels[:self.max_len-1] + ['<end>']\n",
145
  " \n",
146
+ " # Convert to indices\n",
147
  " token_ids = self.text_vocab.encode(tokens)\n",
148
  " label_ids = self.label_vocab.encode(labels)\n",
149
  " \n",
 
152
  },
153
  {
154
  "cell_type": "code",
155
+ "execution_count": null,
156
  "id": "e7ca8f8f",
157
  "metadata": {
158
  "execution": {
 
174
  "source": [
175
  "def collate_fn(batch):\n",
176
  " \"\"\"Custom collate function for padding sequences\"\"\"\n",
177
+ " # Separate tokens and labels\n",
178
  " tokens, labels = zip(*batch)\n",
179
+ " # Pad sequences to same length in batch\n",
180
  " tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
181
  " labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
182
  " return tokens_padded, labels_padded"
 
184
  },
185
  {
186
  "cell_type": "code",
187
+ "execution_count": null,
188
  "id": "85b32e21",
189
  "metadata": {
190
  "execution": {
 
204
  },
205
  "outputs": [],
206
  "source": [
207
+ "# F1 score metric for evaluation\n",
208
  "class F1ScoreMetric:\n",
209
  " \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
210
  " def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
 
215
  " self.reset()\n",
216
  " \n",
217
  " def reset(self):\n",
218
+ " # Reset all counters\n",
219
  " self.true_positives = 0\n",
220
  " self.false_positives = 0\n",
221
  " self.false_negatives = 0\n",
222
  " self.class_metrics = {}\n",
223
  " \n",
224
  " def update(self, predictions, targets):\n",
225
+ " # Create mask to ignore padding and special tokens\n",
226
  " mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
227
  " o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
228
  " \n",
229
+ " # Calculate metrics for each PII class\n",
230
  " for class_id in range(1, self.num_classes):\n",
231
  " if class_id == o_idx:\n",
232
  " continue\n",
233
  " \n",
234
+ " # Find where predictions and targets match this class\n",
235
  " pred_mask = (predictions == class_id) & mask\n",
236
  " true_mask = (targets == class_id) & mask\n",
237
  " \n",
238
+ " # Count true positives, false positives, false negatives\n",
239
  " tp = ((pred_mask) & (true_mask)).sum().item()\n",
240
  " fp = ((pred_mask) & (~true_mask)).sum().item()\n",
241
  " fn = ((~pred_mask) & (true_mask)).sum().item()\n",
242
  " \n",
243
+ " # Update total counts\n",
244
  " self.true_positives += tp\n",
245
  " self.false_positives += fp\n",
246
  " self.false_negatives += fn\n",
247
  " \n",
248
+ " # Store per-class metrics\n",
249
  " if class_id not in self.class_metrics:\n",
250
  " self.class_metrics[class_id] = {'tp': 0, 'fp': 0, 'fn': 0}\n",
251
  " self.class_metrics[class_id]['tp'] += tp\n",
 
253
  " self.class_metrics[class_id]['fn'] += fn\n",
254
  " \n",
255
  " def compute(self):\n",
256
+ " # Calculate F-beta score\n",
257
  " beta_squared = self.beta ** 2\n",
258
  " precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
259
  " recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
 
261
  " return f1\n",
262
  " \n",
263
  " def get_class_metrics(self):\n",
264
+ " # Get detailed metrics for each class\n",
265
  " results = {}\n",
266
  " for class_id, metrics in self.class_metrics.items():\n",
267
  " if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
 
280
  },
281
  {
282
  "cell_type": "code",
283
+ "execution_count": null,
284
  "id": "60cf16eb",
285
  "metadata": {
286
  "execution": {
 
300
  },
301
  "outputs": [],
302
  "source": [
303
+ "# Focal loss for handling class imbalance\n",
304
  "class FocalLoss(nn.Module):\n",
305
  " \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
306
  " def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
 
311
  " self.ignore_index = ignore_index\n",
312
  " \n",
313
  " def forward(self, inputs, targets):\n",
314
+ " # Calculate cross entropy loss\n",
315
  " ce_loss = nn.functional.cross_entropy(\n",
316
  " inputs, targets, \n",
317
  " weight=self.alpha, \n",
 
319
  " ignore_index=self.ignore_index\n",
320
  " )\n",
321
  " \n",
322
+ " # Apply focal term to focus on hard examples\n",
323
  " pt = torch.exp(-ce_loss)\n",
324
  " focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
325
  " \n",
326
+ " # Reduce loss based on specified method\n",
327
  " if self.reduction == 'mean':\n",
328
  " return focal_loss.mean()\n",
329
  " elif self.reduction == 'sum':\n",
 
334
  },
335
  {
336
  "cell_type": "code",
337
+ "execution_count": null,
338
  "id": "4e56747c",
339
  "metadata": {
340
  "execution": {
 
360
  " total_loss = 0\n",
361
  " f1_metric.reset()\n",
362
  " \n",
363
+ " # Progress bar for training\n",
364
  " progress_bar = tqdm(dataloader, desc='Training')\n",
365
  " for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
366
+ " # Move data to device\n",
367
  " tokens = tokens.to(device)\n",
368
  " labels = labels.to(device)\n",
369
  " \n",
 
375
  " outputs_flat = outputs.view(-1, outputs.size(-1))\n",
376
  " labels_flat = labels.view(-1)\n",
377
  " \n",
378
+ " # Calculate loss and backpropagate\n",
379
  " loss = criterion(outputs_flat, labels_flat)\n",
380
  " loss.backward()\n",
381
+ " # Clip gradients to prevent exploding gradients\n",
382
  " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
383
  " optimizer.step()\n",
384
  " \n",
 
398
  },
399
  {
400
  "cell_type": "code",
401
+ "execution_count": null,
402
  "id": "8a2e8d19",
403
  "metadata": {
404
  "execution": {
 
424
  " total_loss = 0\n",
425
  " f1_metric.reset()\n",
426
  " \n",
427
+ " # No gradient computation during evaluation\n",
428
  " with torch.no_grad():\n",
429
  " for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
430
+ " # Move data to device\n",
431
  " tokens = tokens.to(device)\n",
432
  " labels = labels.to(device)\n",
433
  " \n",
 
449
  },
450
  {
451
  "cell_type": "code",
452
+ "execution_count": null,
453
  "id": "6e292ace",
454
  "metadata": {
455
  "execution": {
 
473
  " \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
474
  " sample_weights = []\n",
475
  " \n",
476
+ " # Calculate weight for each sample\n",
477
  " for idx in range(len(dataset)):\n",
478
  " _, labels = dataset[idx]\n",
479
  " \n",
 
482
  " for label_id in labels:\n",
483
  " if label_id > 3: # Skip special tokens\n",
484
  " label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
485
+ " # If sample contains PII, give it higher weight\n",
486
  " if label_name != 'o' and 'B-' in label_name:\n",
487
  " min_weight = 10.0\n",
488
  " break\n",
489
  " \n",
490
  " sample_weights.append(min_weight)\n",
491
  " \n",
492
+ " # Create weighted sampler\n",
493
  " sampler = WeightedRandomSampler(\n",
494
  " weights=sample_weights,\n",
495
  " num_samples=len(sample_weights),\n",
496
  " replacement=True\n",
497
  " )\n",
498
  " \n",
499
+ " return sampler"
500
  ]
501
  },
502
  {
503
  "cell_type": "code",
504
+ "execution_count": null,
505
  "id": "857335cb",
506
  "metadata": {
507
  "execution": {
 
524
  "def print_label_distribution(data, title=\"Label Distribution\"):\n",
525
  " \"\"\"Print label distribution statistics\"\"\"\n",
526
  " label_counts = Counter()\n",
527
+ " \n",
528
+ " # Count each label type\n",
529
  " for label_seq in data.labels:\n",
530
  " for label in label_seq:\n",
531
  " if label not in ['<pad>', '<start>', '<end>']:\n",
532
  " label_counts[label] += 1\n",
533
  " \n",
534
+ " # Print distribution\n",
535
  " print(f\"\\n{title}:\")\n",
536
  " print(\"-\" * 50)\n",
537
  " total = sum(label_counts.values())\n",
 
544
  },
545
  {
546
  "cell_type": "code",
547
+ "execution_count": null,
548
  "id": "1738f8a9",
549
  "metadata": {
550
  "execution": {
 
566
  "source": [
567
  "def save_model(model, text_vocab, label_vocab, config, save_dir):\n",
568
  " \"\"\"Save model and all necessary components for Flask deployment\"\"\"\n",
569
+ " # Create directory if it doesn't exist\n",
570
  " os.makedirs(save_dir, exist_ok=True)\n",
571
  " \n",
572
+ " # Save model weights\n",
573
  " model_path = os.path.join(save_dir, 'pii_transformer_model.pt')\n",
574
  " torch.save(model.state_dict(), model_path)\n",
575
  " \n",
 
595
  },
596
  {
597
  "cell_type": "code",
598
+ "execution_count": null,
599
  "id": "d93e7c25",
600
  "metadata": {
601
  "execution": {
 
626
  "):\n",
627
  " \"\"\"Main training function\"\"\"\n",
628
  " \n",
629
+ " # Load augmented data\n",
630
  " print(\"Loading augmented data...\")\n",
631
  " data = pd.read_json(data_path, lines=True)\n",
632
  " print(f\"Total samples: {len(data)}\")\n",
633
  " \n",
634
+ " # Show label distribution\n",
635
  " print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
636
  " \n",
637
  " # Build vocabularies\n",
 
639
  " text_vocab = Vocabulary(max_size=100000)\n",
640
  " label_vocab = Vocabulary(max_size=50)\n",
641
  " \n",
642
+ " # Add all words and labels to vocabularies\n",
643
  " for tokens in data.tokens:\n",
644
  " text_vocab.add_sentence(tokens)\n",
645
  " for labels in data.labels:\n",
646
  " label_vocab.add_sentence(labels)\n",
647
  " \n",
648
+ " # Build vocabularies from collected words\n",
649
  " text_vocab.build()\n",
650
  " label_vocab.build()\n",
651
  " \n",
652
+ " # Calculate class weights for balanced loss\n",
653
  " class_weights = calculate_class_weights(data, label_vocab)\n",
654
  " class_weights = class_weights.to(device)\n",
655
  " \n",
656
+ " # Split data into train and validation sets\n",
657
  " X_train, X_val, y_train, y_val = train_test_split(\n",
658
  " data.tokens.tolist(),\n",
659
  " data.labels.tolist(),\n",
 
665
  " print(f\" - Train samples: {len(X_train):,}\")\n",
666
  " print(f\" - Validation samples: {len(X_val):,}\")\n",
667
  " \n",
668
+ " # Create datasets\n",
669
  " max_seq_len = 512\n",
670
  " train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
671
  " val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
672
  " \n",
673
+ " # Create balanced sampler for training\n",
674
  " train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
675
  " \n",
676
+ " # Create data loaders\n",
677
  " train_loader = DataLoader(\n",
678
  " train_dataset, \n",
679
  " batch_size=batch_size,\n",
 
702
  " 'max_len': max_seq_len\n",
703
  " }\n",
704
  " \n",
705
+ " # Create transformer model\n",
706
  " print(\"\\nCreating model...\")\n",
707
  " model = create_transformer_pii_model(**model_config).to(device)\n",
708
  " print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
 
717
  " else:\n",
718
  " criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=0)\n",
719
  " \n",
720
+ " # Setup optimizer and learning rate scheduler\n",
721
  " optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)\n",
722
  " scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)\n",
723
  " \n",
724
+ " # Initialize metrics\n",
725
  " f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
726
  " f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
727
  " \n",
728
+ " # Training history\n",
729
  " train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
730
  " best_val_f1 = 0\n",
731
  " patience = 5\n",
 
734
  " print(\"\\nStarting training...\")\n",
735
  " print(\"=\" * 60)\n",
736
  " \n",
737
+ " # Training loop\n",
738
  " for epoch in range(num_epochs):\n",
739
  " print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
740
  " \n",
741
+ " # Train for one epoch\n",
742
  " train_loss, train_f1 = train_epoch(\n",
743
  " model, train_loader, optimizer, criterion, device, f1_metric_train\n",
744
  " )\n",
745
+ " \n",
746
+ " # Evaluate on validation set\n",
747
  " val_loss, val_f1 = evaluate(\n",
748
  " model, val_loader, criterion, device, f1_metric_val\n",
749
  " )\n",
750
  " \n",
751
+ " # Adjust learning rate based on validation loss\n",
752
  " scheduler.step(val_loss)\n",
753
  " \n",
754
  " # Store metrics\n",
 
785
  " else:\n",
786
  " patience_counter += 1\n",
787
  " \n",
788
+ " # Early stopping check\n",
789
  " if patience_counter >= patience and epoch > 10:\n",
790
  " print(f\"\\nEarly stopping triggered after {patience} epochs without improvement\")\n",
791
  " break\n",
 
793
  " # Plot training curves\n",
794
  " plt.figure(figsize=(12, 5))\n",
795
  " \n",
796
+ " # Plot loss curves\n",
797
  " plt.subplot(1, 2, 1)\n",
798
  " plt.plot(train_losses, label='Train Loss', linewidth=2)\n",
799
  " plt.plot(val_losses, label='Val Loss', linewidth=2)\n",
 
803
  " plt.legend()\n",
804
  " plt.grid(True, alpha=0.3)\n",
805
  " \n",
806
+ " # Plot F1 score curves\n",
807
  " plt.subplot(1, 2, 2)\n",
808
  " plt.plot(train_f1s, label='Train F1', linewidth=2)\n",
809
  " plt.plot(val_f1s, label='Val F1', linewidth=2)\n",
 
821
  " print(f\"Training completed!\")\n",
822
  " print(f\"Best validation F1: {best_val_f1:.4f}\")\n",
823
  " \n",
824
+ " # Save model for deployment\n",
825
  " save_model(model, text_vocab, label_vocab, model_config, 'saved_transformer_model')\n",
826
  " \n",
827
  " return model, text_vocab, label_vocab"
 
1296
  }
1297
  ],
1298
  "source": [
1299
+ "# Set device\n",
1300
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
1301
  "print(f\"Using device: {device}\")\n",
1302
  "\n",
1303
+ "# Train the transformer model\n",
1304
  "model, text_vocab, label_vocab = train_transformer_pii_model(\n",
1305
  " data_path='train_augmented.json',\n",
1306
  " num_epochs=20,\n",