Spaces:
Sleeping
Sleeping
Commit
·
f53fac9
1
Parent(s):
3dea7de
add comments
Browse files- app.py +67 -47
- data_augmentation.py +64 -95
- lstm.py +48 -61
- lstm_training.ipynb +65 -26
- transformer.py +56 -172
- transformer_training.ipynb +76 -29
app.py
CHANGED
|
@@ -10,20 +10,23 @@ from collections import Counter
|
|
| 10 |
import warnings
|
| 11 |
warnings.filterwarnings('ignore')
|
| 12 |
|
| 13 |
-
#
|
| 14 |
class Vocabulary:
|
| 15 |
"""Vocabulary class for encoding/decoding text and labels"""
|
| 16 |
def __init__(self, max_size=100000):
|
|
|
|
| 17 |
self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}
|
| 18 |
self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}
|
| 19 |
self.word_count = Counter()
|
| 20 |
self.max_size = max_size
|
| 21 |
|
| 22 |
def add_sentence(self, sentence):
|
|
|
|
| 23 |
for word in sentence:
|
| 24 |
self.word_count[word.lower()] += 1
|
| 25 |
|
| 26 |
def build(self):
|
|
|
|
| 27 |
most_common = self.word_count.most_common(self.max_size - len(self.word2idx))
|
| 28 |
for word, _ in most_common:
|
| 29 |
if word not in self.word2idx:
|
|
@@ -35,12 +38,14 @@ class Vocabulary:
|
|
| 35 |
return len(self.word2idx)
|
| 36 |
|
| 37 |
def encode(self, sentence):
|
|
|
|
| 38 |
return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]
|
| 39 |
|
| 40 |
def decode(self, indices):
|
|
|
|
| 41 |
return [self.idx2word.get(idx, '<unk>') for idx in indices]
|
| 42 |
|
| 43 |
-
#
|
| 44 |
class MultiHeadAttention(nn.Module):
|
| 45 |
def __init__(self, d_model, num_heads, dropout=0.1):
|
| 46 |
super().__init__()
|
|
@@ -49,6 +54,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 49 |
self.num_heads = num_heads
|
| 50 |
self.d_k = d_model // num_heads
|
| 51 |
|
|
|
|
| 52 |
self.w_q = nn.Linear(d_model, d_model)
|
| 53 |
self.w_k = nn.Linear(d_model, d_model)
|
| 54 |
self.w_v = nn.Linear(d_model, d_model)
|
|
@@ -59,24 +65,27 @@ class MultiHeadAttention(nn.Module):
|
|
| 59 |
def forward(self, query, key, value, mask=None):
|
| 60 |
batch_size = query.size(0)
|
| 61 |
|
| 62 |
-
#
|
| 63 |
Q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
| 64 |
K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
| 65 |
V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
| 66 |
|
| 67 |
-
#
|
| 68 |
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 69 |
|
|
|
|
| 70 |
if mask is not None:
|
| 71 |
mask = mask.unsqueeze(1).unsqueeze(1)
|
| 72 |
scores = scores.masked_fill(mask, -1e9)
|
| 73 |
|
|
|
|
| 74 |
attention = F.softmax(scores, dim=-1)
|
| 75 |
attention = self.dropout(attention)
|
| 76 |
|
|
|
|
| 77 |
context = torch.matmul(attention, V)
|
| 78 |
|
| 79 |
-
#
|
| 80 |
context = context.transpose(1, 2).contiguous().view(
|
| 81 |
batch_size, -1, self.d_model
|
| 82 |
)
|
|
@@ -84,6 +93,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 84 |
output = self.w_o(context)
|
| 85 |
return output
|
| 86 |
|
|
|
|
| 87 |
class FeedForward(nn.Module):
|
| 88 |
def __init__(self, d_model, d_ff, dropout=0.1):
|
| 89 |
super().__init__()
|
|
@@ -92,8 +102,10 @@ class FeedForward(nn.Module):
|
|
| 92 |
self.dropout = nn.Dropout(dropout)
|
| 93 |
|
| 94 |
def forward(self, x):
|
|
|
|
| 95 |
return self.w_2(self.dropout(F.gelu(self.w_1(x))))
|
| 96 |
|
|
|
|
| 97 |
class EncoderLayer(nn.Module):
|
| 98 |
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
|
| 99 |
super().__init__()
|
|
@@ -104,19 +116,21 @@ class EncoderLayer(nn.Module):
|
|
| 104 |
self.dropout = nn.Dropout(dropout)
|
| 105 |
|
| 106 |
def forward(self, x, mask=None):
|
| 107 |
-
#
|
| 108 |
attn_output = self.self_attention(x, x, x, mask)
|
| 109 |
x = self.norm1(x + self.dropout(attn_output))
|
| 110 |
|
| 111 |
-
#
|
| 112 |
ff_output = self.feed_forward(x)
|
| 113 |
x = self.norm2(x + self.dropout(ff_output))
|
| 114 |
|
| 115 |
return x
|
| 116 |
|
|
|
|
| 117 |
class TransformerEncoder(nn.Module):
|
| 118 |
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
|
| 119 |
super().__init__()
|
|
|
|
| 120 |
self.layers = nn.ModuleList([
|
| 121 |
EncoderLayer(d_model, num_heads, d_ff, dropout)
|
| 122 |
for _ in range(num_layers)
|
|
@@ -124,49 +138,63 @@ class TransformerEncoder(nn.Module):
|
|
| 124 |
self.norm = nn.LayerNorm(d_model)
|
| 125 |
|
| 126 |
def forward(self, x, mask=None):
|
|
|
|
| 127 |
for layer in self.layers:
|
| 128 |
x = layer(x, mask)
|
| 129 |
return self.norm(x)
|
| 130 |
|
|
|
|
| 131 |
class PositionalEncoding(nn.Module):
|
| 132 |
def __init__(self, d_model, max_len=5000):
|
| 133 |
super().__init__()
|
| 134 |
self.d_model = d_model
|
|
|
|
|
|
|
| 135 |
pe = torch.zeros(max_len, d_model)
|
| 136 |
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 137 |
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
| 138 |
-(torch.log(torch.tensor(10000.0)) / d_model))
|
|
|
|
|
|
|
| 139 |
pe[:, 0::2] = torch.sin(position * div_term)
|
| 140 |
pe[:, 1::2] = torch.cos(position * div_term)
|
| 141 |
self.register_buffer('pe', pe.unsqueeze(0))
|
| 142 |
|
| 143 |
def forward(self, x):
|
|
|
|
| 144 |
return x * torch.sqrt(torch.tensor(self.d_model, dtype=x.dtype)) + self.pe[:, :x.size(1)]
|
| 145 |
|
|
|
|
| 146 |
class TransformerPIIDetector(nn.Module):
|
| 147 |
def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
|
| 148 |
d_ff=512, num_layers=4, dropout=0.1, max_len=512):
|
| 149 |
super().__init__()
|
| 150 |
|
|
|
|
| 151 |
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
|
| 152 |
-
self.positional_encoding = PositionalEncoding(d_model, max_len)
|
| 153 |
self.dropout = nn.Dropout(dropout)
|
| 154 |
-
|
| 155 |
-
# Custom encoder to match saved model structure
|
| 156 |
self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
|
| 157 |
self.classifier = nn.Linear(d_model, num_classes)
|
| 158 |
|
| 159 |
def forward(self, x):
|
|
|
|
| 160 |
padding_mask = (x == 0)
|
|
|
|
|
|
|
| 161 |
x = self.embedding(x)
|
| 162 |
x = self.positional_encoding(x)
|
| 163 |
x = self.dropout(x)
|
|
|
|
|
|
|
| 164 |
x = self.encoder(x, padding_mask)
|
| 165 |
return self.classifier(x)
|
| 166 |
|
| 167 |
def create_transformer_pii_model(**kwargs):
|
|
|
|
| 168 |
return TransformerPIIDetector(**kwargs)
|
| 169 |
|
|
|
|
| 170 |
class PIIDetector:
|
| 171 |
def __init__(self, model_dir='saved_transformer'):
|
| 172 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
@@ -176,13 +204,13 @@ class PIIDetector:
|
|
| 176 |
self.label_vocab = None
|
| 177 |
self.load_model()
|
| 178 |
|
| 179 |
-
#
|
| 180 |
self.highlight_color = '#FF6B6B'
|
| 181 |
|
| 182 |
def load_model(self):
|
| 183 |
"""Load the trained model and vocabularies"""
|
| 184 |
try:
|
| 185 |
-
# Load vocabularies
|
| 186 |
vocab_path = os.path.join(self.model_dir, 'vocabularies.pkl')
|
| 187 |
with open(vocab_path, 'rb') as f:
|
| 188 |
vocabs = pickle.load(f)
|
|
@@ -194,7 +222,7 @@ class PIIDetector:
|
|
| 194 |
with open(config_path, 'rb') as f:
|
| 195 |
model_config = pickle.load(f)
|
| 196 |
|
| 197 |
-
#
|
| 198 |
self.model = create_transformer_pii_model(**model_config)
|
| 199 |
model_path = os.path.join(self.model_dir, 'pii_transformer_model.pt')
|
| 200 |
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
@@ -211,7 +239,7 @@ class PIIDetector:
|
|
| 211 |
def tokenize(self, text: str) -> List[str]:
|
| 212 |
"""Simple tokenization by splitting on spaces and punctuation"""
|
| 213 |
import re
|
| 214 |
-
# Split
|
| 215 |
tokens = re.findall(r'\w+|[^\w\s]', text)
|
| 216 |
return tokens
|
| 217 |
|
|
@@ -220,30 +248,30 @@ class PIIDetector:
|
|
| 220 |
if not text.strip():
|
| 221 |
return []
|
| 222 |
|
| 223 |
-
# Tokenize
|
| 224 |
tokens = self.tokenize(text)
|
| 225 |
|
| 226 |
-
# Add
|
| 227 |
tokens_with_special = ['<start>'] + tokens + ['<end>']
|
| 228 |
|
| 229 |
-
#
|
| 230 |
token_ids = self.text_vocab.encode(tokens_with_special)
|
| 231 |
|
| 232 |
-
#
|
| 233 |
input_tensor = torch.tensor([token_ids]).to(self.device)
|
| 234 |
|
| 235 |
-
#
|
| 236 |
with torch.no_grad():
|
| 237 |
outputs = self.model(input_tensor)
|
| 238 |
predictions = torch.argmax(outputs, dim=-1)
|
| 239 |
|
| 240 |
-
#
|
| 241 |
predicted_labels = []
|
| 242 |
-
for idx in predictions[0][1:-1]: # Skip
|
| 243 |
label = self.label_vocab.idx2word.get(idx.item(), 'O')
|
| 244 |
predicted_labels.append(label.upper())
|
| 245 |
|
| 246 |
-
#
|
| 247 |
return list(zip(tokens, predicted_labels))
|
| 248 |
|
| 249 |
def create_highlighted_html(self, token_label_pairs: List[Tuple[str, str]]) -> str:
|
|
@@ -254,14 +282,14 @@ class PIIDetector:
|
|
| 254 |
while i < len(token_label_pairs):
|
| 255 |
token, label = token_label_pairs[i]
|
| 256 |
|
| 257 |
-
# Check if
|
| 258 |
if label != 'O':
|
| 259 |
# Collect all tokens for this entity
|
| 260 |
entity_tokens = [token]
|
| 261 |
entity_label = label
|
| 262 |
j = i + 1
|
| 263 |
|
| 264 |
-
#
|
| 265 |
while j < len(token_label_pairs):
|
| 266 |
next_token, next_label = token_label_pairs[j]
|
| 267 |
if next_label.startswith('I-') and next_label.replace('I-', 'B-') == entity_label:
|
|
@@ -270,14 +298,14 @@ class PIIDetector:
|
|
| 270 |
else:
|
| 271 |
break
|
| 272 |
|
| 273 |
-
# Join tokens with
|
| 274 |
entity_text = ''
|
| 275 |
for k, tok in enumerate(entity_tokens):
|
| 276 |
if k > 0 and tok not in '.,!?;:':
|
| 277 |
entity_text += ' '
|
| 278 |
entity_text += tok
|
| 279 |
|
| 280 |
-
#
|
| 281 |
label_display = entity_label.replace('B-', '').replace('I-', '').replace('_', ' ')
|
| 282 |
html_parts.append(
|
| 283 |
f'<mark style="background-color: {self.highlight_color}; padding: 2px 4px; '
|
|
@@ -287,7 +315,7 @@ class PIIDetector:
|
|
| 287 |
|
| 288 |
i = j
|
| 289 |
else:
|
| 290 |
-
# Add
|
| 291 |
if i > 0 and token not in '.,!?;:' and len(token_label_pairs) > i-1:
|
| 292 |
prev_token, _ = token_label_pairs[i-1]
|
| 293 |
if prev_token not in '(':
|
|
@@ -306,14 +334,14 @@ class PIIDetector:
|
|
| 306 |
total_tokens = len(token_label_pairs)
|
| 307 |
pii_tokens = 0
|
| 308 |
|
|
|
|
| 309 |
for _, label in token_label_pairs:
|
| 310 |
if label != 'O':
|
| 311 |
pii_tokens += 1
|
| 312 |
-
# Clean up label for display
|
| 313 |
label_clean = label.replace('B-', '').replace('I-', '').replace('_', ' ')
|
| 314 |
stats[label_clean] = stats.get(label_clean, 0) + 1
|
| 315 |
|
| 316 |
-
#
|
| 317 |
stats_text = f"### Detection Summary\n\n"
|
| 318 |
stats_text += f"**Total tokens:** {total_tokens}\n\n"
|
| 319 |
stats_text += f"**PII tokens:** {pii_tokens} ({pii_tokens/total_tokens*100:.1f}%)\n\n"
|
|
@@ -323,7 +351,7 @@ class PIIDetector:
|
|
| 323 |
|
| 324 |
return stats_text
|
| 325 |
|
| 326 |
-
# Initialize the detector
|
| 327 |
print("Initializing PII Detector...")
|
| 328 |
detector = PIIDetector()
|
| 329 |
|
|
@@ -333,13 +361,13 @@ def detect_pii(text):
|
|
| 333 |
return "<p style='color: #6c757d; padding: 20px;'>Please enter some text to analyze.</p>", "No text provided."
|
| 334 |
|
| 335 |
try:
|
| 336 |
-
#
|
| 337 |
token_label_pairs = detector.predict(text)
|
| 338 |
|
| 339 |
-
#
|
| 340 |
highlighted_html = detector.create_highlighted_html(token_label_pairs)
|
| 341 |
|
| 342 |
-
#
|
| 343 |
stats = detector.get_statistics(token_label_pairs)
|
| 344 |
|
| 345 |
return highlighted_html, stats
|
|
@@ -349,18 +377,7 @@ def detect_pii(text):
|
|
| 349 |
error_stats = f"Error occurred: {str(e)}"
|
| 350 |
return error_html, error_stats
|
| 351 |
|
| 352 |
-
#
|
| 353 |
-
examples = [
|
| 354 |
-
"My name is John Smith and my email is john.smith@email.com. You can reach me at 555-123-4567.",
|
| 355 |
-
"Student ID: 12345678. Please send the documents to 123 Main Street, Anytown, USA 12345.",
|
| 356 |
-
"Contact Sarah Johnson at sarah_j_2023@gmail.com or visit her profile at linkedin.com/in/sarahjohnson",
|
| 357 |
-
"The project was completed by student A1B2C3D4 who lives at 456 Oak Avenue.",
|
| 358 |
-
"For verification, my phone number is (555) 987-6543 and my username is cool_user_99.",
|
| 359 |
-
"Hi, I'm Emily Chen. My student number is STU-2023-98765 and I live at 789 Pine Street, Apt 4B.",
|
| 360 |
-
"You can reach me at my personal website: www.johndoe.com or call me at +1-555-0123.",
|
| 361 |
-
]
|
| 362 |
-
|
| 363 |
-
# Create Gradio interface
|
| 364 |
with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
|
| 365 |
gr.Markdown(
|
| 366 |
"""
|
|
@@ -371,6 +388,7 @@ with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
|
|
| 371 |
)
|
| 372 |
|
| 373 |
with gr.Column():
|
|
|
|
| 374 |
input_text = gr.Textbox(
|
| 375 |
label="Input Text",
|
| 376 |
placeholder="Enter text to analyze for PII...",
|
|
@@ -378,10 +396,12 @@ with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
|
|
| 378 |
max_lines=20
|
| 379 |
)
|
| 380 |
|
|
|
|
| 381 |
with gr.Row():
|
| 382 |
analyze_btn = gr.Button("🔍 Detect PII", variant="primary", scale=2)
|
| 383 |
clear_btn = gr.Button("🗑️ Clear", scale=1)
|
| 384 |
|
|
|
|
| 385 |
highlighted_output = gr.HTML(
|
| 386 |
label="Highlighted Text",
|
| 387 |
value="<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>"
|
|
@@ -392,7 +412,7 @@ with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
|
|
| 392 |
value="*Statistics will appear here...*"
|
| 393 |
)
|
| 394 |
|
| 395 |
-
#
|
| 396 |
analyze_btn.click(
|
| 397 |
fn=detect_pii,
|
| 398 |
inputs=[input_text],
|
|
@@ -404,7 +424,7 @@ with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
|
|
| 404 |
outputs=[input_text, highlighted_output, stats_output]
|
| 405 |
)
|
| 406 |
|
| 407 |
-
# Launch the
|
| 408 |
if __name__ == "__main__":
|
| 409 |
print("\nLaunching Gradio interface...")
|
| 410 |
demo.launch()
|
|
|
|
| 10 |
import warnings
|
| 11 |
warnings.filterwarnings('ignore')
|
| 12 |
|
| 13 |
+
# Vocabulary class for handling text encoding and decoding
|
| 14 |
class Vocabulary:
|
| 15 |
"""Vocabulary class for encoding/decoding text and labels"""
|
| 16 |
def __init__(self, max_size=100000):
|
| 17 |
+
# Initialize special tokens
|
| 18 |
self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}
|
| 19 |
self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}
|
| 20 |
self.word_count = Counter()
|
| 21 |
self.max_size = max_size
|
| 22 |
|
| 23 |
def add_sentence(self, sentence):
|
| 24 |
+
# Count word frequencies in the sentence
|
| 25 |
for word in sentence:
|
| 26 |
self.word_count[word.lower()] += 1
|
| 27 |
|
| 28 |
def build(self):
|
| 29 |
+
# Build vocabulary from most common words
|
| 30 |
most_common = self.word_count.most_common(self.max_size - len(self.word2idx))
|
| 31 |
for word, _ in most_common:
|
| 32 |
if word not in self.word2idx:
|
|
|
|
| 38 |
return len(self.word2idx)
|
| 39 |
|
| 40 |
def encode(self, sentence):
|
| 41 |
+
# Convert words to indices
|
| 42 |
return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]
|
| 43 |
|
| 44 |
def decode(self, indices):
|
| 45 |
+
# Convert indices back to words
|
| 46 |
return [self.idx2word.get(idx, '<unk>') for idx in indices]
|
| 47 |
|
| 48 |
+
# Multi-head attention mechanism for the transformer
|
| 49 |
class MultiHeadAttention(nn.Module):
|
| 50 |
def __init__(self, d_model, num_heads, dropout=0.1):
|
| 51 |
super().__init__()
|
|
|
|
| 54 |
self.num_heads = num_heads
|
| 55 |
self.d_k = d_model // num_heads
|
| 56 |
|
| 57 |
+
# Linear layers for query, key, value, and output
|
| 58 |
self.w_q = nn.Linear(d_model, d_model)
|
| 59 |
self.w_k = nn.Linear(d_model, d_model)
|
| 60 |
self.w_v = nn.Linear(d_model, d_model)
|
|
|
|
| 65 |
def forward(self, query, key, value, mask=None):
|
| 66 |
batch_size = query.size(0)
|
| 67 |
|
| 68 |
+
# Transform and reshape for multi-head attention
|
| 69 |
Q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
| 70 |
K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
| 71 |
V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
| 72 |
|
| 73 |
+
# Calculate attention scores
|
| 74 |
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 75 |
|
| 76 |
+
# Apply mask if provided
|
| 77 |
if mask is not None:
|
| 78 |
mask = mask.unsqueeze(1).unsqueeze(1)
|
| 79 |
scores = scores.masked_fill(mask, -1e9)
|
| 80 |
|
| 81 |
+
# Apply softmax and dropout
|
| 82 |
attention = F.softmax(scores, dim=-1)
|
| 83 |
attention = self.dropout(attention)
|
| 84 |
|
| 85 |
+
# Apply attention to values
|
| 86 |
context = torch.matmul(attention, V)
|
| 87 |
|
| 88 |
+
# Reshape back to original dimensions
|
| 89 |
context = context.transpose(1, 2).contiguous().view(
|
| 90 |
batch_size, -1, self.d_model
|
| 91 |
)
|
|
|
|
| 93 |
output = self.w_o(context)
|
| 94 |
return output
|
| 95 |
|
| 96 |
+
# Feed-forward network component
|
| 97 |
class FeedForward(nn.Module):
|
| 98 |
def __init__(self, d_model, d_ff, dropout=0.1):
|
| 99 |
super().__init__()
|
|
|
|
| 102 |
self.dropout = nn.Dropout(dropout)
|
| 103 |
|
| 104 |
def forward(self, x):
|
| 105 |
+
# Two linear layers with GELU activation
|
| 106 |
return self.w_2(self.dropout(F.gelu(self.w_1(x))))
|
| 107 |
|
| 108 |
+
# Single encoder layer combining attention and feed-forward
|
| 109 |
class EncoderLayer(nn.Module):
|
| 110 |
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
|
| 111 |
super().__init__()
|
|
|
|
| 116 |
self.dropout = nn.Dropout(dropout)
|
| 117 |
|
| 118 |
def forward(self, x, mask=None):
|
| 119 |
+
# Apply self-attention with residual connection
|
| 120 |
attn_output = self.self_attention(x, x, x, mask)
|
| 121 |
x = self.norm1(x + self.dropout(attn_output))
|
| 122 |
|
| 123 |
+
# Apply feed-forward with residual connection
|
| 124 |
ff_output = self.feed_forward(x)
|
| 125 |
x = self.norm2(x + self.dropout(ff_output))
|
| 126 |
|
| 127 |
return x
|
| 128 |
|
| 129 |
+
# Stack of encoder layers
|
| 130 |
class TransformerEncoder(nn.Module):
|
| 131 |
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
|
| 132 |
super().__init__()
|
| 133 |
+
# Create multiple encoder layers
|
| 134 |
self.layers = nn.ModuleList([
|
| 135 |
EncoderLayer(d_model, num_heads, d_ff, dropout)
|
| 136 |
for _ in range(num_layers)
|
|
|
|
| 138 |
self.norm = nn.LayerNorm(d_model)
|
| 139 |
|
| 140 |
def forward(self, x, mask=None):
|
| 141 |
+
# Pass through each encoder layer
|
| 142 |
for layer in self.layers:
|
| 143 |
x = layer(x, mask)
|
| 144 |
return self.norm(x)
|
| 145 |
|
| 146 |
+
# Positional encoding to add position information to embeddings
|
| 147 |
class PositionalEncoding(nn.Module):
|
| 148 |
def __init__(self, d_model, max_len=5000):
|
| 149 |
super().__init__()
|
| 150 |
self.d_model = d_model
|
| 151 |
+
|
| 152 |
+
# Create positional encoding matrix
|
| 153 |
pe = torch.zeros(max_len, d_model)
|
| 154 |
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 155 |
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
| 156 |
-(torch.log(torch.tensor(10000.0)) / d_model))
|
| 157 |
+
|
| 158 |
+
# Apply sine and cosine functions
|
| 159 |
pe[:, 0::2] = torch.sin(position * div_term)
|
| 160 |
pe[:, 1::2] = torch.cos(position * div_term)
|
| 161 |
self.register_buffer('pe', pe.unsqueeze(0))
|
| 162 |
|
| 163 |
def forward(self, x):
|
| 164 |
+
# Scale embeddings and add positional encoding
|
| 165 |
return x * torch.sqrt(torch.tensor(self.d_model, dtype=x.dtype)) + self.pe[:, :x.size(1)]
|
| 166 |
|
| 167 |
+
# Main transformer model for PII detection
|
| 168 |
class TransformerPIIDetector(nn.Module):
|
| 169 |
def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
|
| 170 |
d_ff=512, num_layers=4, dropout=0.1, max_len=512):
|
| 171 |
super().__init__()
|
| 172 |
|
| 173 |
+
# Model components
|
| 174 |
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
|
| 175 |
+
self.positional_encoding = PositionalEncoding(d_model, max_len)
|
| 176 |
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
| 177 |
self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
|
| 178 |
self.classifier = nn.Linear(d_model, num_classes)
|
| 179 |
|
| 180 |
def forward(self, x):
|
| 181 |
+
# Create padding mask
|
| 182 |
padding_mask = (x == 0)
|
| 183 |
+
|
| 184 |
+
# Pass through embedding and positional encoding
|
| 185 |
x = self.embedding(x)
|
| 186 |
x = self.positional_encoding(x)
|
| 187 |
x = self.dropout(x)
|
| 188 |
+
|
| 189 |
+
# Encode and classify
|
| 190 |
x = self.encoder(x, padding_mask)
|
| 191 |
return self.classifier(x)
|
| 192 |
|
| 193 |
def create_transformer_pii_model(**kwargs):
|
| 194 |
+
# Factory function to create the model
|
| 195 |
return TransformerPIIDetector(**kwargs)
|
| 196 |
|
| 197 |
+
# Main PII detection class
|
| 198 |
class PIIDetector:
|
| 199 |
def __init__(self, model_dir='saved_transformer'):
|
| 200 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 204 |
self.label_vocab = None
|
| 205 |
self.load_model()
|
| 206 |
|
| 207 |
+
# Color for highlighting PII entities
|
| 208 |
self.highlight_color = '#FF6B6B'
|
| 209 |
|
| 210 |
def load_model(self):
|
| 211 |
"""Load the trained model and vocabularies"""
|
| 212 |
try:
|
| 213 |
+
# Load saved vocabularies
|
| 214 |
vocab_path = os.path.join(self.model_dir, 'vocabularies.pkl')
|
| 215 |
with open(vocab_path, 'rb') as f:
|
| 216 |
vocabs = pickle.load(f)
|
|
|
|
| 222 |
with open(config_path, 'rb') as f:
|
| 223 |
model_config = pickle.load(f)
|
| 224 |
|
| 225 |
+
# Initialize and load model weights
|
| 226 |
self.model = create_transformer_pii_model(**model_config)
|
| 227 |
model_path = os.path.join(self.model_dir, 'pii_transformer_model.pt')
|
| 228 |
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
|
|
| 239 |
def tokenize(self, text: str) -> List[str]:
|
| 240 |
"""Simple tokenization by splitting on spaces and punctuation"""
|
| 241 |
import re
|
| 242 |
+
# Split text into words and punctuation marks
|
| 243 |
tokens = re.findall(r'\w+|[^\w\s]', text)
|
| 244 |
return tokens
|
| 245 |
|
|
|
|
| 248 |
if not text.strip():
|
| 249 |
return []
|
| 250 |
|
| 251 |
+
# Tokenize input text
|
| 252 |
tokens = self.tokenize(text)
|
| 253 |
|
| 254 |
+
# Add special tokens
|
| 255 |
tokens_with_special = ['<start>'] + tokens + ['<end>']
|
| 256 |
|
| 257 |
+
# Convert tokens to indices
|
| 258 |
token_ids = self.text_vocab.encode(tokens_with_special)
|
| 259 |
|
| 260 |
+
# Prepare tensor for model
|
| 261 |
input_tensor = torch.tensor([token_ids]).to(self.device)
|
| 262 |
|
| 263 |
+
# Get predictions
|
| 264 |
with torch.no_grad():
|
| 265 |
outputs = self.model(input_tensor)
|
| 266 |
predictions = torch.argmax(outputs, dim=-1)
|
| 267 |
|
| 268 |
+
# Convert predictions to labels
|
| 269 |
predicted_labels = []
|
| 270 |
+
for idx in predictions[0][1:-1]: # Skip special tokens
|
| 271 |
label = self.label_vocab.idx2word.get(idx.item(), 'O')
|
| 272 |
predicted_labels.append(label.upper())
|
| 273 |
|
| 274 |
+
# Return token-label pairs
|
| 275 |
return list(zip(tokens, predicted_labels))
|
| 276 |
|
| 277 |
def create_highlighted_html(self, token_label_pairs: List[Tuple[str, str]]) -> str:
|
|
|
|
| 282 |
while i < len(token_label_pairs):
|
| 283 |
token, label = token_label_pairs[i]
|
| 284 |
|
| 285 |
+
# Check if token is part of PII entity
|
| 286 |
if label != 'O':
|
| 287 |
# Collect all tokens for this entity
|
| 288 |
entity_tokens = [token]
|
| 289 |
entity_label = label
|
| 290 |
j = i + 1
|
| 291 |
|
| 292 |
+
# Find continuation tokens
|
| 293 |
while j < len(token_label_pairs):
|
| 294 |
next_token, next_label = token_label_pairs[j]
|
| 295 |
if next_label.startswith('I-') and next_label.replace('I-', 'B-') == entity_label:
|
|
|
|
| 298 |
else:
|
| 299 |
break
|
| 300 |
|
| 301 |
+
# Join entity tokens with proper spacing
|
| 302 |
entity_text = ''
|
| 303 |
for k, tok in enumerate(entity_tokens):
|
| 304 |
if k > 0 and tok not in '.,!?;:':
|
| 305 |
entity_text += ' '
|
| 306 |
entity_text += tok
|
| 307 |
|
| 308 |
+
# Create highlighted HTML for entity
|
| 309 |
label_display = entity_label.replace('B-', '').replace('I-', '').replace('_', ' ')
|
| 310 |
html_parts.append(
|
| 311 |
f'<mark style="background-color: {self.highlight_color}; padding: 2px 4px; '
|
|
|
|
| 315 |
|
| 316 |
i = j
|
| 317 |
else:
|
| 318 |
+
# Add non-PII token with proper spacing
|
| 319 |
if i > 0 and token not in '.,!?;:' and len(token_label_pairs) > i-1:
|
| 320 |
prev_token, _ = token_label_pairs[i-1]
|
| 321 |
if prev_token not in '(':
|
|
|
|
| 334 |
total_tokens = len(token_label_pairs)
|
| 335 |
pii_tokens = 0
|
| 336 |
|
| 337 |
+
# Count PII tokens by type
|
| 338 |
for _, label in token_label_pairs:
|
| 339 |
if label != 'O':
|
| 340 |
pii_tokens += 1
|
|
|
|
| 341 |
label_clean = label.replace('B-', '').replace('I-', '').replace('_', ' ')
|
| 342 |
stats[label_clean] = stats.get(label_clean, 0) + 1
|
| 343 |
|
| 344 |
+
# Format statistics text
|
| 345 |
stats_text = f"### Detection Summary\n\n"
|
| 346 |
stats_text += f"**Total tokens:** {total_tokens}\n\n"
|
| 347 |
stats_text += f"**PII tokens:** {pii_tokens} ({pii_tokens/total_tokens*100:.1f}%)\n\n"
|
|
|
|
| 351 |
|
| 352 |
return stats_text
|
| 353 |
|
| 354 |
+
# Initialize the detector when the script runs
|
| 355 |
print("Initializing PII Detector...")
|
| 356 |
detector = PIIDetector()
|
| 357 |
|
|
|
|
| 361 |
return "<p style='color: #6c757d; padding: 20px;'>Please enter some text to analyze.</p>", "No text provided."
|
| 362 |
|
| 363 |
try:
|
| 364 |
+
# Run PII detection
|
| 365 |
token_label_pairs = detector.predict(text)
|
| 366 |
|
| 367 |
+
# Generate highlighted output
|
| 368 |
highlighted_html = detector.create_highlighted_html(token_label_pairs)
|
| 369 |
|
| 370 |
+
# Generate statistics
|
| 371 |
stats = detector.get_statistics(token_label_pairs)
|
| 372 |
|
| 373 |
return highlighted_html, stats
|
|
|
|
| 377 |
error_stats = f"Error occurred: {str(e)}"
|
| 378 |
return error_html, error_stats
|
| 379 |
|
| 380 |
+
# Create the Gradio interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
|
| 382 |
gr.Markdown(
|
| 383 |
"""
|
|
|
|
| 388 |
)
|
| 389 |
|
| 390 |
with gr.Column():
|
| 391 |
+
# Input text area
|
| 392 |
input_text = gr.Textbox(
|
| 393 |
label="Input Text",
|
| 394 |
placeholder="Enter text to analyze for PII...",
|
|
|
|
| 396 |
max_lines=20
|
| 397 |
)
|
| 398 |
|
| 399 |
+
# Control buttons
|
| 400 |
with gr.Row():
|
| 401 |
analyze_btn = gr.Button("🔍 Detect PII", variant="primary", scale=2)
|
| 402 |
clear_btn = gr.Button("🗑️ Clear", scale=1)
|
| 403 |
|
| 404 |
+
# Output areas
|
| 405 |
highlighted_output = gr.HTML(
|
| 406 |
label="Highlighted Text",
|
| 407 |
value="<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>"
|
|
|
|
| 412 |
value="*Statistics will appear here...*"
|
| 413 |
)
|
| 414 |
|
| 415 |
+
# Connect buttons to functions
|
| 416 |
analyze_btn.click(
|
| 417 |
fn=detect_pii,
|
| 418 |
inputs=[input_text],
|
|
|
|
| 424 |
outputs=[input_text, highlighted_output, stats_output]
|
| 425 |
)
|
| 426 |
|
| 427 |
+
# Launch the application
|
| 428 |
if __name__ == "__main__":
|
| 429 |
print("\nLaunching Gradio interface...")
|
| 430 |
demo.launch()
|
data_augmentation.py
CHANGED
|
@@ -15,17 +15,20 @@ class PIIDataAugmenter:
|
|
| 15 |
|
| 16 |
def __init__(self, seed=42):
|
| 17 |
"""Initialize the augmenter with random seeds for reproducibility."""
|
|
|
|
| 18 |
random.seed(seed)
|
| 19 |
np.random.seed(seed)
|
| 20 |
self.fake = Faker()
|
| 21 |
Faker.seed(seed)
|
| 22 |
|
|
|
|
| 23 |
self._init_templates()
|
| 24 |
self._init_context_phrases()
|
| 25 |
self._init_generators()
|
| 26 |
|
| 27 |
def _init_templates(self):
|
| 28 |
"""Initialize templates for different PII types."""
|
|
|
|
| 29 |
self.templates = {
|
| 30 |
'NAME_STUDENT': [
|
| 31 |
"My name is {name}",
|
|
@@ -115,6 +118,7 @@ class PIIDataAugmenter:
|
|
| 115 |
|
| 116 |
def _init_context_phrases(self):
|
| 117 |
"""Initialize context phrases for more natural text generation."""
|
|
|
|
| 118 |
self.context_prefix = [
|
| 119 |
"Hello everyone,",
|
| 120 |
"Dear Sir/Madam,",
|
|
@@ -128,6 +132,7 @@ class PIIDataAugmenter:
|
|
| 128 |
"I am writing to tell you that"
|
| 129 |
]
|
| 130 |
|
|
|
|
| 131 |
self.context_suffix = [
|
| 132 |
"Thank you.",
|
| 133 |
"Best regards.",
|
|
@@ -141,12 +146,14 @@ class PIIDataAugmenter:
|
|
| 141 |
"Let me know if you have questions."
|
| 142 |
]
|
| 143 |
|
|
|
|
| 144 |
self.connectors = [
|
| 145 |
" and ", " or ", ", ", ". Also, ", ". Additionally, "
|
| 146 |
]
|
| 147 |
|
| 148 |
def _init_generators(self):
|
| 149 |
"""Initialize PII generators mapping."""
|
|
|
|
| 150 |
self.generators = {
|
| 151 |
'NAME_STUDENT': self.generate_name,
|
| 152 |
'EMAIL': self.generate_email,
|
|
@@ -157,6 +164,7 @@ class PIIDataAugmenter:
|
|
| 157 |
'USERNAME': self.generate_username
|
| 158 |
}
|
| 159 |
|
|
|
|
| 160 |
self.format_keys = {
|
| 161 |
'NAME_STUDENT': 'name',
|
| 162 |
'EMAIL': 'email',
|
|
@@ -167,8 +175,6 @@ class PIIDataAugmenter:
|
|
| 167 |
'USERNAME': 'username'
|
| 168 |
}
|
| 169 |
|
| 170 |
-
# ========== PII Generators ==========
|
| 171 |
-
|
| 172 |
def generate_name(self):
|
| 173 |
"""Generate realistic person names."""
|
| 174 |
return self.fake.name()
|
|
@@ -179,6 +185,7 @@ class PIIDataAugmenter:
|
|
| 179 |
|
| 180 |
def generate_phone(self):
|
| 181 |
"""Generate realistic phone numbers in various formats."""
|
|
|
|
| 182 |
formats = [
|
| 183 |
"555-{:03d}-{:04d}",
|
| 184 |
"(555) {:03d}-{:04d}",
|
|
@@ -186,6 +193,7 @@ class PIIDataAugmenter:
|
|
| 186 |
"+1-555-{:03d}-{:04d}",
|
| 187 |
"555{:03d}{:04d}"
|
| 188 |
]
|
|
|
|
| 189 |
format_choice = random.choice(formats)
|
| 190 |
area = random.randint(100, 999)
|
| 191 |
number = random.randint(1000, 9999)
|
|
@@ -193,10 +201,12 @@ class PIIDataAugmenter:
|
|
| 193 |
|
| 194 |
def generate_address(self):
|
| 195 |
"""Generate realistic street addresses."""
|
|
|
|
| 196 |
return self.fake.address().replace('\n', ', ')
|
| 197 |
|
| 198 |
def generate_id_num(self):
|
| 199 |
"""Generate various ID number formats."""
|
|
|
|
| 200 |
formats = [
|
| 201 |
"{:06d}", # 6-digit ID
|
| 202 |
"{:08d}", # 8-digit ID
|
|
@@ -207,6 +217,7 @@ class PIIDataAugmenter:
|
|
| 207 |
]
|
| 208 |
format_choice = random.choice(formats)
|
| 209 |
|
|
|
|
| 210 |
if '-' in format_choice:
|
| 211 |
return format_choice.format(
|
| 212 |
random.randint(1000, 9999),
|
|
@@ -217,6 +228,7 @@ class PIIDataAugmenter:
|
|
| 217 |
|
| 218 |
def generate_url(self):
|
| 219 |
"""Generate personal website URLs."""
|
|
|
|
| 220 |
domains = ['github.com', 'linkedin.com', 'portfolio.com',
|
| 221 |
'personal.com', 'website.com']
|
| 222 |
username = self.fake.user_name()
|
|
@@ -227,72 +239,53 @@ class PIIDataAugmenter:
|
|
| 227 |
"""Generate usernames."""
|
| 228 |
return self.fake.user_name()
|
| 229 |
|
| 230 |
-
# ========== Synthetic Example Creation ==========
|
| 231 |
-
|
| 232 |
def create_synthetic_example(self, pii_type, add_context=True):
|
| 233 |
-
"""
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
Args:
|
| 237 |
-
pii_type: Type of PII to generate
|
| 238 |
-
add_context: Whether to add context phrases
|
| 239 |
-
|
| 240 |
-
Returns:
|
| 241 |
-
Tuple of (tokens, labels)
|
| 242 |
-
"""
|
| 243 |
-
# Generate PII value
|
| 244 |
pii_value = self.generators[pii_type]()
|
| 245 |
|
| 246 |
-
#
|
| 247 |
template = random.choice(self.templates[pii_type])
|
| 248 |
format_key = self.format_keys[pii_type]
|
| 249 |
sentence = template.format(**{format_key: pii_value})
|
| 250 |
|
| 251 |
-
#
|
| 252 |
if add_context and random.random() > 0.3:
|
| 253 |
sentence = self._add_context(sentence)
|
| 254 |
|
| 255 |
-
#
|
| 256 |
tokens, labels = self._tokenize_and_label(sentence, pii_value, pii_type)
|
| 257 |
|
| 258 |
return tokens, labels
|
| 259 |
|
| 260 |
def create_mixed_example(self, pii_types, num_pii=2):
|
| 261 |
-
"""
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
Args:
|
| 265 |
-
pii_types: List of PII types to include
|
| 266 |
-
num_pii: Number of PII entities to include
|
| 267 |
-
|
| 268 |
-
Returns:
|
| 269 |
-
Tuple of (tokens, labels)
|
| 270 |
-
"""
|
| 271 |
selected_types = random.sample(pii_types, min(num_pii, len(pii_types)))
|
| 272 |
|
| 273 |
all_tokens = []
|
| 274 |
all_labels = []
|
| 275 |
|
| 276 |
-
# Add context
|
| 277 |
if random.random() > 0.3:
|
| 278 |
prefix = random.choice(self.context_prefix)
|
| 279 |
all_tokens.extend(prefix.split())
|
| 280 |
all_labels.extend(['O'] * len(prefix.split()))
|
| 281 |
|
| 282 |
-
# Add each PII
|
| 283 |
for i, pii_type in enumerate(selected_types):
|
| 284 |
-
# Add connector between
|
| 285 |
if i > 0 and random.random() > 0.5:
|
| 286 |
connector = random.choice(self.connectors)
|
| 287 |
all_tokens.extend(connector.strip().split())
|
| 288 |
all_labels.extend(['O'] * len(connector.strip().split()))
|
| 289 |
|
| 290 |
-
#
|
| 291 |
tokens, labels = self.create_synthetic_example(pii_type, add_context=False)
|
| 292 |
all_tokens.extend(tokens)
|
| 293 |
all_labels.extend(labels)
|
| 294 |
|
| 295 |
-
# Add context
|
| 296 |
if random.random() > 0.3:
|
| 297 |
suffix = random.choice(self.context_suffix)
|
| 298 |
all_tokens.extend(suffix.split())
|
|
@@ -302,79 +295,60 @@ class PIIDataAugmenter:
|
|
| 302 |
|
| 303 |
def _add_context(self, sentence):
|
| 304 |
"""Add context phrases to make text more natural."""
|
|
|
|
| 305 |
if random.random() > 0.5:
|
| 306 |
sentence = random.choice(self.context_prefix) + " " + sentence
|
|
|
|
| 307 |
if random.random() > 0.5:
|
| 308 |
sentence = sentence + " " + random.choice(self.context_suffix)
|
| 309 |
return sentence
|
| 310 |
|
| 311 |
def _tokenize_and_label(self, sentence, pii_value, pii_type):
|
| 312 |
-
"""
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
Args:
|
| 316 |
-
sentence: The sentence containing PII
|
| 317 |
-
pii_value: The PII value to find and label
|
| 318 |
-
pii_type: The type of PII for labeling
|
| 319 |
-
|
| 320 |
-
Returns:
|
| 321 |
-
Tuple of (tokens, labels)
|
| 322 |
-
"""
|
| 323 |
tokens = sentence.split()
|
| 324 |
labels = ['O'] * len(tokens)
|
| 325 |
|
| 326 |
-
#
|
| 327 |
pii_tokens = pii_value.split()
|
| 328 |
|
| 329 |
-
# Find
|
| 330 |
for i in range(len(tokens) - len(pii_tokens) + 1):
|
| 331 |
-
# Check if tokens match PII value
|
| 332 |
if (tokens[i:i+len(pii_tokens)] == pii_tokens or
|
| 333 |
' '.join(tokens[i:i+len(pii_tokens)]).lower() == pii_value.lower()):
|
| 334 |
|
| 335 |
-
# Apply BIO
|
| 336 |
-
labels[i] = f'B-{pii_type}'
|
| 337 |
for j in range(1, len(pii_tokens)):
|
| 338 |
-
labels[i+j] = f'I-{pii_type}'
|
| 339 |
break
|
| 340 |
|
| 341 |
return tokens, labels
|
| 342 |
|
| 343 |
-
# ========== Dataset Augmentation ==========
|
| 344 |
-
|
| 345 |
def augment_dataset(self, original_data, target_samples_per_class=1000, mix_ratio=0.3):
|
| 346 |
-
"""
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
Args:
|
| 350 |
-
original_data: Original dataset DataFrame
|
| 351 |
-
target_samples_per_class: Target number of samples per PII class
|
| 352 |
-
mix_ratio: Ratio of mixed (multi-PII) examples
|
| 353 |
-
|
| 354 |
-
Returns:
|
| 355 |
-
Augmented dataset DataFrame
|
| 356 |
-
"""
|
| 357 |
-
# Analyze original distribution
|
| 358 |
label_counts = self._analyze_label_distribution(original_data)
|
| 359 |
print("\nOriginal label distribution:")
|
| 360 |
self._print_distribution(label_counts)
|
| 361 |
|
| 362 |
-
# Generate synthetic
|
| 363 |
synthetic_tokens, synthetic_labels = self._generate_synthetic_data(
|
| 364 |
label_counts, target_samples_per_class, mix_ratio
|
| 365 |
)
|
| 366 |
|
| 367 |
-
# Add non-PII examples
|
| 368 |
synthetic_tokens, synthetic_labels = self._add_non_pii_examples(
|
| 369 |
synthetic_tokens, synthetic_labels
|
| 370 |
)
|
| 371 |
|
| 372 |
-
# Combine and
|
| 373 |
augmented_df = self._combine_and_shuffle(
|
| 374 |
original_data, synthetic_tokens, synthetic_labels
|
| 375 |
)
|
| 376 |
|
| 377 |
-
#
|
| 378 |
new_label_counts = self._analyze_label_distribution(augmented_df)
|
| 379 |
print("\nAugmented label distribution:")
|
| 380 |
self._print_distribution(new_label_counts)
|
|
@@ -385,10 +359,11 @@ class PIIDataAugmenter:
|
|
| 385 |
"""Analyze the distribution of PII labels in the dataset."""
|
| 386 |
label_counts = Counter()
|
| 387 |
|
|
|
|
| 388 |
for labels in data['labels']:
|
| 389 |
for label in labels:
|
| 390 |
if label != 'O':
|
| 391 |
-
#
|
| 392 |
base_label = label.split('-')[1] if '-' in label else label
|
| 393 |
label_counts[base_label] += 1
|
| 394 |
|
|
@@ -397,6 +372,7 @@ class PIIDataAugmenter:
|
|
| 397 |
def _print_distribution(self, label_counts):
|
| 398 |
"""Print label distribution statistics."""
|
| 399 |
total = sum(label_counts.values())
|
|
|
|
| 400 |
for label, count in label_counts.most_common():
|
| 401 |
percentage = (count / total * 100) if total > 0 else 0
|
| 402 |
print(f" {label:15} : {count:6,} ({percentage:5.2f}%)")
|
|
@@ -406,6 +382,7 @@ class PIIDataAugmenter:
|
|
| 406 |
synthetic_tokens = []
|
| 407 |
synthetic_labels = []
|
| 408 |
|
|
|
|
| 409 |
for pii_type in self.templates.keys():
|
| 410 |
current_count = label_counts.get(pii_type, 0)
|
| 411 |
needed = max(0, target_samples - current_count)
|
|
@@ -415,17 +392,17 @@ class PIIDataAugmenter:
|
|
| 415 |
|
| 416 |
print(f"\nGenerating {needed} synthetic examples for {pii_type}")
|
| 417 |
|
| 418 |
-
#
|
| 419 |
single_count = int(needed * (1 - mix_ratio))
|
| 420 |
for _ in range(single_count):
|
| 421 |
tokens, labels = self.create_synthetic_example(pii_type)
|
| 422 |
synthetic_tokens.append(tokens)
|
| 423 |
synthetic_labels.append(labels)
|
| 424 |
|
| 425 |
-
#
|
| 426 |
mixed_count = int(needed * mix_ratio)
|
| 427 |
for _ in range(mixed_count):
|
| 428 |
-
#
|
| 429 |
other_types = [t for t in self.templates.keys() if t != pii_type]
|
| 430 |
selected_types = [pii_type] + random.sample(
|
| 431 |
other_types, min(1, len(other_types))
|
|
@@ -439,6 +416,7 @@ class PIIDataAugmenter:
|
|
| 439 |
|
| 440 |
def _add_non_pii_examples(self, synthetic_tokens, synthetic_labels):
|
| 441 |
"""Add examples without PII (all 'O' labels) for balance."""
|
|
|
|
| 442 |
num_non_pii = int(len(synthetic_tokens) * 0.1)
|
| 443 |
|
| 444 |
for _ in range(num_non_pii):
|
|
@@ -454,17 +432,17 @@ class PIIDataAugmenter:
|
|
| 454 |
|
| 455 |
def _combine_and_shuffle(self, original_data, synthetic_tokens, synthetic_labels):
|
| 456 |
"""Combine original and synthetic data, then shuffle."""
|
| 457 |
-
#
|
| 458 |
all_tokens = original_data['tokens'].tolist() + synthetic_tokens
|
| 459 |
all_labels = original_data['labels'].tolist() + synthetic_labels
|
| 460 |
|
| 461 |
-
# Create
|
| 462 |
augmented_data = pd.DataFrame({
|
| 463 |
'tokens': all_tokens,
|
| 464 |
'labels': all_labels
|
| 465 |
})
|
| 466 |
|
| 467 |
-
# Shuffle
|
| 468 |
augmented_data = augmented_data.sample(frac=1, random_state=42).reset_index(drop=True)
|
| 469 |
|
| 470 |
print(f"\nTotal augmented samples: {len(augmented_data):,}")
|
|
@@ -472,17 +450,8 @@ class PIIDataAugmenter:
|
|
| 472 |
return augmented_data
|
| 473 |
|
| 474 |
def calculate_class_weights(data, label_vocab):
|
| 475 |
-
"""
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
Args:
|
| 479 |
-
data: Dataset DataFrame with 'labels' column
|
| 480 |
-
label_vocab: Vocabulary object with word2idx mapping
|
| 481 |
-
|
| 482 |
-
Returns:
|
| 483 |
-
Tensor of class weights
|
| 484 |
-
"""
|
| 485 |
-
# Count label occurrences
|
| 486 |
label_counts = Counter()
|
| 487 |
|
| 488 |
for labels in data['labels']:
|
|
@@ -490,7 +459,7 @@ def calculate_class_weights(data, label_vocab):
|
|
| 490 |
label_id = label_vocab.word2idx.get(label.lower(), 0)
|
| 491 |
label_counts[label_id] += 1
|
| 492 |
|
| 493 |
-
# Calculate inverse frequency
|
| 494 |
total_samples = sum(label_counts.values())
|
| 495 |
num_classes = len(label_vocab)
|
| 496 |
|
|
@@ -498,31 +467,31 @@ def calculate_class_weights(data, label_vocab):
|
|
| 498 |
|
| 499 |
for class_id, count in label_counts.items():
|
| 500 |
if count > 0:
|
| 501 |
-
# Inverse frequency
|
| 502 |
weights[class_id] = total_samples / (num_classes * count)
|
| 503 |
|
| 504 |
-
# Normalize weights
|
| 505 |
weights = weights / weights.sum() * num_classes
|
| 506 |
|
| 507 |
-
#
|
| 508 |
weights = torch.clamp(weights, min=0.1, max=10.0)
|
| 509 |
|
| 510 |
-
#
|
| 511 |
weights[0] = 0.0
|
| 512 |
|
| 513 |
return weights
|
| 514 |
|
| 515 |
if __name__ == '__main__':
|
| 516 |
"""Example usage of the augmentation module."""
|
| 517 |
-
# Load original data
|
| 518 |
print("Loading original training data...")
|
| 519 |
original_data = pd.read_json('train.json')
|
| 520 |
print(f"Original dataset size: {len(original_data):,}")
|
| 521 |
|
| 522 |
-
#
|
| 523 |
augmenter = PIIDataAugmenter(seed=42)
|
| 524 |
|
| 525 |
-
#
|
| 526 |
print("\n" + "="*60)
|
| 527 |
print("Starting data augmentation...")
|
| 528 |
print("="*60)
|
|
@@ -533,7 +502,7 @@ if __name__ == '__main__':
|
|
| 533 |
mix_ratio=0.3
|
| 534 |
)
|
| 535 |
|
| 536 |
-
# Save augmented
|
| 537 |
output_path = './train_augmented.json'
|
| 538 |
augmented_data.to_json(output_path, orient='records', lines=True)
|
| 539 |
print(f"\nSaved augmented data to {output_path}")
|
|
|
|
| 15 |
|
| 16 |
def __init__(self, seed=42):
|
| 17 |
"""Initialize the augmenter with random seeds for reproducibility."""
|
| 18 |
+
# Set random seeds for consistent results
|
| 19 |
random.seed(seed)
|
| 20 |
np.random.seed(seed)
|
| 21 |
self.fake = Faker()
|
| 22 |
Faker.seed(seed)
|
| 23 |
|
| 24 |
+
# Initialize data structures
|
| 25 |
self._init_templates()
|
| 26 |
self._init_context_phrases()
|
| 27 |
self._init_generators()
|
| 28 |
|
| 29 |
def _init_templates(self):
|
| 30 |
"""Initialize templates for different PII types."""
|
| 31 |
+
# Templates for generating sentences with PII
|
| 32 |
self.templates = {
|
| 33 |
'NAME_STUDENT': [
|
| 34 |
"My name is {name}",
|
|
|
|
| 118 |
|
| 119 |
def _init_context_phrases(self):
|
| 120 |
"""Initialize context phrases for more natural text generation."""
|
| 121 |
+
# Opening phrases for generated text
|
| 122 |
self.context_prefix = [
|
| 123 |
"Hello everyone,",
|
| 124 |
"Dear Sir/Madam,",
|
|
|
|
| 132 |
"I am writing to tell you that"
|
| 133 |
]
|
| 134 |
|
| 135 |
+
# Closing phrases for generated text
|
| 136 |
self.context_suffix = [
|
| 137 |
"Thank you.",
|
| 138 |
"Best regards.",
|
|
|
|
| 146 |
"Let me know if you have questions."
|
| 147 |
]
|
| 148 |
|
| 149 |
+
# Words to connect multiple PII elements
|
| 150 |
self.connectors = [
|
| 151 |
" and ", " or ", ", ", ". Also, ", ". Additionally, "
|
| 152 |
]
|
| 153 |
|
| 154 |
def _init_generators(self):
|
| 155 |
"""Initialize PII generators mapping."""
|
| 156 |
+
# Map PII types to their generator functions
|
| 157 |
self.generators = {
|
| 158 |
'NAME_STUDENT': self.generate_name,
|
| 159 |
'EMAIL': self.generate_email,
|
|
|
|
| 164 |
'USERNAME': self.generate_username
|
| 165 |
}
|
| 166 |
|
| 167 |
+
# Map PII types to template placeholder keys
|
| 168 |
self.format_keys = {
|
| 169 |
'NAME_STUDENT': 'name',
|
| 170 |
'EMAIL': 'email',
|
|
|
|
| 175 |
'USERNAME': 'username'
|
| 176 |
}
|
| 177 |
|
|
|
|
|
|
|
| 178 |
def generate_name(self):
|
| 179 |
"""Generate realistic person names."""
|
| 180 |
return self.fake.name()
|
|
|
|
| 185 |
|
| 186 |
def generate_phone(self):
|
| 187 |
"""Generate realistic phone numbers in various formats."""
|
| 188 |
+
# Different phone number formats
|
| 189 |
formats = [
|
| 190 |
"555-{:03d}-{:04d}",
|
| 191 |
"(555) {:03d}-{:04d}",
|
|
|
|
| 193 |
"+1-555-{:03d}-{:04d}",
|
| 194 |
"555{:03d}{:04d}"
|
| 195 |
]
|
| 196 |
+
# Pick a random format and fill with random numbers
|
| 197 |
format_choice = random.choice(formats)
|
| 198 |
area = random.randint(100, 999)
|
| 199 |
number = random.randint(1000, 9999)
|
|
|
|
| 201 |
|
| 202 |
def generate_address(self):
|
| 203 |
"""Generate realistic street addresses."""
|
| 204 |
+
# Get address and replace newlines with commas
|
| 205 |
return self.fake.address().replace('\n', ', ')
|
| 206 |
|
| 207 |
def generate_id_num(self):
|
| 208 |
"""Generate various ID number formats."""
|
| 209 |
+
# Different ID number patterns
|
| 210 |
formats = [
|
| 211 |
"{:06d}", # 6-digit ID
|
| 212 |
"{:08d}", # 8-digit ID
|
|
|
|
| 217 |
]
|
| 218 |
format_choice = random.choice(formats)
|
| 219 |
|
| 220 |
+
# Handle hyphenated format differently
|
| 221 |
if '-' in format_choice:
|
| 222 |
return format_choice.format(
|
| 223 |
random.randint(1000, 9999),
|
|
|
|
| 228 |
|
| 229 |
def generate_url(self):
|
| 230 |
"""Generate personal website URLs."""
|
| 231 |
+
# Common personal website domains
|
| 232 |
domains = ['github.com', 'linkedin.com', 'portfolio.com',
|
| 233 |
'personal.com', 'website.com']
|
| 234 |
username = self.fake.user_name()
|
|
|
|
| 239 |
"""Generate usernames."""
|
| 240 |
return self.fake.user_name()
|
| 241 |
|
|
|
|
|
|
|
| 242 |
def create_synthetic_example(self, pii_type, add_context=True):
|
| 243 |
+
"""Create a synthetic example with proper BIO labeling."""
|
| 244 |
+
# Generate the PII value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
pii_value = self.generators[pii_type]()
|
| 246 |
|
| 247 |
+
# Choose a template and insert the PII
|
| 248 |
template = random.choice(self.templates[pii_type])
|
| 249 |
format_key = self.format_keys[pii_type]
|
| 250 |
sentence = template.format(**{format_key: pii_value})
|
| 251 |
|
| 252 |
+
# Optionally add context for more natural text
|
| 253 |
if add_context and random.random() > 0.3:
|
| 254 |
sentence = self._add_context(sentence)
|
| 255 |
|
| 256 |
+
# Create tokens and labels
|
| 257 |
tokens, labels = self._tokenize_and_label(sentence, pii_value, pii_type)
|
| 258 |
|
| 259 |
return tokens, labels
|
| 260 |
|
| 261 |
def create_mixed_example(self, pii_types, num_pii=2):
|
| 262 |
+
"""Create examples with multiple PII types."""
|
| 263 |
+
# Select which PII types to include
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
selected_types = random.sample(pii_types, min(num_pii, len(pii_types)))
|
| 265 |
|
| 266 |
all_tokens = []
|
| 267 |
all_labels = []
|
| 268 |
|
| 269 |
+
# Add opening context
|
| 270 |
if random.random() > 0.3:
|
| 271 |
prefix = random.choice(self.context_prefix)
|
| 272 |
all_tokens.extend(prefix.split())
|
| 273 |
all_labels.extend(['O'] * len(prefix.split()))
|
| 274 |
|
| 275 |
+
# Add each PII entity
|
| 276 |
for i, pii_type in enumerate(selected_types):
|
| 277 |
+
# Add connector between PII entities
|
| 278 |
if i > 0 and random.random() > 0.5:
|
| 279 |
connector = random.choice(self.connectors)
|
| 280 |
all_tokens.extend(connector.strip().split())
|
| 281 |
all_labels.extend(['O'] * len(connector.strip().split()))
|
| 282 |
|
| 283 |
+
# Generate PII example
|
| 284 |
tokens, labels = self.create_synthetic_example(pii_type, add_context=False)
|
| 285 |
all_tokens.extend(tokens)
|
| 286 |
all_labels.extend(labels)
|
| 287 |
|
| 288 |
+
# Add closing context
|
| 289 |
if random.random() > 0.3:
|
| 290 |
suffix = random.choice(self.context_suffix)
|
| 291 |
all_tokens.extend(suffix.split())
|
|
|
|
| 295 |
|
| 296 |
def _add_context(self, sentence):
|
| 297 |
"""Add context phrases to make text more natural."""
|
| 298 |
+
# Randomly add prefix
|
| 299 |
if random.random() > 0.5:
|
| 300 |
sentence = random.choice(self.context_prefix) + " " + sentence
|
| 301 |
+
# Randomly add suffix
|
| 302 |
if random.random() > 0.5:
|
| 303 |
sentence = sentence + " " + random.choice(self.context_suffix)
|
| 304 |
return sentence
|
| 305 |
|
| 306 |
def _tokenize_and_label(self, sentence, pii_value, pii_type):
|
| 307 |
+
"""Tokenize sentence and apply BIO labels for PII."""
|
| 308 |
+
# Split sentence into tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
tokens = sentence.split()
|
| 310 |
labels = ['O'] * len(tokens)
|
| 311 |
|
| 312 |
+
# Split PII value into tokens
|
| 313 |
pii_tokens = pii_value.split()
|
| 314 |
|
| 315 |
+
# Find where PII appears in the sentence
|
| 316 |
for i in range(len(tokens) - len(pii_tokens) + 1):
|
| 317 |
+
# Check if tokens match the PII value
|
| 318 |
if (tokens[i:i+len(pii_tokens)] == pii_tokens or
|
| 319 |
' '.join(tokens[i:i+len(pii_tokens)]).lower() == pii_value.lower()):
|
| 320 |
|
| 321 |
+
# Apply BIO tagging
|
| 322 |
+
labels[i] = f'B-{pii_type}' # Beginning
|
| 323 |
for j in range(1, len(pii_tokens)):
|
| 324 |
+
labels[i+j] = f'I-{pii_type}' # Inside
|
| 325 |
break
|
| 326 |
|
| 327 |
return tokens, labels
|
| 328 |
|
|
|
|
|
|
|
| 329 |
def augment_dataset(self, original_data, target_samples_per_class=1000, mix_ratio=0.3):
|
| 330 |
+
"""Augment dataset with synthetic examples to balance PII classes."""
|
| 331 |
+
# Check current distribution
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
label_counts = self._analyze_label_distribution(original_data)
|
| 333 |
print("\nOriginal label distribution:")
|
| 334 |
self._print_distribution(label_counts)
|
| 335 |
|
| 336 |
+
# Generate synthetic data
|
| 337 |
synthetic_tokens, synthetic_labels = self._generate_synthetic_data(
|
| 338 |
label_counts, target_samples_per_class, mix_ratio
|
| 339 |
)
|
| 340 |
|
| 341 |
+
# Add some non-PII examples for balance
|
| 342 |
synthetic_tokens, synthetic_labels = self._add_non_pii_examples(
|
| 343 |
synthetic_tokens, synthetic_labels
|
| 344 |
)
|
| 345 |
|
| 346 |
+
# Combine original and synthetic data
|
| 347 |
augmented_df = self._combine_and_shuffle(
|
| 348 |
original_data, synthetic_tokens, synthetic_labels
|
| 349 |
)
|
| 350 |
|
| 351 |
+
# Check new distribution
|
| 352 |
new_label_counts = self._analyze_label_distribution(augmented_df)
|
| 353 |
print("\nAugmented label distribution:")
|
| 354 |
self._print_distribution(new_label_counts)
|
|
|
|
| 359 |
"""Analyze the distribution of PII labels in the dataset."""
|
| 360 |
label_counts = Counter()
|
| 361 |
|
| 362 |
+
# Count each PII type
|
| 363 |
for labels in data['labels']:
|
| 364 |
for label in labels:
|
| 365 |
if label != 'O':
|
| 366 |
+
# Remove B- or I- prefix to get base label
|
| 367 |
base_label = label.split('-')[1] if '-' in label else label
|
| 368 |
label_counts[base_label] += 1
|
| 369 |
|
|
|
|
| 372 |
def _print_distribution(self, label_counts):
|
| 373 |
"""Print label distribution statistics."""
|
| 374 |
total = sum(label_counts.values())
|
| 375 |
+
# Print each label count and percentage
|
| 376 |
for label, count in label_counts.most_common():
|
| 377 |
percentage = (count / total * 100) if total > 0 else 0
|
| 378 |
print(f" {label:15} : {count:6,} ({percentage:5.2f}%)")
|
|
|
|
| 382 |
synthetic_tokens = []
|
| 383 |
synthetic_labels = []
|
| 384 |
|
| 385 |
+
# Generate examples for each PII type
|
| 386 |
for pii_type in self.templates.keys():
|
| 387 |
current_count = label_counts.get(pii_type, 0)
|
| 388 |
needed = max(0, target_samples - current_count)
|
|
|
|
| 392 |
|
| 393 |
print(f"\nGenerating {needed} synthetic examples for {pii_type}")
|
| 394 |
|
| 395 |
+
# Generate single PII examples
|
| 396 |
single_count = int(needed * (1 - mix_ratio))
|
| 397 |
for _ in range(single_count):
|
| 398 |
tokens, labels = self.create_synthetic_example(pii_type)
|
| 399 |
synthetic_tokens.append(tokens)
|
| 400 |
synthetic_labels.append(labels)
|
| 401 |
|
| 402 |
+
# Generate mixed PII examples
|
| 403 |
mixed_count = int(needed * mix_ratio)
|
| 404 |
for _ in range(mixed_count):
|
| 405 |
+
# Make sure current PII type is included
|
| 406 |
other_types = [t for t in self.templates.keys() if t != pii_type]
|
| 407 |
selected_types = [pii_type] + random.sample(
|
| 408 |
other_types, min(1, len(other_types))
|
|
|
|
| 416 |
|
| 417 |
def _add_non_pii_examples(self, synthetic_tokens, synthetic_labels):
|
| 418 |
"""Add examples without PII (all 'O' labels) for balance."""
|
| 419 |
+
# Add 10% non-PII examples
|
| 420 |
num_non_pii = int(len(synthetic_tokens) * 0.1)
|
| 421 |
|
| 422 |
for _ in range(num_non_pii):
|
|
|
|
| 432 |
|
| 433 |
def _combine_and_shuffle(self, original_data, synthetic_tokens, synthetic_labels):
|
| 434 |
"""Combine original and synthetic data, then shuffle."""
|
| 435 |
+
# Merge all data
|
| 436 |
all_tokens = original_data['tokens'].tolist() + synthetic_tokens
|
| 437 |
all_labels = original_data['labels'].tolist() + synthetic_labels
|
| 438 |
|
| 439 |
+
# Create new dataframe
|
| 440 |
augmented_data = pd.DataFrame({
|
| 441 |
'tokens': all_tokens,
|
| 442 |
'labels': all_labels
|
| 443 |
})
|
| 444 |
|
| 445 |
+
# Shuffle the data
|
| 446 |
augmented_data = augmented_data.sample(frac=1, random_state=42).reset_index(drop=True)
|
| 447 |
|
| 448 |
print(f"\nTotal augmented samples: {len(augmented_data):,}")
|
|
|
|
| 450 |
return augmented_data
|
| 451 |
|
| 452 |
def calculate_class_weights(data, label_vocab):
|
| 453 |
+
"""Calculate class weights for balanced loss function."""
|
| 454 |
+
# Count occurrences of each label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
label_counts = Counter()
|
| 456 |
|
| 457 |
for labels in data['labels']:
|
|
|
|
| 459 |
label_id = label_vocab.word2idx.get(label.lower(), 0)
|
| 460 |
label_counts[label_id] += 1
|
| 461 |
|
| 462 |
+
# Calculate weights based on inverse frequency
|
| 463 |
total_samples = sum(label_counts.values())
|
| 464 |
num_classes = len(label_vocab)
|
| 465 |
|
|
|
|
| 467 |
|
| 468 |
for class_id, count in label_counts.items():
|
| 469 |
if count > 0:
|
| 470 |
+
# Inverse frequency weighting
|
| 471 |
weights[class_id] = total_samples / (num_classes * count)
|
| 472 |
|
| 473 |
+
# Normalize the weights
|
| 474 |
weights = weights / weights.sum() * num_classes
|
| 475 |
|
| 476 |
+
# Prevent extreme weights
|
| 477 |
weights = torch.clamp(weights, min=0.1, max=10.0)
|
| 478 |
|
| 479 |
+
# Don't weight padding tokens
|
| 480 |
weights[0] = 0.0
|
| 481 |
|
| 482 |
return weights
|
| 483 |
|
| 484 |
if __name__ == '__main__':
|
| 485 |
"""Example usage of the augmentation module."""
|
| 486 |
+
# Load original training data
|
| 487 |
print("Loading original training data...")
|
| 488 |
original_data = pd.read_json('train.json')
|
| 489 |
print(f"Original dataset size: {len(original_data):,}")
|
| 490 |
|
| 491 |
+
# Create augmenter instance
|
| 492 |
augmenter = PIIDataAugmenter(seed=42)
|
| 493 |
|
| 494 |
+
# Run augmentation
|
| 495 |
print("\n" + "="*60)
|
| 496 |
print("Starting data augmentation...")
|
| 497 |
print("="*60)
|
|
|
|
| 502 |
mix_ratio=0.3
|
| 503 |
)
|
| 504 |
|
| 505 |
+
# Save the augmented dataset
|
| 506 |
output_path = './train_augmented.json'
|
| 507 |
augmented_data.to_json(output_path, orient='records', lines=True)
|
| 508 |
print(f"\nSaved augmented data to {output_path}")
|
lstm.py
CHANGED
|
@@ -12,28 +12,28 @@ class LSTMCell(nn.Module):
|
|
| 12 |
self.input_size = input_size
|
| 13 |
self.hidden_size = hidden_size
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
# Input gate
|
| 17 |
self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
|
| 18 |
self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
|
| 19 |
self.b_i = nn.Parameter(torch.Tensor(hidden_size))
|
| 20 |
|
| 21 |
-
# Forget gate
|
| 22 |
self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
|
| 23 |
self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
|
| 24 |
self.b_f = nn.Parameter(torch.Tensor(hidden_size))
|
| 25 |
|
| 26 |
-
#
|
| 27 |
self.W_in = nn.Parameter(torch.Tensor(input_size, hidden_size))
|
| 28 |
self.W_hn = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
|
| 29 |
self.b_n = nn.Parameter(torch.Tensor(hidden_size))
|
| 30 |
|
| 31 |
-
# Output gate
|
| 32 |
self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
|
| 33 |
self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
|
| 34 |
self.b_o = nn.Parameter(torch.Tensor(hidden_size))
|
| 35 |
|
| 36 |
-
# Initialize
|
| 37 |
for name, param in self.named_parameters():
|
| 38 |
if 'W_' in name:
|
| 39 |
nn.init.xavier_uniform_(param)
|
|
@@ -41,35 +41,26 @@ class LSTMCell(nn.Module):
|
|
| 41 |
nn.init.zeros_(param)
|
| 42 |
|
| 43 |
def forward(self, input: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
input: input at current time step [batch_size, input_size]
|
| 48 |
-
states: tuple of (hidden_state, cell_state) from previous time step
|
| 49 |
-
both with shape [batch_size, hidden_size]
|
| 50 |
-
Returns:
|
| 51 |
-
new_hidden: updated hidden state [batch_size, hidden_size]
|
| 52 |
-
new_cell: updated cell state [batch_size, hidden_size]
|
| 53 |
-
"""
|
| 54 |
hidden, cell = states
|
| 55 |
|
| 56 |
-
#
|
| 57 |
-
# Forget gate: f_t = sigmoid(W_if @ x_t + W_hf @ h_{t-1} + b_f)
|
| 58 |
forget_gate = torch.sigmoid(torch.mm(input, self.W_if) + torch.mm(hidden, self.W_hf) + self.b_f)
|
| 59 |
|
| 60 |
-
#
|
| 61 |
input_gate = torch.sigmoid(torch.mm(input, self.W_ii) + torch.mm(hidden, self.W_hi) + self.b_i)
|
| 62 |
|
| 63 |
-
#
|
| 64 |
candidate = torch.tanh(torch.mm(input, self.W_in) + torch.mm(hidden, self.W_hn) + self.b_n)
|
| 65 |
|
| 66 |
-
#
|
| 67 |
output_gate = torch.sigmoid(torch.mm(input, self.W_io) + torch.mm(hidden, self.W_ho) + self.b_o)
|
| 68 |
|
| 69 |
-
# Update cell state
|
| 70 |
new_cell = forget_gate * cell + input_gate * candidate
|
| 71 |
|
| 72 |
-
#
|
| 73 |
new_hidden = output_gate * torch.tanh(new_cell)
|
| 74 |
|
| 75 |
return new_hidden, new_cell
|
|
@@ -87,27 +78,31 @@ class BidirectionalLSTM(nn.Module):
|
|
| 87 |
self.batch_first = batch_first
|
| 88 |
self.dropout = dropout if num_layers > 1 else 0.0
|
| 89 |
|
| 90 |
-
# Create forward and backward cells for each layer
|
| 91 |
self.forward_cells = nn.ModuleList()
|
| 92 |
self.backward_cells = nn.ModuleList()
|
| 93 |
self.dropout_layers = nn.ModuleList() if self.dropout > 0 else None
|
| 94 |
|
| 95 |
for layer in range(num_layers):
|
| 96 |
-
#
|
| 97 |
layer_input_size = input_size if layer == 0 else hidden_size * 2
|
| 98 |
|
|
|
|
| 99 |
self.forward_cells.append(LSTMCell(layer_input_size, hidden_size))
|
| 100 |
self.backward_cells.append(LSTMCell(layer_input_size, hidden_size))
|
| 101 |
|
|
|
|
| 102 |
if self.dropout > 0 and layer < num_layers - 1:
|
| 103 |
self.dropout_layers.append(nn.Dropout(dropout))
|
| 104 |
|
| 105 |
def forward(self, input, states=None, lengths=None):
|
| 106 |
-
#
|
| 107 |
is_packed = isinstance(input, PackedSequence)
|
| 108 |
if is_packed:
|
|
|
|
| 109 |
padded, lengths = pad_packed_sequence(input, batch_first=self.batch_first)
|
| 110 |
outputs, (h_n, c_n) = self._forward_unpacked(padded, states, lengths)
|
|
|
|
| 111 |
packed_out = pack_padded_sequence(
|
| 112 |
outputs, lengths,
|
| 113 |
batch_first=self.batch_first,
|
|
@@ -118,13 +113,15 @@ class BidirectionalLSTM(nn.Module):
|
|
| 118 |
return self._forward_unpacked(input, states, lengths)
|
| 119 |
|
| 120 |
def _forward_unpacked(self, input: torch.Tensor, states, lengths=None):
|
|
|
|
| 121 |
if not self.batch_first:
|
| 122 |
input = input.transpose(0, 1)
|
| 123 |
|
| 124 |
batch_size, seq_len, _ = input.size()
|
| 125 |
|
| 126 |
-
# Initialize states if not provided
|
| 127 |
if states is None:
|
|
|
|
| 128 |
h_t_forward = [input.new_zeros(batch_size, self.hidden_size)
|
| 129 |
for _ in range(self.num_layers)]
|
| 130 |
c_t_forward = [input.new_zeros(batch_size, self.hidden_size)
|
|
@@ -134,23 +131,24 @@ class BidirectionalLSTM(nn.Module):
|
|
| 134 |
c_t_backward = [input.new_zeros(batch_size, self.hidden_size)
|
| 135 |
for _ in range(self.num_layers)]
|
| 136 |
else:
|
|
|
|
| 137 |
h0, c0 = states
|
| 138 |
-
# h0 and c0 are [num_layers * 2, batch_size, hidden_size]
|
| 139 |
h_t_forward = []
|
| 140 |
c_t_forward = []
|
| 141 |
h_t_backward = []
|
| 142 |
c_t_backward = []
|
| 143 |
|
|
|
|
| 144 |
for layer in range(self.num_layers):
|
| 145 |
h_t_forward.append(h0[layer * 2])
|
| 146 |
c_t_forward.append(c0[layer * 2])
|
| 147 |
h_t_backward.append(h0[layer * 2 + 1])
|
| 148 |
c_t_backward.append(c0[layer * 2 + 1])
|
| 149 |
|
| 150 |
-
# Process through
|
| 151 |
layer_input = input
|
| 152 |
for layer_idx in range(self.num_layers):
|
| 153 |
-
#
|
| 154 |
forward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
|
| 155 |
for t in range(seq_len):
|
| 156 |
x = layer_input[:, t, :]
|
|
@@ -159,7 +157,7 @@ class BidirectionalLSTM(nn.Module):
|
|
| 159 |
c_t_forward[layer_idx] = c
|
| 160 |
forward_output[:, t, :] = h
|
| 161 |
|
| 162 |
-
#
|
| 163 |
backward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
|
| 164 |
for t in reversed(range(seq_len)):
|
| 165 |
x = layer_input[:, t, :]
|
|
@@ -168,19 +166,20 @@ class BidirectionalLSTM(nn.Module):
|
|
| 168 |
c_t_backward[layer_idx] = c
|
| 169 |
backward_output[:, t, :] = h
|
| 170 |
|
| 171 |
-
# Concatenate forward and backward
|
| 172 |
layer_output = torch.cat([forward_output, backward_output], dim=2)
|
| 173 |
|
| 174 |
-
# Apply dropout between layers
|
| 175 |
if self.dropout > 0 and layer_idx < self.num_layers - 1:
|
| 176 |
layer_output = self.dropout_layers[layer_idx](layer_output)
|
| 177 |
|
|
|
|
| 178 |
layer_input = layer_output
|
| 179 |
|
| 180 |
-
# Final output
|
| 181 |
outputs = layer_output
|
| 182 |
|
| 183 |
-
# Stack hidden and cell states
|
| 184 |
h_n = []
|
| 185 |
c_n = []
|
| 186 |
for layer in range(self.num_layers):
|
|
@@ -189,6 +188,7 @@ class BidirectionalLSTM(nn.Module):
|
|
| 189 |
h_n = torch.stack(h_n, dim=0)
|
| 190 |
c_n = torch.stack(c_n, dim=0)
|
| 191 |
|
|
|
|
| 192 |
if not self.batch_first:
|
| 193 |
outputs = outputs.transpose(0, 1)
|
| 194 |
|
|
@@ -209,7 +209,7 @@ class LSTM(nn.Module):
|
|
| 209 |
self.hidden_size = hidden_size
|
| 210 |
self.num_layers = num_layers
|
| 211 |
|
| 212 |
-
# Embedding layer
|
| 213 |
self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
|
| 214 |
self.embed_dropout = nn.Dropout(dropout)
|
| 215 |
|
|
@@ -222,8 +222,8 @@ class LSTM(nn.Module):
|
|
| 222 |
dropout=dropout if num_layers > 1 else 0.0
|
| 223 |
)
|
| 224 |
|
| 225 |
-
# Output
|
| 226 |
-
lstm_output_size = hidden_size * 2 # bidirectional
|
| 227 |
self.fc = nn.Linear(lstm_output_size, num_classes)
|
| 228 |
self.output_dropout = nn.Dropout(dropout)
|
| 229 |
|
|
@@ -236,11 +236,11 @@ class LSTM(nn.Module):
|
|
| 236 |
Returns:
|
| 237 |
logits: class predictions [batch_size, seq_len, num_classes]
|
| 238 |
"""
|
| 239 |
-
#
|
| 240 |
-
embedded = self.embedding(input_ids)
|
| 241 |
embedded = self.embed_dropout(embedded)
|
| 242 |
|
| 243 |
-
# Pack if lengths provided
|
| 244 |
if lengths is not None:
|
| 245 |
packed_embedded = pack_padded_sequence(
|
| 246 |
embedded, lengths.cpu(),
|
|
@@ -248,40 +248,27 @@ class LSTM(nn.Module):
|
|
| 248 |
enforce_sorted=False
|
| 249 |
)
|
| 250 |
lstm_out, _ = self.lstm(packed_embedded)
|
|
|
|
| 251 |
lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
|
| 252 |
else:
|
|
|
|
| 253 |
lstm_out, _ = self.lstm(embedded)
|
| 254 |
|
| 255 |
-
# Apply dropout and
|
| 256 |
lstm_out = self.output_dropout(lstm_out)
|
| 257 |
-
logits = self.fc(lstm_out)
|
| 258 |
|
| 259 |
return logits
|
| 260 |
|
| 261 |
def create_lstm_pii_model(vocab_size: int, num_classes: int, d_model: int = 256,
|
| 262 |
num_heads: int = 8, d_ff: int = 512, num_layers: int = 4,
|
| 263 |
dropout: float = 0.1, max_len: int = 512):
|
| 264 |
-
"""
|
| 265 |
-
Create
|
| 266 |
-
Note: num_heads and d_ff are ignored (kept for compatibility with transformer interface)
|
| 267 |
-
|
| 268 |
-
Args:
|
| 269 |
-
vocab_size: size of vocabulary
|
| 270 |
-
num_classes: number of output classes (PII tags)
|
| 271 |
-
d_model: hidden dimension size
|
| 272 |
-
num_heads: ignored (for compatibility)
|
| 273 |
-
d_ff: ignored (for compatibility)
|
| 274 |
-
num_layers: number of LSTM layers
|
| 275 |
-
dropout: dropout rate
|
| 276 |
-
max_len: maximum sequence length
|
| 277 |
-
|
| 278 |
-
Returns:
|
| 279 |
-
LSTM
|
| 280 |
-
"""
|
| 281 |
return LSTM(
|
| 282 |
vocab_size=vocab_size,
|
| 283 |
num_classes=num_classes,
|
| 284 |
-
embed_size=d_model // 2,
|
| 285 |
hidden_size=d_model,
|
| 286 |
num_layers=num_layers,
|
| 287 |
dropout=dropout,
|
|
|
|
| 12 |
self.input_size = input_size
|
| 13 |
self.hidden_size = hidden_size
|
| 14 |
|
| 15 |
+
# Weight matrices and biases for each gate
|
| 16 |
+
# Input gate parameters
|
| 17 |
self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
|
| 18 |
self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
|
| 19 |
self.b_i = nn.Parameter(torch.Tensor(hidden_size))
|
| 20 |
|
| 21 |
+
# Forget gate parameters
|
| 22 |
self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
|
| 23 |
self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
|
| 24 |
self.b_f = nn.Parameter(torch.Tensor(hidden_size))
|
| 25 |
|
| 26 |
+
# Candidate values parameters
|
| 27 |
self.W_in = nn.Parameter(torch.Tensor(input_size, hidden_size))
|
| 28 |
self.W_hn = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
|
| 29 |
self.b_n = nn.Parameter(torch.Tensor(hidden_size))
|
| 30 |
|
| 31 |
+
# Output gate parameters
|
| 32 |
self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
|
| 33 |
self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
|
| 34 |
self.b_o = nn.Parameter(torch.Tensor(hidden_size))
|
| 35 |
|
| 36 |
+
# Initialize weights using Xavier initialization
|
| 37 |
for name, param in self.named_parameters():
|
| 38 |
if 'W_' in name:
|
| 39 |
nn.init.xavier_uniform_(param)
|
|
|
|
| 41 |
nn.init.zeros_(param)
|
| 42 |
|
| 43 |
def forward(self, input: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
+
"""Forward pass for one time step"""
|
| 45 |
+
# Unpack previous states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
hidden, cell = states
|
| 47 |
|
| 48 |
+
# Calculate forget gate - decides what to forget from previous cell state
|
|
|
|
| 49 |
forget_gate = torch.sigmoid(torch.mm(input, self.W_if) + torch.mm(hidden, self.W_hf) + self.b_f)
|
| 50 |
|
| 51 |
+
# Calculate input gate - decides what new information to store
|
| 52 |
input_gate = torch.sigmoid(torch.mm(input, self.W_ii) + torch.mm(hidden, self.W_hi) + self.b_i)
|
| 53 |
|
| 54 |
+
# Calculate candidate values - new information that could be added
|
| 55 |
candidate = torch.tanh(torch.mm(input, self.W_in) + torch.mm(hidden, self.W_hn) + self.b_n)
|
| 56 |
|
| 57 |
+
# Calculate output gate - decides what parts of cell state to output
|
| 58 |
output_gate = torch.sigmoid(torch.mm(input, self.W_io) + torch.mm(hidden, self.W_ho) + self.b_o)
|
| 59 |
|
| 60 |
+
# Update cell state by forgetting old info and adding new info
|
| 61 |
new_cell = forget_gate * cell + input_gate * candidate
|
| 62 |
|
| 63 |
+
# Generate new hidden state based on filtered cell state
|
| 64 |
new_hidden = output_gate * torch.tanh(new_cell)
|
| 65 |
|
| 66 |
return new_hidden, new_cell
|
|
|
|
| 78 |
self.batch_first = batch_first
|
| 79 |
self.dropout = dropout if num_layers > 1 else 0.0
|
| 80 |
|
| 81 |
+
# Create forward and backward LSTM cells for each layer
|
| 82 |
self.forward_cells = nn.ModuleList()
|
| 83 |
self.backward_cells = nn.ModuleList()
|
| 84 |
self.dropout_layers = nn.ModuleList() if self.dropout > 0 else None
|
| 85 |
|
| 86 |
for layer in range(num_layers):
|
| 87 |
+
# First layer takes input_size, others take concatenated bidirectional output
|
| 88 |
layer_input_size = input_size if layer == 0 else hidden_size * 2
|
| 89 |
|
| 90 |
+
# Add forward and backward cells for this layer
|
| 91 |
self.forward_cells.append(LSTMCell(layer_input_size, hidden_size))
|
| 92 |
self.backward_cells.append(LSTMCell(layer_input_size, hidden_size))
|
| 93 |
|
| 94 |
+
# Add dropout between layers (except after last layer)
|
| 95 |
if self.dropout > 0 and layer < num_layers - 1:
|
| 96 |
self.dropout_layers.append(nn.Dropout(dropout))
|
| 97 |
|
| 98 |
def forward(self, input, states=None, lengths=None):
|
| 99 |
+
# Check if input is packed sequence
|
| 100 |
is_packed = isinstance(input, PackedSequence)
|
| 101 |
if is_packed:
|
| 102 |
+
# Unpack for processing
|
| 103 |
padded, lengths = pad_packed_sequence(input, batch_first=self.batch_first)
|
| 104 |
outputs, (h_n, c_n) = self._forward_unpacked(padded, states, lengths)
|
| 105 |
+
# Pack output back
|
| 106 |
packed_out = pack_padded_sequence(
|
| 107 |
outputs, lengths,
|
| 108 |
batch_first=self.batch_first,
|
|
|
|
| 113 |
return self._forward_unpacked(input, states, lengths)
|
| 114 |
|
| 115 |
def _forward_unpacked(self, input: torch.Tensor, states, lengths=None):
|
| 116 |
+
# Convert to batch-first if needed
|
| 117 |
if not self.batch_first:
|
| 118 |
input = input.transpose(0, 1)
|
| 119 |
|
| 120 |
batch_size, seq_len, _ = input.size()
|
| 121 |
|
| 122 |
+
# Initialize hidden and cell states if not provided
|
| 123 |
if states is None:
|
| 124 |
+
# Create zero states for each layer and direction
|
| 125 |
h_t_forward = [input.new_zeros(batch_size, self.hidden_size)
|
| 126 |
for _ in range(self.num_layers)]
|
| 127 |
c_t_forward = [input.new_zeros(batch_size, self.hidden_size)
|
|
|
|
| 131 |
c_t_backward = [input.new_zeros(batch_size, self.hidden_size)
|
| 132 |
for _ in range(self.num_layers)]
|
| 133 |
else:
|
| 134 |
+
# Unpack provided states
|
| 135 |
h0, c0 = states
|
|
|
|
| 136 |
h_t_forward = []
|
| 137 |
c_t_forward = []
|
| 138 |
h_t_backward = []
|
| 139 |
c_t_backward = []
|
| 140 |
|
| 141 |
+
# Separate forward and backward states for each layer
|
| 142 |
for layer in range(self.num_layers):
|
| 143 |
h_t_forward.append(h0[layer * 2])
|
| 144 |
c_t_forward.append(c0[layer * 2])
|
| 145 |
h_t_backward.append(h0[layer * 2 + 1])
|
| 146 |
c_t_backward.append(c0[layer * 2 + 1])
|
| 147 |
|
| 148 |
+
# Process through each layer
|
| 149 |
layer_input = input
|
| 150 |
for layer_idx in range(self.num_layers):
|
| 151 |
+
# Process forward direction
|
| 152 |
forward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
|
| 153 |
for t in range(seq_len):
|
| 154 |
x = layer_input[:, t, :]
|
|
|
|
| 157 |
c_t_forward[layer_idx] = c
|
| 158 |
forward_output[:, t, :] = h
|
| 159 |
|
| 160 |
+
# Process backward direction
|
| 161 |
backward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
|
| 162 |
for t in reversed(range(seq_len)):
|
| 163 |
x = layer_input[:, t, :]
|
|
|
|
| 166 |
c_t_backward[layer_idx] = c
|
| 167 |
backward_output[:, t, :] = h
|
| 168 |
|
| 169 |
+
# Concatenate forward and backward outputs
|
| 170 |
layer_output = torch.cat([forward_output, backward_output], dim=2)
|
| 171 |
|
| 172 |
+
# Apply dropout between layers
|
| 173 |
if self.dropout > 0 and layer_idx < self.num_layers - 1:
|
| 174 |
layer_output = self.dropout_layers[layer_idx](layer_output)
|
| 175 |
|
| 176 |
+
# Use this layer's output as next layer's input
|
| 177 |
layer_input = layer_output
|
| 178 |
|
| 179 |
+
# Final output is the last layer's output
|
| 180 |
outputs = layer_output
|
| 181 |
|
| 182 |
+
# Stack final hidden and cell states for all layers
|
| 183 |
h_n = []
|
| 184 |
c_n = []
|
| 185 |
for layer in range(self.num_layers):
|
|
|
|
| 188 |
h_n = torch.stack(h_n, dim=0)
|
| 189 |
c_n = torch.stack(c_n, dim=0)
|
| 190 |
|
| 191 |
+
# Convert back if not batch-first
|
| 192 |
if not self.batch_first:
|
| 193 |
outputs = outputs.transpose(0, 1)
|
| 194 |
|
|
|
|
| 209 |
self.hidden_size = hidden_size
|
| 210 |
self.num_layers = num_layers
|
| 211 |
|
| 212 |
+
# Embedding layer to convert tokens to vectors
|
| 213 |
self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
|
| 214 |
self.embed_dropout = nn.Dropout(dropout)
|
| 215 |
|
|
|
|
| 222 |
dropout=dropout if num_layers > 1 else 0.0
|
| 223 |
)
|
| 224 |
|
| 225 |
+
# Output layer to predict PII labels
|
| 226 |
+
lstm_output_size = hidden_size * 2 # doubled for bidirectional
|
| 227 |
self.fc = nn.Linear(lstm_output_size, num_classes)
|
| 228 |
self.output_dropout = nn.Dropout(dropout)
|
| 229 |
|
|
|
|
| 236 |
Returns:
|
| 237 |
logits: class predictions [batch_size, seq_len, num_classes]
|
| 238 |
"""
|
| 239 |
+
# Convert token ids to embeddings
|
| 240 |
+
embedded = self.embedding(input_ids)
|
| 241 |
embedded = self.embed_dropout(embedded)
|
| 242 |
|
| 243 |
+
# Pack sequences for efficient processing if lengths provided
|
| 244 |
if lengths is not None:
|
| 245 |
packed_embedded = pack_padded_sequence(
|
| 246 |
embedded, lengths.cpu(),
|
|
|
|
| 248 |
enforce_sorted=False
|
| 249 |
)
|
| 250 |
lstm_out, _ = self.lstm(packed_embedded)
|
| 251 |
+
# Unpack the output
|
| 252 |
lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
|
| 253 |
else:
|
| 254 |
+
# Process without packing
|
| 255 |
lstm_out, _ = self.lstm(embedded)
|
| 256 |
|
| 257 |
+
# Apply dropout and get final predictions
|
| 258 |
lstm_out = self.output_dropout(lstm_out)
|
| 259 |
+
logits = self.fc(lstm_out)
|
| 260 |
|
| 261 |
return logits
|
| 262 |
|
| 263 |
def create_lstm_pii_model(vocab_size: int, num_classes: int, d_model: int = 256,
|
| 264 |
num_heads: int = 8, d_ff: int = 512, num_layers: int = 4,
|
| 265 |
dropout: float = 0.1, max_len: int = 512):
|
| 266 |
+
"""Create Bidirectional LSTM model for PII detection"""
|
| 267 |
+
# Create LSTM with appropriate dimensions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
return LSTM(
|
| 269 |
vocab_size=vocab_size,
|
| 270 |
num_classes=num_classes,
|
| 271 |
+
embed_size=d_model // 2,
|
| 272 |
hidden_size=d_model,
|
| 273 |
num_layers=num_layers,
|
| 274 |
dropout=dropout,
|
lstm_training.ipynb
CHANGED
|
@@ -42,7 +42,7 @@
|
|
| 42 |
},
|
| 43 |
{
|
| 44 |
"cell_type": "code",
|
| 45 |
-
"execution_count":
|
| 46 |
"id": "1207cd93",
|
| 47 |
"metadata": {
|
| 48 |
"execution": {
|
|
@@ -62,19 +62,23 @@
|
|
| 62 |
},
|
| 63 |
"outputs": [],
|
| 64 |
"source": [
|
|
|
|
| 65 |
"class Vocabulary:\n",
|
| 66 |
" \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
|
| 67 |
" def __init__(self, max_size=100000):\n",
|
|
|
|
| 68 |
" self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
|
| 69 |
" self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
|
| 70 |
" self.word_count = Counter()\n",
|
| 71 |
" self.max_size = max_size\n",
|
| 72 |
" \n",
|
| 73 |
" def add_sentence(self, sentence):\n",
|
|
|
|
| 74 |
" for word in sentence:\n",
|
| 75 |
" self.word_count[word.lower()] += 1\n",
|
| 76 |
" \n",
|
| 77 |
" def build(self):\n",
|
|
|
|
| 78 |
" most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
|
| 79 |
" for word, _ in most_common:\n",
|
| 80 |
" if word not in self.word2idx:\n",
|
|
@@ -86,15 +90,17 @@
|
|
| 86 |
" return len(self.word2idx)\n",
|
| 87 |
" \n",
|
| 88 |
" def encode(self, sentence):\n",
|
|
|
|
| 89 |
" return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
|
| 90 |
" \n",
|
| 91 |
" def decode(self, indices):\n",
|
|
|
|
| 92 |
" return [self.idx2word.get(idx, '<unk>') for idx in indices]"
|
| 93 |
]
|
| 94 |
},
|
| 95 |
{
|
| 96 |
"cell_type": "code",
|
| 97 |
-
"execution_count":
|
| 98 |
"id": "f4056292",
|
| 99 |
"metadata": {
|
| 100 |
"execution": {
|
|
@@ -114,6 +120,7 @@
|
|
| 114 |
},
|
| 115 |
"outputs": [],
|
| 116 |
"source": [
|
|
|
|
| 117 |
"class PIIDataset(Dataset):\n",
|
| 118 |
" \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
|
| 119 |
" def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
|
|
@@ -127,16 +134,16 @@
|
|
| 127 |
" return len(self.tokens)\n",
|
| 128 |
" \n",
|
| 129 |
" def __getitem__(self, idx):\n",
|
| 130 |
-
" # Add
|
| 131 |
" tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
|
| 132 |
" labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
|
| 133 |
" \n",
|
| 134 |
-
" # Truncate if too long\n",
|
| 135 |
" if len(tokens) > self.max_len:\n",
|
| 136 |
" tokens = tokens[:self.max_len-1] + ['<end>']\n",
|
| 137 |
" labels = labels[:self.max_len-1] + ['<end>']\n",
|
| 138 |
" \n",
|
| 139 |
-
" # Encode\n",
|
| 140 |
" token_ids = self.text_vocab.encode(tokens)\n",
|
| 141 |
" label_ids = self.label_vocab.encode(labels)\n",
|
| 142 |
" \n",
|
|
@@ -145,7 +152,7 @@
|
|
| 145 |
},
|
| 146 |
{
|
| 147 |
"cell_type": "code",
|
| 148 |
-
"execution_count":
|
| 149 |
"id": "499deba2",
|
| 150 |
"metadata": {
|
| 151 |
"execution": {
|
|
@@ -167,7 +174,9 @@
|
|
| 167 |
"source": [
|
| 168 |
"def collate_fn(batch):\n",
|
| 169 |
" \"\"\"Custom collate function for padding sequences\"\"\"\n",
|
|
|
|
| 170 |
" tokens, labels = zip(*batch)\n",
|
|
|
|
| 171 |
" tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
|
| 172 |
" labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
|
| 173 |
" return tokens_padded, labels_padded"
|
|
@@ -175,7 +184,7 @@
|
|
| 175 |
},
|
| 176 |
{
|
| 177 |
"cell_type": "code",
|
| 178 |
-
"execution_count":
|
| 179 |
"id": "7ade0505",
|
| 180 |
"metadata": {
|
| 181 |
"execution": {
|
|
@@ -195,6 +204,7 @@
|
|
| 195 |
},
|
| 196 |
"outputs": [],
|
| 197 |
"source": [
|
|
|
|
| 198 |
"class F1ScoreMetric:\n",
|
| 199 |
" \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
|
| 200 |
" def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
|
|
@@ -205,26 +215,32 @@
|
|
| 205 |
" self.reset()\n",
|
| 206 |
" \n",
|
| 207 |
" def reset(self):\n",
|
|
|
|
| 208 |
" self.true_positives = 0\n",
|
| 209 |
" self.false_positives = 0\n",
|
| 210 |
" self.false_negatives = 0\n",
|
| 211 |
" self.class_metrics = {}\n",
|
| 212 |
" \n",
|
| 213 |
" def update(self, predictions, targets):\n",
|
|
|
|
| 214 |
" mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
|
| 215 |
" o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
|
| 216 |
" \n",
|
|
|
|
| 217 |
" for class_id in range(1, self.num_classes):\n",
|
| 218 |
" if class_id == o_idx:\n",
|
| 219 |
" continue\n",
|
| 220 |
-
"
|
|
|
|
| 221 |
" pred_mask = (predictions == class_id) & mask\n",
|
| 222 |
" true_mask = (targets == class_id) & mask\n",
|
| 223 |
" \n",
|
|
|
|
| 224 |
" tp = ((pred_mask) & (true_mask)).sum().item()\n",
|
| 225 |
" fp = ((pred_mask) & (~true_mask)).sum().item()\n",
|
| 226 |
" fn = ((~pred_mask) & (true_mask)).sum().item()\n",
|
| 227 |
" \n",
|
|
|
|
| 228 |
" self.true_positives += tp\n",
|
| 229 |
" self.false_positives += fp\n",
|
| 230 |
" self.false_negatives += fn\n",
|
|
@@ -236,6 +252,7 @@
|
|
| 236 |
" self.class_metrics[class_id]['fn'] += fn\n",
|
| 237 |
" \n",
|
| 238 |
" def compute(self):\n",
|
|
|
|
| 239 |
" beta_squared = self.beta ** 2\n",
|
| 240 |
" precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
|
| 241 |
" recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
|
|
@@ -243,6 +260,7 @@
|
|
| 243 |
" return f1\n",
|
| 244 |
" \n",
|
| 245 |
" def get_class_metrics(self):\n",
|
|
|
|
| 246 |
" results = {}\n",
|
| 247 |
" for class_id, metrics in self.class_metrics.items():\n",
|
| 248 |
" if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
|
|
@@ -261,7 +279,7 @@
|
|
| 261 |
},
|
| 262 |
{
|
| 263 |
"cell_type": "code",
|
| 264 |
-
"execution_count":
|
| 265 |
"id": "361b5505",
|
| 266 |
"metadata": {
|
| 267 |
"execution": {
|
|
@@ -281,6 +299,7 @@
|
|
| 281 |
},
|
| 282 |
"outputs": [],
|
| 283 |
"source": [
|
|
|
|
| 284 |
"class FocalLoss(nn.Module):\n",
|
| 285 |
" \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
|
| 286 |
" def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
|
|
@@ -291,6 +310,7 @@
|
|
| 291 |
" self.ignore_index = ignore_index\n",
|
| 292 |
" \n",
|
| 293 |
" def forward(self, inputs, targets):\n",
|
|
|
|
| 294 |
" ce_loss = nn.functional.cross_entropy(\n",
|
| 295 |
" inputs, targets, \n",
|
| 296 |
" weight=self.alpha, \n",
|
|
@@ -298,9 +318,11 @@
|
|
| 298 |
" ignore_index=self.ignore_index\n",
|
| 299 |
" )\n",
|
| 300 |
" \n",
|
|
|
|
| 301 |
" pt = torch.exp(-ce_loss)\n",
|
| 302 |
" focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
|
| 303 |
" \n",
|
|
|
|
| 304 |
" if self.reduction == 'mean':\n",
|
| 305 |
" return focal_loss.mean()\n",
|
| 306 |
" elif self.reduction == 'sum':\n",
|
|
@@ -311,7 +333,7 @@
|
|
| 311 |
},
|
| 312 |
{
|
| 313 |
"cell_type": "code",
|
| 314 |
-
"execution_count":
|
| 315 |
"id": "1de646e9",
|
| 316 |
"metadata": {
|
| 317 |
"execution": {
|
|
@@ -337,8 +359,10 @@
|
|
| 337 |
" total_loss = 0\n",
|
| 338 |
" f1_metric.reset()\n",
|
| 339 |
" \n",
|
|
|
|
| 340 |
" progress_bar = tqdm(dataloader, desc='Training')\n",
|
| 341 |
" for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
|
|
|
|
| 342 |
" tokens = tokens.to(device)\n",
|
| 343 |
" labels = labels.to(device)\n",
|
| 344 |
" \n",
|
|
@@ -353,6 +377,7 @@
|
|
| 353 |
" # Calculate loss and backward pass\n",
|
| 354 |
" loss = criterion(outputs_flat, labels_flat)\n",
|
| 355 |
" loss.backward()\n",
|
|
|
|
| 356 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
|
| 357 |
" optimizer.step()\n",
|
| 358 |
" \n",
|
|
@@ -372,7 +397,7 @@
|
|
| 372 |
},
|
| 373 |
{
|
| 374 |
"cell_type": "code",
|
| 375 |
-
"execution_count":
|
| 376 |
"id": "d1ce3b0f",
|
| 377 |
"metadata": {
|
| 378 |
"execution": {
|
|
@@ -398,8 +423,10 @@
|
|
| 398 |
" total_loss = 0\n",
|
| 399 |
" f1_metric.reset()\n",
|
| 400 |
" \n",
|
|
|
|
| 401 |
" with torch.no_grad():\n",
|
| 402 |
" for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
|
|
|
|
| 403 |
" tokens = tokens.to(device)\n",
|
| 404 |
" labels = labels.to(device)\n",
|
| 405 |
" \n",
|
|
@@ -421,7 +448,7 @@
|
|
| 421 |
},
|
| 422 |
{
|
| 423 |
"cell_type": "code",
|
| 424 |
-
"execution_count":
|
| 425 |
"id": "da3ff80c",
|
| 426 |
"metadata": {
|
| 427 |
"execution": {
|
|
@@ -445,6 +472,7 @@
|
|
| 445 |
" \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
|
| 446 |
" sample_weights = []\n",
|
| 447 |
" \n",
|
|
|
|
| 448 |
" for idx in range(len(dataset)):\n",
|
| 449 |
" _, labels = dataset[idx]\n",
|
| 450 |
" \n",
|
|
@@ -453,12 +481,14 @@
|
|
| 453 |
" for label_id in labels:\n",
|
| 454 |
" if label_id > 3: # Skip special tokens\n",
|
| 455 |
" label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
|
|
|
|
| 456 |
" if label_name != 'o' and 'B-' in label_name:\n",
|
| 457 |
" min_weight = 10.0\n",
|
| 458 |
" break\n",
|
| 459 |
" \n",
|
| 460 |
" sample_weights.append(min_weight)\n",
|
| 461 |
" \n",
|
|
|
|
| 462 |
" sampler = WeightedRandomSampler(\n",
|
| 463 |
" weights=sample_weights,\n",
|
| 464 |
" num_samples=len(sample_weights),\n",
|
|
@@ -470,7 +500,7 @@
|
|
| 470 |
},
|
| 471 |
{
|
| 472 |
"cell_type": "code",
|
| 473 |
-
"execution_count":
|
| 474 |
"id": "69b37e68",
|
| 475 |
"metadata": {
|
| 476 |
"execution": {
|
|
@@ -493,11 +523,14 @@
|
|
| 493 |
"def print_label_distribution(data, title=\"Label Distribution\"):\n",
|
| 494 |
" \"\"\"Print label distribution statistics\"\"\"\n",
|
| 495 |
" label_counts = Counter()\n",
|
|
|
|
|
|
|
| 496 |
" for label_seq in data.labels:\n",
|
| 497 |
" for label in label_seq:\n",
|
| 498 |
" if label not in ['<pad>', '<start>', '<end>']:\n",
|
| 499 |
" label_counts[label] += 1\n",
|
| 500 |
" \n",
|
|
|
|
| 501 |
" print(f\"\\n{title}:\")\n",
|
| 502 |
" print(\"-\" * 50)\n",
|
| 503 |
" total = sum(label_counts.values())\n",
|
|
@@ -510,7 +543,7 @@
|
|
| 510 |
},
|
| 511 |
{
|
| 512 |
"cell_type": "code",
|
| 513 |
-
"execution_count":
|
| 514 |
"id": "4b1b4f86",
|
| 515 |
"metadata": {
|
| 516 |
"execution": {
|
|
@@ -534,7 +567,7 @@
|
|
| 534 |
" \"\"\"Save model and all necessary components for deployment\"\"\"\n",
|
| 535 |
" os.makedirs(save_dir, exist_ok=True)\n",
|
| 536 |
" \n",
|
| 537 |
-
" # Save model
|
| 538 |
" model_path = os.path.join(save_dir, 'pii_lstm_model.pt')\n",
|
| 539 |
" torch.save(model.state_dict(), model_path)\n",
|
| 540 |
" \n",
|
|
@@ -560,7 +593,7 @@
|
|
| 560 |
},
|
| 561 |
{
|
| 562 |
"cell_type": "code",
|
| 563 |
-
"execution_count":
|
| 564 |
"id": "31d2f1b1",
|
| 565 |
"metadata": {
|
| 566 |
"execution": {
|
|
@@ -596,7 +629,7 @@
|
|
| 596 |
" data = pd.read_json(data_path, lines=True)\n",
|
| 597 |
" print(f\"Total samples: {len(data)}\")\n",
|
| 598 |
" \n",
|
| 599 |
-
" #
|
| 600 |
" print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
|
| 601 |
" \n",
|
| 602 |
" # Build vocabularies\n",
|
|
@@ -609,6 +642,7 @@
|
|
| 609 |
" for labels in data.labels:\n",
|
| 610 |
" label_vocab.add_sentence(labels)\n",
|
| 611 |
" \n",
|
|
|
|
| 612 |
" text_vocab.build()\n",
|
| 613 |
" label_vocab.build()\n",
|
| 614 |
" \n",
|
|
@@ -616,11 +650,11 @@
|
|
| 616 |
" print(f\" - Text vocabulary: {len(text_vocab):,}\")\n",
|
| 617 |
" print(f\" - Label vocabulary: {len(label_vocab)}\")\n",
|
| 618 |
" \n",
|
| 619 |
-
" # Calculate class weights\n",
|
| 620 |
" class_weights = calculate_class_weights(data, label_vocab)\n",
|
| 621 |
" class_weights = class_weights.to(device)\n",
|
| 622 |
" \n",
|
| 623 |
-
" # Split data\n",
|
| 624 |
" X_train, X_val, y_train, y_val = train_test_split(\n",
|
| 625 |
" data.tokens.tolist(),\n",
|
| 626 |
" data.labels.tolist(),\n",
|
|
@@ -632,14 +666,15 @@
|
|
| 632 |
" print(f\" - Train samples: {len(X_train):,}\")\n",
|
| 633 |
" print(f\" - Validation samples: {len(X_val):,}\")\n",
|
| 634 |
" \n",
|
| 635 |
-
" # Create datasets
|
| 636 |
" max_seq_len = 512\n",
|
| 637 |
" train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
|
| 638 |
" val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
|
| 639 |
" \n",
|
| 640 |
-
" #
|
| 641 |
" train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
|
| 642 |
" \n",
|
|
|
|
| 643 |
" train_loader = DataLoader(\n",
|
| 644 |
" train_dataset, \n",
|
| 645 |
" batch_size=batch_size,\n",
|
|
@@ -668,7 +703,7 @@
|
|
| 668 |
" 'max_len': max_seq_len\n",
|
| 669 |
" }\n",
|
| 670 |
" \n",
|
| 671 |
-
" # Create model\n",
|
| 672 |
" print(\"\\nCreating LSTM model...\")\n",
|
| 673 |
" model = create_lstm_pii_model(**model_config).to(device)\n",
|
| 674 |
" print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
|
|
@@ -701,11 +736,11 @@
|
|
| 701 |
" min_lr=1e-6\n",
|
| 702 |
" )\n",
|
| 703 |
" \n",
|
| 704 |
-
" #
|
| 705 |
" f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
|
| 706 |
" f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
|
| 707 |
" \n",
|
| 708 |
-
" # Training
|
| 709 |
" train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
|
| 710 |
" best_val_f1 = 0\n",
|
| 711 |
" patience = 7\n",
|
|
@@ -714,18 +749,20 @@
|
|
| 714 |
" print(\"\\nStarting training...\")\n",
|
| 715 |
" print(\"=\" * 60)\n",
|
| 716 |
" \n",
|
|
|
|
| 717 |
" for epoch in range(num_epochs):\n",
|
| 718 |
" print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
|
| 719 |
" \n",
|
| 720 |
-
" # Train
|
| 721 |
" train_loss, train_f1 = train_epoch(\n",
|
| 722 |
" model, train_loader, optimizer, criterion, device, f1_metric_train\n",
|
| 723 |
" )\n",
|
|
|
|
| 724 |
" val_loss, val_f1 = evaluate(\n",
|
| 725 |
" model, val_loader, criterion, device, f1_metric_val\n",
|
| 726 |
" )\n",
|
| 727 |
" \n",
|
| 728 |
-
" #
|
| 729 |
" scheduler.step(val_loss)\n",
|
| 730 |
" \n",
|
| 731 |
" # Store metrics\n",
|
|
@@ -1282,9 +1319,11 @@
|
|
| 1282 |
}
|
| 1283 |
],
|
| 1284 |
"source": [
|
|
|
|
| 1285 |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 1286 |
"print(f\"Using device: {device}\")\n",
|
| 1287 |
"\n",
|
|
|
|
| 1288 |
"model, text_vocab, label_vocab = train_lstm_pii_model(\n",
|
| 1289 |
" data_path='train_augmented.json',\n",
|
| 1290 |
" num_epochs=20,\n",
|
|
|
|
| 42 |
},
|
| 43 |
{
|
| 44 |
"cell_type": "code",
|
| 45 |
+
"execution_count": null,
|
| 46 |
"id": "1207cd93",
|
| 47 |
"metadata": {
|
| 48 |
"execution": {
|
|
|
|
| 62 |
},
|
| 63 |
"outputs": [],
|
| 64 |
"source": [
|
| 65 |
+
"# Define vocabulary class for text and label encoding\n",
|
| 66 |
"class Vocabulary:\n",
|
| 67 |
" \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
|
| 68 |
" def __init__(self, max_size=100000):\n",
|
| 69 |
+
" # Initialize special tokens\n",
|
| 70 |
" self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
|
| 71 |
" self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
|
| 72 |
" self.word_count = Counter()\n",
|
| 73 |
" self.max_size = max_size\n",
|
| 74 |
" \n",
|
| 75 |
" def add_sentence(self, sentence):\n",
|
| 76 |
+
" # Count word frequencies\n",
|
| 77 |
" for word in sentence:\n",
|
| 78 |
" self.word_count[word.lower()] += 1\n",
|
| 79 |
" \n",
|
| 80 |
" def build(self):\n",
|
| 81 |
+
" # Build vocabulary from most common words\n",
|
| 82 |
" most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
|
| 83 |
" for word, _ in most_common:\n",
|
| 84 |
" if word not in self.word2idx:\n",
|
|
|
|
| 90 |
" return len(self.word2idx)\n",
|
| 91 |
" \n",
|
| 92 |
" def encode(self, sentence):\n",
|
| 93 |
+
" # Convert words to indices\n",
|
| 94 |
" return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
|
| 95 |
" \n",
|
| 96 |
" def decode(self, indices):\n",
|
| 97 |
+
" # Convert indices back to words\n",
|
| 98 |
" return [self.idx2word.get(idx, '<unk>') for idx in indices]"
|
| 99 |
]
|
| 100 |
},
|
| 101 |
{
|
| 102 |
"cell_type": "code",
|
| 103 |
+
"execution_count": null,
|
| 104 |
"id": "f4056292",
|
| 105 |
"metadata": {
|
| 106 |
"execution": {
|
|
|
|
| 120 |
},
|
| 121 |
"outputs": [],
|
| 122 |
"source": [
|
| 123 |
+
"# Dataset class for PII detection\n",
|
| 124 |
"class PIIDataset(Dataset):\n",
|
| 125 |
" \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
|
| 126 |
" def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
|
|
|
|
| 134 |
" return len(self.tokens)\n",
|
| 135 |
" \n",
|
| 136 |
" def __getitem__(self, idx):\n",
|
| 137 |
+
" # Add special tokens to beginning and end\n",
|
| 138 |
" tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
|
| 139 |
" labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
|
| 140 |
" \n",
|
| 141 |
+
" # Truncate if sequence is too long\n",
|
| 142 |
" if len(tokens) > self.max_len:\n",
|
| 143 |
" tokens = tokens[:self.max_len-1] + ['<end>']\n",
|
| 144 |
" labels = labels[:self.max_len-1] + ['<end>']\n",
|
| 145 |
" \n",
|
| 146 |
+
" # Encode tokens and labels to indices\n",
|
| 147 |
" token_ids = self.text_vocab.encode(tokens)\n",
|
| 148 |
" label_ids = self.label_vocab.encode(labels)\n",
|
| 149 |
" \n",
|
|
|
|
| 152 |
},
|
| 153 |
{
|
| 154 |
"cell_type": "code",
|
| 155 |
+
"execution_count": null,
|
| 156 |
"id": "499deba2",
|
| 157 |
"metadata": {
|
| 158 |
"execution": {
|
|
|
|
| 174 |
"source": [
|
| 175 |
"def collate_fn(batch):\n",
|
| 176 |
" \"\"\"Custom collate function for padding sequences\"\"\"\n",
|
| 177 |
+
" # Separate tokens and labels\n",
|
| 178 |
" tokens, labels = zip(*batch)\n",
|
| 179 |
+
" # Pad sequences to same length\n",
|
| 180 |
" tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
|
| 181 |
" labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
|
| 182 |
" return tokens_padded, labels_padded"
|
|
|
|
| 184 |
},
|
| 185 |
{
|
| 186 |
"cell_type": "code",
|
| 187 |
+
"execution_count": null,
|
| 188 |
"id": "7ade0505",
|
| 189 |
"metadata": {
|
| 190 |
"execution": {
|
|
|
|
| 204 |
},
|
| 205 |
"outputs": [],
|
| 206 |
"source": [
|
| 207 |
+
"# F1 score metric for evaluation\n",
|
| 208 |
"class F1ScoreMetric:\n",
|
| 209 |
" \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
|
| 210 |
" def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
|
|
|
|
| 215 |
" self.reset()\n",
|
| 216 |
" \n",
|
| 217 |
" def reset(self):\n",
|
| 218 |
+
" # Reset counters\n",
|
| 219 |
" self.true_positives = 0\n",
|
| 220 |
" self.false_positives = 0\n",
|
| 221 |
" self.false_negatives = 0\n",
|
| 222 |
" self.class_metrics = {}\n",
|
| 223 |
" \n",
|
| 224 |
" def update(self, predictions, targets):\n",
|
| 225 |
+
" # Create mask to ignore padding and special tokens\n",
|
| 226 |
" mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
|
| 227 |
" o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
|
| 228 |
" \n",
|
| 229 |
+
" # Calculate metrics for each class\n",
|
| 230 |
" for class_id in range(1, self.num_classes):\n",
|
| 231 |
" if class_id == o_idx:\n",
|
| 232 |
" continue\n",
|
| 233 |
+
" \n",
|
| 234 |
+
" # Find where predictions and targets match this class\n",
|
| 235 |
" pred_mask = (predictions == class_id) & mask\n",
|
| 236 |
" true_mask = (targets == class_id) & mask\n",
|
| 237 |
" \n",
|
| 238 |
+
" # Count true positives, false positives, false negatives\n",
|
| 239 |
" tp = ((pred_mask) & (true_mask)).sum().item()\n",
|
| 240 |
" fp = ((pred_mask) & (~true_mask)).sum().item()\n",
|
| 241 |
" fn = ((~pred_mask) & (true_mask)).sum().item()\n",
|
| 242 |
" \n",
|
| 243 |
+
" # Update total counts\n",
|
| 244 |
" self.true_positives += tp\n",
|
| 245 |
" self.false_positives += fp\n",
|
| 246 |
" self.false_negatives += fn\n",
|
|
|
|
| 252 |
" self.class_metrics[class_id]['fn'] += fn\n",
|
| 253 |
" \n",
|
| 254 |
" def compute(self):\n",
|
| 255 |
+
" # Calculate F-beta score\n",
|
| 256 |
" beta_squared = self.beta ** 2\n",
|
| 257 |
" precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
|
| 258 |
" recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
|
|
|
|
| 260 |
" return f1\n",
|
| 261 |
" \n",
|
| 262 |
" def get_class_metrics(self):\n",
|
| 263 |
+
" # Get metrics for each class\n",
|
| 264 |
" results = {}\n",
|
| 265 |
" for class_id, metrics in self.class_metrics.items():\n",
|
| 266 |
" if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
|
|
|
|
| 279 |
},
|
| 280 |
{
|
| 281 |
"cell_type": "code",
|
| 282 |
+
"execution_count": null,
|
| 283 |
"id": "361b5505",
|
| 284 |
"metadata": {
|
| 285 |
"execution": {
|
|
|
|
| 299 |
},
|
| 300 |
"outputs": [],
|
| 301 |
"source": [
|
| 302 |
+
"# Focal loss for handling class imbalance\n",
|
| 303 |
"class FocalLoss(nn.Module):\n",
|
| 304 |
" \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
|
| 305 |
" def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
|
|
|
|
| 310 |
" self.ignore_index = ignore_index\n",
|
| 311 |
" \n",
|
| 312 |
" def forward(self, inputs, targets):\n",
|
| 313 |
+
" # Calculate cross entropy loss\n",
|
| 314 |
" ce_loss = nn.functional.cross_entropy(\n",
|
| 315 |
" inputs, targets, \n",
|
| 316 |
" weight=self.alpha, \n",
|
|
|
|
| 318 |
" ignore_index=self.ignore_index\n",
|
| 319 |
" )\n",
|
| 320 |
" \n",
|
| 321 |
+
" # Apply focal term to focus on hard examples\n",
|
| 322 |
" pt = torch.exp(-ce_loss)\n",
|
| 323 |
" focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
|
| 324 |
" \n",
|
| 325 |
+
" # Reduce loss based on specified method\n",
|
| 326 |
" if self.reduction == 'mean':\n",
|
| 327 |
" return focal_loss.mean()\n",
|
| 328 |
" elif self.reduction == 'sum':\n",
|
|
|
|
| 333 |
},
|
| 334 |
{
|
| 335 |
"cell_type": "code",
|
| 336 |
+
"execution_count": null,
|
| 337 |
"id": "1de646e9",
|
| 338 |
"metadata": {
|
| 339 |
"execution": {
|
|
|
|
| 359 |
" total_loss = 0\n",
|
| 360 |
" f1_metric.reset()\n",
|
| 361 |
" \n",
|
| 362 |
+
" # Progress bar for training\n",
|
| 363 |
" progress_bar = tqdm(dataloader, desc='Training')\n",
|
| 364 |
" for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
|
| 365 |
+
" # Move data to device\n",
|
| 366 |
" tokens = tokens.to(device)\n",
|
| 367 |
" labels = labels.to(device)\n",
|
| 368 |
" \n",
|
|
|
|
| 377 |
" # Calculate loss and backward pass\n",
|
| 378 |
" loss = criterion(outputs_flat, labels_flat)\n",
|
| 379 |
" loss.backward()\n",
|
| 380 |
+
" # Clip gradients to prevent exploding gradients\n",
|
| 381 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
|
| 382 |
" optimizer.step()\n",
|
| 383 |
" \n",
|
|
|
|
| 397 |
},
|
| 398 |
{
|
| 399 |
"cell_type": "code",
|
| 400 |
+
"execution_count": null,
|
| 401 |
"id": "d1ce3b0f",
|
| 402 |
"metadata": {
|
| 403 |
"execution": {
|
|
|
|
| 423 |
" total_loss = 0\n",
|
| 424 |
" f1_metric.reset()\n",
|
| 425 |
" \n",
|
| 426 |
+
" # No gradient computation during evaluation\n",
|
| 427 |
" with torch.no_grad():\n",
|
| 428 |
" for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
|
| 429 |
+
" # Move data to device\n",
|
| 430 |
" tokens = tokens.to(device)\n",
|
| 431 |
" labels = labels.to(device)\n",
|
| 432 |
" \n",
|
|
|
|
| 448 |
},
|
| 449 |
{
|
| 450 |
"cell_type": "code",
|
| 451 |
+
"execution_count": null,
|
| 452 |
"id": "da3ff80c",
|
| 453 |
"metadata": {
|
| 454 |
"execution": {
|
|
|
|
| 472 |
" \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
|
| 473 |
" sample_weights = []\n",
|
| 474 |
" \n",
|
| 475 |
+
" # Calculate weight for each sample\n",
|
| 476 |
" for idx in range(len(dataset)):\n",
|
| 477 |
" _, labels = dataset[idx]\n",
|
| 478 |
" \n",
|
|
|
|
| 481 |
" for label_id in labels:\n",
|
| 482 |
" if label_id > 3: # Skip special tokens\n",
|
| 483 |
" label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
|
| 484 |
+
" # If sample contains PII, give it higher weight\n",
|
| 485 |
" if label_name != 'o' and 'B-' in label_name:\n",
|
| 486 |
" min_weight = 10.0\n",
|
| 487 |
" break\n",
|
| 488 |
" \n",
|
| 489 |
" sample_weights.append(min_weight)\n",
|
| 490 |
" \n",
|
| 491 |
+
" # Create weighted sampler\n",
|
| 492 |
" sampler = WeightedRandomSampler(\n",
|
| 493 |
" weights=sample_weights,\n",
|
| 494 |
" num_samples=len(sample_weights),\n",
|
|
|
|
| 500 |
},
|
| 501 |
{
|
| 502 |
"cell_type": "code",
|
| 503 |
+
"execution_count": null,
|
| 504 |
"id": "69b37e68",
|
| 505 |
"metadata": {
|
| 506 |
"execution": {
|
|
|
|
| 523 |
"def print_label_distribution(data, title=\"Label Distribution\"):\n",
|
| 524 |
" \"\"\"Print label distribution statistics\"\"\"\n",
|
| 525 |
" label_counts = Counter()\n",
|
| 526 |
+
"\n",
|
| 527 |
+
" # Count each label type\n",
|
| 528 |
" for label_seq in data.labels:\n",
|
| 529 |
" for label in label_seq:\n",
|
| 530 |
" if label not in ['<pad>', '<start>', '<end>']:\n",
|
| 531 |
" label_counts[label] += 1\n",
|
| 532 |
" \n",
|
| 533 |
+
" # Print distribution\n",
|
| 534 |
" print(f\"\\n{title}:\")\n",
|
| 535 |
" print(\"-\" * 50)\n",
|
| 536 |
" total = sum(label_counts.values())\n",
|
|
|
|
| 543 |
},
|
| 544 |
{
|
| 545 |
"cell_type": "code",
|
| 546 |
+
"execution_count": null,
|
| 547 |
"id": "4b1b4f86",
|
| 548 |
"metadata": {
|
| 549 |
"execution": {
|
|
|
|
| 567 |
" \"\"\"Save model and all necessary components for deployment\"\"\"\n",
|
| 568 |
" os.makedirs(save_dir, exist_ok=True)\n",
|
| 569 |
" \n",
|
| 570 |
+
" # Save model weights\n",
|
| 571 |
" model_path = os.path.join(save_dir, 'pii_lstm_model.pt')\n",
|
| 572 |
" torch.save(model.state_dict(), model_path)\n",
|
| 573 |
" \n",
|
|
|
|
| 593 |
},
|
| 594 |
{
|
| 595 |
"cell_type": "code",
|
| 596 |
+
"execution_count": null,
|
| 597 |
"id": "31d2f1b1",
|
| 598 |
"metadata": {
|
| 599 |
"execution": {
|
|
|
|
| 629 |
" data = pd.read_json(data_path, lines=True)\n",
|
| 630 |
" print(f\"Total samples: {len(data)}\")\n",
|
| 631 |
" \n",
|
| 632 |
+
" # Show label distribution\n",
|
| 633 |
" print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
|
| 634 |
" \n",
|
| 635 |
" # Build vocabularies\n",
|
|
|
|
| 642 |
" for labels in data.labels:\n",
|
| 643 |
" label_vocab.add_sentence(labels)\n",
|
| 644 |
" \n",
|
| 645 |
+
" # Build vocabularies from collected words\n",
|
| 646 |
" text_vocab.build()\n",
|
| 647 |
" label_vocab.build()\n",
|
| 648 |
" \n",
|
|
|
|
| 650 |
" print(f\" - Text vocabulary: {len(text_vocab):,}\")\n",
|
| 651 |
" print(f\" - Label vocabulary: {len(label_vocab)}\")\n",
|
| 652 |
" \n",
|
| 653 |
+
" # Calculate class weights for balanced loss\n",
|
| 654 |
" class_weights = calculate_class_weights(data, label_vocab)\n",
|
| 655 |
" class_weights = class_weights.to(device)\n",
|
| 656 |
" \n",
|
| 657 |
+
" # Split data into train and validation sets\n",
|
| 658 |
" X_train, X_val, y_train, y_val = train_test_split(\n",
|
| 659 |
" data.tokens.tolist(),\n",
|
| 660 |
" data.labels.tolist(),\n",
|
|
|
|
| 666 |
" print(f\" - Train samples: {len(X_train):,}\")\n",
|
| 667 |
" print(f\" - Validation samples: {len(X_val):,}\")\n",
|
| 668 |
" \n",
|
| 669 |
+
" # Create datasets\n",
|
| 670 |
" max_seq_len = 512\n",
|
| 671 |
" train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
|
| 672 |
" val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
|
| 673 |
" \n",
|
| 674 |
+
" # Create balanced sampler for training\n",
|
| 675 |
" train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
|
| 676 |
" \n",
|
| 677 |
+
" # Create data loaders\n",
|
| 678 |
" train_loader = DataLoader(\n",
|
| 679 |
" train_dataset, \n",
|
| 680 |
" batch_size=batch_size,\n",
|
|
|
|
| 703 |
" 'max_len': max_seq_len\n",
|
| 704 |
" }\n",
|
| 705 |
" \n",
|
| 706 |
+
" # Create LSTM model\n",
|
| 707 |
" print(\"\\nCreating LSTM model...\")\n",
|
| 708 |
" model = create_lstm_pii_model(**model_config).to(device)\n",
|
| 709 |
" print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
|
|
|
|
| 736 |
" min_lr=1e-6\n",
|
| 737 |
" )\n",
|
| 738 |
" \n",
|
| 739 |
+
" # Initialize metrics\n",
|
| 740 |
" f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
|
| 741 |
" f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
|
| 742 |
" \n",
|
| 743 |
+
" # Training history\n",
|
| 744 |
" train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
|
| 745 |
" best_val_f1 = 0\n",
|
| 746 |
" patience = 7\n",
|
|
|
|
| 749 |
" print(\"\\nStarting training...\")\n",
|
| 750 |
" print(\"=\" * 60)\n",
|
| 751 |
" \n",
|
| 752 |
+
" # Training loop\n",
|
| 753 |
" for epoch in range(num_epochs):\n",
|
| 754 |
" print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
|
| 755 |
" \n",
|
| 756 |
+
" # Train for one epoch\n",
|
| 757 |
" train_loss, train_f1 = train_epoch(\n",
|
| 758 |
" model, train_loader, optimizer, criterion, device, f1_metric_train\n",
|
| 759 |
" )\n",
|
| 760 |
+
" # Evaluate on validation set\n",
|
| 761 |
" val_loss, val_f1 = evaluate(\n",
|
| 762 |
" model, val_loader, criterion, device, f1_metric_val\n",
|
| 763 |
" )\n",
|
| 764 |
" \n",
|
| 765 |
+
" # Adjust learning rate based on validation loss\n",
|
| 766 |
" scheduler.step(val_loss)\n",
|
| 767 |
" \n",
|
| 768 |
" # Store metrics\n",
|
|
|
|
| 1319 |
}
|
| 1320 |
],
|
| 1321 |
"source": [
|
| 1322 |
+
"# Set device\n",
|
| 1323 |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 1324 |
"print(f\"Using device: {device}\")\n",
|
| 1325 |
"\n",
|
| 1326 |
+
"# Train the LSTM model\n",
|
| 1327 |
"model, text_vocab, label_vocab = train_lstm_pii_model(\n",
|
| 1328 |
" data_path='train_augmented.json',\n",
|
| 1329 |
" num_epochs=20,\n",
|
transformer.py
CHANGED
|
@@ -4,37 +4,25 @@ import torch.nn.functional as F
|
|
| 4 |
import math
|
| 5 |
|
| 6 |
def scaled_dot_product_attention(q, k, v, mask=None, dropout=None):
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
Args:
|
| 11 |
-
q: queries (batch_size, num_heads, seq_len_q, d_k)
|
| 12 |
-
k: keys (batch_size, num_heads, seq_len_k, d_k)
|
| 13 |
-
v: values (batch_size, num_heads, seq_len_v, d_v)
|
| 14 |
-
mask: mask tensor (batch_size, 1, 1, seq_len_k) or (batch_size, 1, seq_len_q, seq_len_k)
|
| 15 |
-
dropout: dropout layer
|
| 16 |
-
|
| 17 |
-
Returns:
|
| 18 |
-
output: attended values (batch_size, num_heads, seq_len_q, d_v)
|
| 19 |
-
attention_weights: attention weights (batch_size, num_heads, seq_len_q, seq_len_k)
|
| 20 |
-
"""
|
| 21 |
d_k = q.size(-1)
|
| 22 |
|
| 23 |
-
#
|
| 24 |
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
|
| 25 |
|
| 26 |
-
#
|
| 27 |
if mask is not None:
|
| 28 |
scores = scores.masked_fill(mask == 0, float('-inf'))
|
| 29 |
|
| 30 |
-
#
|
| 31 |
attention_weights = F.softmax(scores, dim=-1)
|
| 32 |
|
| 33 |
-
# Apply dropout if
|
| 34 |
if dropout is not None:
|
| 35 |
attention_weights = dropout(attention_weights)
|
| 36 |
|
| 37 |
-
# Apply attention to values
|
| 38 |
output = torch.matmul(attention_weights, v)
|
| 39 |
|
| 40 |
return output, attention_weights
|
|
@@ -48,47 +36,42 @@ class MultiHeadAttention(nn.Module):
|
|
| 48 |
|
| 49 |
self.d_model = d_model
|
| 50 |
self.num_heads = num_heads
|
| 51 |
-
self.d_k = d_model // num_heads
|
| 52 |
|
| 53 |
-
# Linear
|
| 54 |
self.w_q = nn.Linear(d_model, d_model)
|
| 55 |
self.w_k = nn.Linear(d_model, d_model)
|
| 56 |
self.w_v = nn.Linear(d_model, d_model)
|
| 57 |
|
| 58 |
-
#
|
| 59 |
self.w_o = nn.Linear(d_model, d_model)
|
| 60 |
|
| 61 |
-
# Dropout
|
| 62 |
self.dropout = nn.Dropout(dropout)
|
| 63 |
|
| 64 |
def forward(self, query, key, value, mask=None):
|
| 65 |
"""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
mask: (batch_size, 1, 1, seq_len_k) or None
|
| 71 |
-
|
| 72 |
-
Returns:
|
| 73 |
-
output: (batch_size, seq_len_q, d_model)
|
| 74 |
-
attention_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
|
| 75 |
"""
|
| 76 |
batch_size = query.size(0)
|
| 77 |
seq_len_q = query.size(1)
|
| 78 |
seq_len_k = key.size(1)
|
| 79 |
seq_len_v = value.size(1)
|
| 80 |
|
| 81 |
-
#
|
| 82 |
Q = self.w_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
|
| 83 |
K = self.w_k(key).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
|
| 84 |
V = self.w_v(value).view(batch_size, seq_len_v, self.num_heads, self.d_k).transpose(1, 2)
|
| 85 |
|
| 86 |
-
#
|
| 87 |
attention_output, attention_weights = scaled_dot_product_attention(
|
| 88 |
Q, K, V, mask=mask, dropout=self.dropout
|
| 89 |
)
|
| 90 |
|
| 91 |
-
#
|
| 92 |
attention_output = attention_output.transpose(1, 2).contiguous().view(
|
| 93 |
batch_size, seq_len_q, self.d_model
|
| 94 |
)
|
|
@@ -102,19 +85,14 @@ class PositionwiseFeedForward(nn.Module):
|
|
| 102 |
|
| 103 |
def __init__(self, d_model, d_ff, dropout=0.1):
|
| 104 |
super(PositionwiseFeedForward, self).__init__()
|
|
|
|
| 105 |
self.w_1 = nn.Linear(d_model, d_ff)
|
| 106 |
self.w_2 = nn.Linear(d_ff, d_model)
|
| 107 |
self.dropout = nn.Dropout(dropout)
|
| 108 |
self.activation = nn.ReLU()
|
| 109 |
|
| 110 |
def forward(self, x):
|
| 111 |
-
|
| 112 |
-
Args:
|
| 113 |
-
x: (batch_size, seq_len, d_model)
|
| 114 |
-
|
| 115 |
-
Returns:
|
| 116 |
-
output: (batch_size, seq_len, d_model)
|
| 117 |
-
"""
|
| 118 |
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
| 119 |
|
| 120 |
class EncoderLayer(nn.Module):
|
|
@@ -123,33 +101,25 @@ class EncoderLayer(nn.Module):
|
|
| 123 |
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
|
| 124 |
super(EncoderLayer, self).__init__()
|
| 125 |
|
| 126 |
-
# Multi-head attention
|
| 127 |
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
|
| 128 |
|
| 129 |
-
# Position-wise feed forward
|
| 130 |
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
|
| 131 |
|
| 132 |
-
# Layer normalization
|
| 133 |
self.norm1 = nn.LayerNorm(d_model)
|
| 134 |
self.norm2 = nn.LayerNorm(d_model)
|
| 135 |
|
| 136 |
-
# Dropout
|
| 137 |
self.dropout = nn.Dropout(dropout)
|
| 138 |
|
| 139 |
def forward(self, x, mask=None):
|
| 140 |
-
|
| 141 |
-
Args:
|
| 142 |
-
x: (batch_size, seq_len, d_model)
|
| 143 |
-
mask: (batch_size, 1, 1, seq_len) or None
|
| 144 |
-
|
| 145 |
-
Returns:
|
| 146 |
-
output: (batch_size, seq_len, d_model)
|
| 147 |
-
"""
|
| 148 |
-
# Self-attention with residual connection and layer norm
|
| 149 |
attn_output, _ = self.self_attention(x, x, x, mask)
|
| 150 |
x = self.norm1(x + self.dropout(attn_output))
|
| 151 |
|
| 152 |
-
# Feed forward with residual connection and layer norm
|
| 153 |
ff_output = self.feed_forward(x)
|
| 154 |
x = self.norm2(x + self.dropout(ff_output))
|
| 155 |
|
|
@@ -161,25 +131,21 @@ class TransformerEncoder(nn.Module):
|
|
| 161 |
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
|
| 162 |
super(TransformerEncoder, self).__init__()
|
| 163 |
|
|
|
|
| 164 |
self.layers = nn.ModuleList([
|
| 165 |
EncoderLayer(d_model, num_heads, d_ff, dropout)
|
| 166 |
for _ in range(num_layers)
|
| 167 |
])
|
| 168 |
|
|
|
|
| 169 |
self.norm = nn.LayerNorm(d_model)
|
| 170 |
|
| 171 |
def forward(self, x, mask=None):
|
| 172 |
-
|
| 173 |
-
Args:
|
| 174 |
-
x: (batch_size, seq_len, d_model)
|
| 175 |
-
mask: (batch_size, 1, 1, seq_len) or None
|
| 176 |
-
|
| 177 |
-
Returns:
|
| 178 |
-
output: (batch_size, seq_len, d_model)
|
| 179 |
-
"""
|
| 180 |
for layer in self.layers:
|
| 181 |
x = layer(x, mask)
|
| 182 |
|
|
|
|
| 183 |
return self.norm(x)
|
| 184 |
|
| 185 |
class PositionalEncoding(nn.Module):
|
|
@@ -189,36 +155,29 @@ class PositionalEncoding(nn.Module):
|
|
| 189 |
super(PositionalEncoding, self).__init__()
|
| 190 |
self.dropout = nn.Dropout(dropout)
|
| 191 |
|
| 192 |
-
# Create positional
|
| 193 |
pe = torch.zeros(max_len, d_model)
|
| 194 |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 195 |
|
| 196 |
-
# Create
|
| 197 |
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
| 198 |
(-math.log(10000.0) / d_model))
|
| 199 |
|
| 200 |
-
# Apply
|
| 201 |
pe[:, 0::2] = torch.sin(position * div_term)
|
| 202 |
|
| 203 |
-
# Apply
|
| 204 |
if d_model % 2 == 1:
|
| 205 |
pe[:, 1::2] = torch.cos(position * div_term[:-1])
|
| 206 |
else:
|
| 207 |
pe[:, 1::2] = torch.cos(position * div_term)
|
| 208 |
|
| 209 |
-
# Add batch dimension and
|
| 210 |
pe = pe.unsqueeze(0)
|
| 211 |
self.register_buffer('pe', pe)
|
| 212 |
|
| 213 |
def forward(self, x):
|
| 214 |
-
|
| 215 |
-
Args:
|
| 216 |
-
x: (batch_size, seq_len, d_model)
|
| 217 |
-
|
| 218 |
-
Returns:
|
| 219 |
-
output: (batch_size, seq_len, d_model)
|
| 220 |
-
"""
|
| 221 |
-
# Add positional encoding
|
| 222 |
x = x + self.pe[:, :x.size(1), :]
|
| 223 |
return self.dropout(x)
|
| 224 |
|
|
@@ -235,62 +194,46 @@ class TransformerPII(nn.Module):
|
|
| 235 |
self.d_model = d_model
|
| 236 |
self.pad_idx = pad_idx
|
| 237 |
|
| 238 |
-
#
|
| 239 |
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
|
| 240 |
|
| 241 |
-
#
|
| 242 |
self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
|
| 243 |
|
| 244 |
-
#
|
| 245 |
self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
|
| 246 |
|
| 247 |
-
# Classification head
|
| 248 |
self.classifier = nn.Linear(d_model, num_classes)
|
| 249 |
|
| 250 |
-
# Dropout
|
| 251 |
self.dropout = nn.Dropout(dropout)
|
| 252 |
|
| 253 |
-
# Initialize weights
|
| 254 |
self._init_weights()
|
| 255 |
|
| 256 |
def _init_weights(self):
|
| 257 |
"""Initialize model weights"""
|
| 258 |
-
# Initialize embeddings
|
| 259 |
nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model**-0.5)
|
|
|
|
| 260 |
if self.pad_idx is not None:
|
| 261 |
nn.init.constant_(self.embedding.weight[self.pad_idx], 0)
|
| 262 |
|
| 263 |
-
# Initialize classifier
|
| 264 |
nn.init.xavier_uniform_(self.classifier.weight)
|
| 265 |
if self.classifier.bias is not None:
|
| 266 |
nn.init.constant_(self.classifier.bias, 0)
|
| 267 |
|
| 268 |
def create_padding_mask(self, x):
|
| 269 |
-
"""
|
| 270 |
-
Create padding
|
| 271 |
-
|
| 272 |
-
Args:
|
| 273 |
-
x: (batch_size, seq_len) - input token indices
|
| 274 |
-
|
| 275 |
-
Returns:
|
| 276 |
-
mask: (batch_size, 1, 1, seq_len) - attention mask
|
| 277 |
-
"""
|
| 278 |
-
# Create mask where padding tokens are marked as 0
|
| 279 |
mask = (x != self.pad_idx).unsqueeze(1).unsqueeze(2)
|
| 280 |
return mask.float()
|
| 281 |
|
| 282 |
def forward(self, x, mask=None):
|
| 283 |
-
"""
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
Args:
|
| 287 |
-
x: (batch_size, seq_len) - input token indices
|
| 288 |
-
mask: Optional custom attention mask
|
| 289 |
-
|
| 290 |
-
Returns:
|
| 291 |
-
logits: (batch_size, seq_len, num_classes) - classification logits
|
| 292 |
-
"""
|
| 293 |
-
# Check input dimensions
|
| 294 |
if x.dim() != 2:
|
| 295 |
raise ValueError(f"Expected input to have 2 dimensions [batch_size, seq_len], got {x.dim()}")
|
| 296 |
|
|
@@ -300,94 +243,35 @@ class TransformerPII(nn.Module):
|
|
| 300 |
if mask is None:
|
| 301 |
mask = self.create_padding_mask(x)
|
| 302 |
|
| 303 |
-
#
|
| 304 |
x = self.embedding(x) * math.sqrt(self.d_model)
|
| 305 |
|
| 306 |
# Add positional encoding
|
| 307 |
x = self.positional_encoding(x)
|
| 308 |
|
| 309 |
-
# Pass through transformer encoder
|
| 310 |
encoder_output = self.encoder(x, mask)
|
| 311 |
|
| 312 |
# Apply dropout before classification
|
| 313 |
encoder_output = self.dropout(encoder_output)
|
| 314 |
|
| 315 |
-
#
|
| 316 |
logits = self.classifier(encoder_output)
|
| 317 |
|
| 318 |
return logits
|
| 319 |
|
| 320 |
def predict(self, x):
|
| 321 |
-
"""
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
Args:
|
| 325 |
-
x: (batch_size, seq_len) - input token indices
|
| 326 |
-
|
| 327 |
-
Returns:
|
| 328 |
-
predictions: (batch_size, seq_len) - predicted class indices
|
| 329 |
-
"""
|
| 330 |
self.eval()
|
| 331 |
with torch.no_grad():
|
| 332 |
logits = self.forward(x)
|
| 333 |
predictions = torch.argmax(logits, dim=-1)
|
| 334 |
return predictions
|
| 335 |
|
| 336 |
-
class TransformerPIIWithCRF(TransformerPII):
|
| 337 |
-
"""
|
| 338 |
-
Transformer with CRF layer for improved sequence labeling
|
| 339 |
-
(Optional enhancement - requires pytorch-crf)
|
| 340 |
-
"""
|
| 341 |
-
|
| 342 |
-
def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
|
| 343 |
-
d_ff=512, num_layers=4, dropout=0.1, max_len=512, pad_idx=0):
|
| 344 |
-
super(TransformerPIIWithCRF, self).__init__(
|
| 345 |
-
vocab_size, num_classes, d_model, num_heads,
|
| 346 |
-
d_ff, num_layers, dropout, max_len, pad_idx
|
| 347 |
-
)
|
| 348 |
-
|
| 349 |
-
# CRF layer would be initialized here
|
| 350 |
-
# from torchcrf import CRF
|
| 351 |
-
# self.crf = CRF(num_classes, batch_first=True)
|
| 352 |
-
|
| 353 |
-
def forward(self, x, labels=None):
|
| 354 |
-
"""Forward pass with optional CRF"""
|
| 355 |
-
# Get transformer outputs
|
| 356 |
-
emissions = super().forward(x)
|
| 357 |
-
|
| 358 |
-
if labels is not None:
|
| 359 |
-
# Training mode with CRF
|
| 360 |
-
# mask = (x != self.pad_idx)
|
| 361 |
-
# loss = -self.crf(emissions, labels, mask=mask)
|
| 362 |
-
# return loss
|
| 363 |
-
pass
|
| 364 |
-
else:
|
| 365 |
-
# Inference mode with CRF
|
| 366 |
-
# mask = (x != self.pad_idx)
|
| 367 |
-
# predictions = self.crf.decode(emissions, mask=mask)
|
| 368 |
-
# return predictions
|
| 369 |
-
pass
|
| 370 |
-
|
| 371 |
-
return emissions
|
| 372 |
-
|
| 373 |
def create_transformer_pii_model(vocab_size, num_classes, d_model=256, num_heads=8,
|
| 374 |
d_ff=512, num_layers=4, dropout=0.1, max_len=512):
|
| 375 |
-
"""
|
| 376 |
-
Factory function to create transformer model for PII detection
|
| 377 |
-
|
| 378 |
-
Args:
|
| 379 |
-
vocab_size: Size of vocabulary
|
| 380 |
-
num_classes: Number of PII classes (e.g., 20)
|
| 381 |
-
d_model: Dimension of model (hidden size)
|
| 382 |
-
num_heads: Number of attention heads
|
| 383 |
-
d_ff: Dimension of feedforward network
|
| 384 |
-
num_layers: Number of transformer layers
|
| 385 |
-
dropout: Dropout rate
|
| 386 |
-
max_len: Maximum sequence length
|
| 387 |
-
|
| 388 |
-
Returns:
|
| 389 |
-
TransformerPII model instance
|
| 390 |
-
"""
|
| 391 |
model = TransformerPII(
|
| 392 |
vocab_size=vocab_size,
|
| 393 |
num_classes=num_classes,
|
|
@@ -397,7 +281,7 @@ def create_transformer_pii_model(vocab_size, num_classes, d_model=256, num_heads
|
|
| 397 |
num_layers=num_layers,
|
| 398 |
dropout=dropout,
|
| 399 |
max_len=max_len,
|
| 400 |
-
pad_idx=0
|
| 401 |
)
|
| 402 |
|
| 403 |
return model
|
|
|
|
| 4 |
import math
|
| 5 |
|
| 6 |
def scaled_dot_product_attention(q, k, v, mask=None, dropout=None):
|
| 7 |
+
"""Compute scaled dot-product attention."""
|
| 8 |
+
# Get dimension of keys for scaling
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
d_k = q.size(-1)
|
| 10 |
|
| 11 |
+
# Compute attention scores using dot product
|
| 12 |
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
|
| 13 |
|
| 14 |
+
# Mask out padding positions if mask provided
|
| 15 |
if mask is not None:
|
| 16 |
scores = scores.masked_fill(mask == 0, float('-inf'))
|
| 17 |
|
| 18 |
+
# Convert scores to probabilities
|
| 19 |
attention_weights = F.softmax(scores, dim=-1)
|
| 20 |
|
| 21 |
+
# Apply dropout to attention weights if specified
|
| 22 |
if dropout is not None:
|
| 23 |
attention_weights = dropout(attention_weights)
|
| 24 |
|
| 25 |
+
# Apply attention weights to values
|
| 26 |
output = torch.matmul(attention_weights, v)
|
| 27 |
|
| 28 |
return output, attention_weights
|
|
|
|
| 36 |
|
| 37 |
self.d_model = d_model
|
| 38 |
self.num_heads = num_heads
|
| 39 |
+
self.d_k = d_model // num_heads # Dimension per head
|
| 40 |
|
| 41 |
+
# Linear layers for projecting Q, K, V
|
| 42 |
self.w_q = nn.Linear(d_model, d_model)
|
| 43 |
self.w_k = nn.Linear(d_model, d_model)
|
| 44 |
self.w_v = nn.Linear(d_model, d_model)
|
| 45 |
|
| 46 |
+
# Final output projection
|
| 47 |
self.w_o = nn.Linear(d_model, d_model)
|
| 48 |
|
| 49 |
+
# Dropout layer
|
| 50 |
self.dropout = nn.Dropout(dropout)
|
| 51 |
|
| 52 |
def forward(self, query, key, value, mask=None):
|
| 53 |
"""
|
| 54 |
+
query: (batch_size, seq_len_q, d_model)
|
| 55 |
+
key: (batch_size, seq_len_k, d_model)
|
| 56 |
+
value: (batch_size, seq_len_v, d_model)
|
| 57 |
+
mask: (batch_size, 1, 1, seq_len_k) or None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
"""
|
| 59 |
batch_size = query.size(0)
|
| 60 |
seq_len_q = query.size(1)
|
| 61 |
seq_len_k = key.size(1)
|
| 62 |
seq_len_v = value.size(1)
|
| 63 |
|
| 64 |
+
# Project and reshape for multiple heads
|
| 65 |
Q = self.w_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
|
| 66 |
K = self.w_k(key).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
|
| 67 |
V = self.w_v(value).view(batch_size, seq_len_v, self.num_heads, self.d_k).transpose(1, 2)
|
| 68 |
|
| 69 |
+
# Apply scaled dot-product attention
|
| 70 |
attention_output, attention_weights = scaled_dot_product_attention(
|
| 71 |
Q, K, V, mask=mask, dropout=self.dropout
|
| 72 |
)
|
| 73 |
|
| 74 |
+
# Concatenate heads and apply output projection
|
| 75 |
attention_output = attention_output.transpose(1, 2).contiguous().view(
|
| 76 |
batch_size, seq_len_q, self.d_model
|
| 77 |
)
|
|
|
|
| 85 |
|
| 86 |
def __init__(self, d_model, d_ff, dropout=0.1):
|
| 87 |
super(PositionwiseFeedForward, self).__init__()
|
| 88 |
+
# Two linear layers with ReLU activation
|
| 89 |
self.w_1 = nn.Linear(d_model, d_ff)
|
| 90 |
self.w_2 = nn.Linear(d_ff, d_model)
|
| 91 |
self.dropout = nn.Dropout(dropout)
|
| 92 |
self.activation = nn.ReLU()
|
| 93 |
|
| 94 |
def forward(self, x):
|
| 95 |
+
# Apply first linear layer, activation, dropout, then second linear layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
| 97 |
|
| 98 |
class EncoderLayer(nn.Module):
|
|
|
|
| 101 |
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
|
| 102 |
super(EncoderLayer, self).__init__()
|
| 103 |
|
| 104 |
+
# Multi-head self-attention sublayer
|
| 105 |
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
|
| 106 |
|
| 107 |
+
# Position-wise feed forward sublayer
|
| 108 |
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
|
| 109 |
|
| 110 |
+
# Layer normalization for each sublayer
|
| 111 |
self.norm1 = nn.LayerNorm(d_model)
|
| 112 |
self.norm2 = nn.LayerNorm(d_model)
|
| 113 |
|
| 114 |
+
# Dropout for residual connections
|
| 115 |
self.dropout = nn.Dropout(dropout)
|
| 116 |
|
| 117 |
def forward(self, x, mask=None):
|
| 118 |
+
# Self-attention sublayer with residual connection and layer norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
attn_output, _ = self.self_attention(x, x, x, mask)
|
| 120 |
x = self.norm1(x + self.dropout(attn_output))
|
| 121 |
|
| 122 |
+
# Feed forward sublayer with residual connection and layer norm
|
| 123 |
ff_output = self.feed_forward(x)
|
| 124 |
x = self.norm2(x + self.dropout(ff_output))
|
| 125 |
|
|
|
|
| 131 |
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
|
| 132 |
super(TransformerEncoder, self).__init__()
|
| 133 |
|
| 134 |
+
# Create stack of encoder layers
|
| 135 |
self.layers = nn.ModuleList([
|
| 136 |
EncoderLayer(d_model, num_heads, d_ff, dropout)
|
| 137 |
for _ in range(num_layers)
|
| 138 |
])
|
| 139 |
|
| 140 |
+
# Final layer normalization
|
| 141 |
self.norm = nn.LayerNorm(d_model)
|
| 142 |
|
| 143 |
def forward(self, x, mask=None):
|
| 144 |
+
# Pass through each encoder layer sequentially
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
for layer in self.layers:
|
| 146 |
x = layer(x, mask)
|
| 147 |
|
| 148 |
+
# Apply final normalization
|
| 149 |
return self.norm(x)
|
| 150 |
|
| 151 |
class PositionalEncoding(nn.Module):
|
|
|
|
| 155 |
super(PositionalEncoding, self).__init__()
|
| 156 |
self.dropout = nn.Dropout(dropout)
|
| 157 |
|
| 158 |
+
# Create matrix to hold positional encodings
|
| 159 |
pe = torch.zeros(max_len, d_model)
|
| 160 |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 161 |
|
| 162 |
+
# Create frequency terms for sin/cos functions
|
| 163 |
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
| 164 |
(-math.log(10000.0) / d_model))
|
| 165 |
|
| 166 |
+
# Apply sine to even indices
|
| 167 |
pe[:, 0::2] = torch.sin(position * div_term)
|
| 168 |
|
| 169 |
+
# Apply cosine to odd indices
|
| 170 |
if d_model % 2 == 1:
|
| 171 |
pe[:, 1::2] = torch.cos(position * div_term[:-1])
|
| 172 |
else:
|
| 173 |
pe[:, 1::2] = torch.cos(position * div_term)
|
| 174 |
|
| 175 |
+
# Add batch dimension and save as buffer
|
| 176 |
pe = pe.unsqueeze(0)
|
| 177 |
self.register_buffer('pe', pe)
|
| 178 |
|
| 179 |
def forward(self, x):
|
| 180 |
+
# Add positional encoding to input embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
x = x + self.pe[:, :x.size(1), :]
|
| 182 |
return self.dropout(x)
|
| 183 |
|
|
|
|
| 194 |
self.d_model = d_model
|
| 195 |
self.pad_idx = pad_idx
|
| 196 |
|
| 197 |
+
# Embedding layer for input tokens
|
| 198 |
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
|
| 199 |
|
| 200 |
+
# Add positional information to embeddings
|
| 201 |
self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
|
| 202 |
|
| 203 |
+
# Stack of transformer encoder layers
|
| 204 |
self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
|
| 205 |
|
| 206 |
+
# Classification head for token-level predictions
|
| 207 |
self.classifier = nn.Linear(d_model, num_classes)
|
| 208 |
|
| 209 |
+
# Dropout layer
|
| 210 |
self.dropout = nn.Dropout(dropout)
|
| 211 |
|
| 212 |
+
# Initialize model weights
|
| 213 |
self._init_weights()
|
| 214 |
|
| 215 |
def _init_weights(self):
|
| 216 |
"""Initialize model weights"""
|
| 217 |
+
# Initialize embeddings with normal distribution
|
| 218 |
nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model**-0.5)
|
| 219 |
+
# Set padding token embedding to zero
|
| 220 |
if self.pad_idx is not None:
|
| 221 |
nn.init.constant_(self.embedding.weight[self.pad_idx], 0)
|
| 222 |
|
| 223 |
+
# Initialize classifier with Xavier uniform
|
| 224 |
nn.init.xavier_uniform_(self.classifier.weight)
|
| 225 |
if self.classifier.bias is not None:
|
| 226 |
nn.init.constant_(self.classifier.bias, 0)
|
| 227 |
|
| 228 |
def create_padding_mask(self, x):
|
| 229 |
+
"""Create padding mask for attention"""
|
| 230 |
+
# Create mask where non-padding tokens are marked as 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
mask = (x != self.pad_idx).unsqueeze(1).unsqueeze(2)
|
| 232 |
return mask.float()
|
| 233 |
|
| 234 |
def forward(self, x, mask=None):
|
| 235 |
+
"""Forward pass for token classification"""
|
| 236 |
+
# Validate input dimensions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
if x.dim() != 2:
|
| 238 |
raise ValueError(f"Expected input to have 2 dimensions [batch_size, seq_len], got {x.dim()}")
|
| 239 |
|
|
|
|
| 243 |
if mask is None:
|
| 244 |
mask = self.create_padding_mask(x)
|
| 245 |
|
| 246 |
+
# Embed and scale by sqrt(d_model)
|
| 247 |
x = self.embedding(x) * math.sqrt(self.d_model)
|
| 248 |
|
| 249 |
# Add positional encoding
|
| 250 |
x = self.positional_encoding(x)
|
| 251 |
|
| 252 |
+
# Pass through transformer encoder stack
|
| 253 |
encoder_output = self.encoder(x, mask)
|
| 254 |
|
| 255 |
# Apply dropout before classification
|
| 256 |
encoder_output = self.dropout(encoder_output)
|
| 257 |
|
| 258 |
+
# Get class predictions for each token
|
| 259 |
logits = self.classifier(encoder_output)
|
| 260 |
|
| 261 |
return logits
|
| 262 |
|
| 263 |
def predict(self, x):
|
| 264 |
+
"""Get predictions for inference"""
|
| 265 |
+
# Switch to evaluation mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
self.eval()
|
| 267 |
with torch.no_grad():
|
| 268 |
logits = self.forward(x)
|
| 269 |
predictions = torch.argmax(logits, dim=-1)
|
| 270 |
return predictions
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
def create_transformer_pii_model(vocab_size, num_classes, d_model=256, num_heads=8,
|
| 273 |
d_ff=512, num_layers=4, dropout=0.1, max_len=512):
|
| 274 |
+
"""Factory function to create transformer model for PII detection"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
model = TransformerPII(
|
| 276 |
vocab_size=vocab_size,
|
| 277 |
num_classes=num_classes,
|
|
|
|
| 281 |
num_layers=num_layers,
|
| 282 |
dropout=dropout,
|
| 283 |
max_len=max_len,
|
| 284 |
+
pad_idx=0
|
| 285 |
)
|
| 286 |
|
| 287 |
return model
|
transformer_training.ipynb
CHANGED
|
@@ -42,7 +42,7 @@
|
|
| 42 |
},
|
| 43 |
{
|
| 44 |
"cell_type": "code",
|
| 45 |
-
"execution_count":
|
| 46 |
"id": "ff1782dd",
|
| 47 |
"metadata": {
|
| 48 |
"execution": {
|
|
@@ -62,19 +62,23 @@
|
|
| 62 |
},
|
| 63 |
"outputs": [],
|
| 64 |
"source": [
|
|
|
|
| 65 |
"class Vocabulary:\n",
|
| 66 |
" \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
|
| 67 |
" def __init__(self, max_size=100000):\n",
|
|
|
|
| 68 |
" self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
|
| 69 |
" self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
|
| 70 |
" self.word_count = Counter()\n",
|
| 71 |
" self.max_size = max_size\n",
|
| 72 |
" \n",
|
| 73 |
" def add_sentence(self, sentence):\n",
|
|
|
|
| 74 |
" for word in sentence:\n",
|
| 75 |
" self.word_count[word.lower()] += 1\n",
|
| 76 |
" \n",
|
| 77 |
" def build(self):\n",
|
|
|
|
| 78 |
" most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
|
| 79 |
" for word, _ in most_common:\n",
|
| 80 |
" if word not in self.word2idx:\n",
|
|
@@ -86,15 +90,17 @@
|
|
| 86 |
" return len(self.word2idx)\n",
|
| 87 |
" \n",
|
| 88 |
" def encode(self, sentence):\n",
|
|
|
|
| 89 |
" return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
|
| 90 |
" \n",
|
| 91 |
" def decode(self, indices):\n",
|
|
|
|
| 92 |
" return [self.idx2word.get(idx, '<unk>') for idx in indices]"
|
| 93 |
]
|
| 94 |
},
|
| 95 |
{
|
| 96 |
"cell_type": "code",
|
| 97 |
-
"execution_count":
|
| 98 |
"id": "5b2b46d6",
|
| 99 |
"metadata": {
|
| 100 |
"execution": {
|
|
@@ -114,6 +120,7 @@
|
|
| 114 |
},
|
| 115 |
"outputs": [],
|
| 116 |
"source": [
|
|
|
|
| 117 |
"class PIIDataset(Dataset):\n",
|
| 118 |
" \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
|
| 119 |
" def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
|
|
@@ -127,16 +134,16 @@
|
|
| 127 |
" return len(self.tokens)\n",
|
| 128 |
" \n",
|
| 129 |
" def __getitem__(self, idx):\n",
|
| 130 |
-
" # Add
|
| 131 |
" tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
|
| 132 |
" labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
|
| 133 |
" \n",
|
| 134 |
-
" # Truncate if too long\n",
|
| 135 |
" if len(tokens) > self.max_len:\n",
|
| 136 |
" tokens = tokens[:self.max_len-1] + ['<end>']\n",
|
| 137 |
" labels = labels[:self.max_len-1] + ['<end>']\n",
|
| 138 |
" \n",
|
| 139 |
-
" #
|
| 140 |
" token_ids = self.text_vocab.encode(tokens)\n",
|
| 141 |
" label_ids = self.label_vocab.encode(labels)\n",
|
| 142 |
" \n",
|
|
@@ -145,7 +152,7 @@
|
|
| 145 |
},
|
| 146 |
{
|
| 147 |
"cell_type": "code",
|
| 148 |
-
"execution_count":
|
| 149 |
"id": "e7ca8f8f",
|
| 150 |
"metadata": {
|
| 151 |
"execution": {
|
|
@@ -167,7 +174,9 @@
|
|
| 167 |
"source": [
|
| 168 |
"def collate_fn(batch):\n",
|
| 169 |
" \"\"\"Custom collate function for padding sequences\"\"\"\n",
|
|
|
|
| 170 |
" tokens, labels = zip(*batch)\n",
|
|
|
|
| 171 |
" tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
|
| 172 |
" labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
|
| 173 |
" return tokens_padded, labels_padded"
|
|
@@ -175,7 +184,7 @@
|
|
| 175 |
},
|
| 176 |
{
|
| 177 |
"cell_type": "code",
|
| 178 |
-
"execution_count":
|
| 179 |
"id": "85b32e21",
|
| 180 |
"metadata": {
|
| 181 |
"execution": {
|
|
@@ -195,6 +204,7 @@
|
|
| 195 |
},
|
| 196 |
"outputs": [],
|
| 197 |
"source": [
|
|
|
|
| 198 |
"class F1ScoreMetric:\n",
|
| 199 |
" \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
|
| 200 |
" def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
|
|
@@ -205,30 +215,37 @@
|
|
| 205 |
" self.reset()\n",
|
| 206 |
" \n",
|
| 207 |
" def reset(self):\n",
|
|
|
|
| 208 |
" self.true_positives = 0\n",
|
| 209 |
" self.false_positives = 0\n",
|
| 210 |
" self.false_negatives = 0\n",
|
| 211 |
" self.class_metrics = {}\n",
|
| 212 |
" \n",
|
| 213 |
" def update(self, predictions, targets):\n",
|
|
|
|
| 214 |
" mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
|
| 215 |
" o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
|
| 216 |
" \n",
|
|
|
|
| 217 |
" for class_id in range(1, self.num_classes):\n",
|
| 218 |
" if class_id == o_idx:\n",
|
| 219 |
" continue\n",
|
| 220 |
" \n",
|
|
|
|
| 221 |
" pred_mask = (predictions == class_id) & mask\n",
|
| 222 |
" true_mask = (targets == class_id) & mask\n",
|
| 223 |
" \n",
|
|
|
|
| 224 |
" tp = ((pred_mask) & (true_mask)).sum().item()\n",
|
| 225 |
" fp = ((pred_mask) & (~true_mask)).sum().item()\n",
|
| 226 |
" fn = ((~pred_mask) & (true_mask)).sum().item()\n",
|
| 227 |
" \n",
|
|
|
|
| 228 |
" self.true_positives += tp\n",
|
| 229 |
" self.false_positives += fp\n",
|
| 230 |
" self.false_negatives += fn\n",
|
| 231 |
" \n",
|
|
|
|
| 232 |
" if class_id not in self.class_metrics:\n",
|
| 233 |
" self.class_metrics[class_id] = {'tp': 0, 'fp': 0, 'fn': 0}\n",
|
| 234 |
" self.class_metrics[class_id]['tp'] += tp\n",
|
|
@@ -236,6 +253,7 @@
|
|
| 236 |
" self.class_metrics[class_id]['fn'] += fn\n",
|
| 237 |
" \n",
|
| 238 |
" def compute(self):\n",
|
|
|
|
| 239 |
" beta_squared = self.beta ** 2\n",
|
| 240 |
" precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
|
| 241 |
" recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
|
|
@@ -243,6 +261,7 @@
|
|
| 243 |
" return f1\n",
|
| 244 |
" \n",
|
| 245 |
" def get_class_metrics(self):\n",
|
|
|
|
| 246 |
" results = {}\n",
|
| 247 |
" for class_id, metrics in self.class_metrics.items():\n",
|
| 248 |
" if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
|
|
@@ -261,7 +280,7 @@
|
|
| 261 |
},
|
| 262 |
{
|
| 263 |
"cell_type": "code",
|
| 264 |
-
"execution_count":
|
| 265 |
"id": "60cf16eb",
|
| 266 |
"metadata": {
|
| 267 |
"execution": {
|
|
@@ -281,6 +300,7 @@
|
|
| 281 |
},
|
| 282 |
"outputs": [],
|
| 283 |
"source": [
|
|
|
|
| 284 |
"class FocalLoss(nn.Module):\n",
|
| 285 |
" \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
|
| 286 |
" def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
|
|
@@ -291,6 +311,7 @@
|
|
| 291 |
" self.ignore_index = ignore_index\n",
|
| 292 |
" \n",
|
| 293 |
" def forward(self, inputs, targets):\n",
|
|
|
|
| 294 |
" ce_loss = nn.functional.cross_entropy(\n",
|
| 295 |
" inputs, targets, \n",
|
| 296 |
" weight=self.alpha, \n",
|
|
@@ -298,9 +319,11 @@
|
|
| 298 |
" ignore_index=self.ignore_index\n",
|
| 299 |
" )\n",
|
| 300 |
" \n",
|
|
|
|
| 301 |
" pt = torch.exp(-ce_loss)\n",
|
| 302 |
" focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
|
| 303 |
" \n",
|
|
|
|
| 304 |
" if self.reduction == 'mean':\n",
|
| 305 |
" return focal_loss.mean()\n",
|
| 306 |
" elif self.reduction == 'sum':\n",
|
|
@@ -311,7 +334,7 @@
|
|
| 311 |
},
|
| 312 |
{
|
| 313 |
"cell_type": "code",
|
| 314 |
-
"execution_count":
|
| 315 |
"id": "4e56747c",
|
| 316 |
"metadata": {
|
| 317 |
"execution": {
|
|
@@ -337,8 +360,10 @@
|
|
| 337 |
" total_loss = 0\n",
|
| 338 |
" f1_metric.reset()\n",
|
| 339 |
" \n",
|
|
|
|
| 340 |
" progress_bar = tqdm(dataloader, desc='Training')\n",
|
| 341 |
" for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
|
|
|
|
| 342 |
" tokens = tokens.to(device)\n",
|
| 343 |
" labels = labels.to(device)\n",
|
| 344 |
" \n",
|
|
@@ -350,9 +375,10 @@
|
|
| 350 |
" outputs_flat = outputs.view(-1, outputs.size(-1))\n",
|
| 351 |
" labels_flat = labels.view(-1)\n",
|
| 352 |
" \n",
|
| 353 |
-
" # Calculate loss and
|
| 354 |
" loss = criterion(outputs_flat, labels_flat)\n",
|
| 355 |
" loss.backward()\n",
|
|
|
|
| 356 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
|
| 357 |
" optimizer.step()\n",
|
| 358 |
" \n",
|
|
@@ -372,7 +398,7 @@
|
|
| 372 |
},
|
| 373 |
{
|
| 374 |
"cell_type": "code",
|
| 375 |
-
"execution_count":
|
| 376 |
"id": "8a2e8d19",
|
| 377 |
"metadata": {
|
| 378 |
"execution": {
|
|
@@ -398,8 +424,10 @@
|
|
| 398 |
" total_loss = 0\n",
|
| 399 |
" f1_metric.reset()\n",
|
| 400 |
" \n",
|
|
|
|
| 401 |
" with torch.no_grad():\n",
|
| 402 |
" for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
|
|
|
|
| 403 |
" tokens = tokens.to(device)\n",
|
| 404 |
" labels = labels.to(device)\n",
|
| 405 |
" \n",
|
|
@@ -421,7 +449,7 @@
|
|
| 421 |
},
|
| 422 |
{
|
| 423 |
"cell_type": "code",
|
| 424 |
-
"execution_count":
|
| 425 |
"id": "6e292ace",
|
| 426 |
"metadata": {
|
| 427 |
"execution": {
|
|
@@ -445,6 +473,7 @@
|
|
| 445 |
" \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
|
| 446 |
" sample_weights = []\n",
|
| 447 |
" \n",
|
|
|
|
| 448 |
" for idx in range(len(dataset)):\n",
|
| 449 |
" _, labels = dataset[idx]\n",
|
| 450 |
" \n",
|
|
@@ -453,24 +482,26 @@
|
|
| 453 |
" for label_id in labels:\n",
|
| 454 |
" if label_id > 3: # Skip special tokens\n",
|
| 455 |
" label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
|
|
|
|
| 456 |
" if label_name != 'o' and 'B-' in label_name:\n",
|
| 457 |
" min_weight = 10.0\n",
|
| 458 |
" break\n",
|
| 459 |
" \n",
|
| 460 |
" sample_weights.append(min_weight)\n",
|
| 461 |
" \n",
|
|
|
|
| 462 |
" sampler = WeightedRandomSampler(\n",
|
| 463 |
" weights=sample_weights,\n",
|
| 464 |
" num_samples=len(sample_weights),\n",
|
| 465 |
" replacement=True\n",
|
| 466 |
" )\n",
|
| 467 |
" \n",
|
| 468 |
-
" return sampler
|
| 469 |
]
|
| 470 |
},
|
| 471 |
{
|
| 472 |
"cell_type": "code",
|
| 473 |
-
"execution_count":
|
| 474 |
"id": "857335cb",
|
| 475 |
"metadata": {
|
| 476 |
"execution": {
|
|
@@ -493,11 +524,14 @@
|
|
| 493 |
"def print_label_distribution(data, title=\"Label Distribution\"):\n",
|
| 494 |
" \"\"\"Print label distribution statistics\"\"\"\n",
|
| 495 |
" label_counts = Counter()\n",
|
|
|
|
|
|
|
| 496 |
" for label_seq in data.labels:\n",
|
| 497 |
" for label in label_seq:\n",
|
| 498 |
" if label not in ['<pad>', '<start>', '<end>']:\n",
|
| 499 |
" label_counts[label] += 1\n",
|
| 500 |
" \n",
|
|
|
|
| 501 |
" print(f\"\\n{title}:\")\n",
|
| 502 |
" print(\"-\" * 50)\n",
|
| 503 |
" total = sum(label_counts.values())\n",
|
|
@@ -510,7 +544,7 @@
|
|
| 510 |
},
|
| 511 |
{
|
| 512 |
"cell_type": "code",
|
| 513 |
-
"execution_count":
|
| 514 |
"id": "1738f8a9",
|
| 515 |
"metadata": {
|
| 516 |
"execution": {
|
|
@@ -532,9 +566,10 @@
|
|
| 532 |
"source": [
|
| 533 |
"def save_model(model, text_vocab, label_vocab, config, save_dir):\n",
|
| 534 |
" \"\"\"Save model and all necessary components for Flask deployment\"\"\"\n",
|
|
|
|
| 535 |
" os.makedirs(save_dir, exist_ok=True)\n",
|
| 536 |
" \n",
|
| 537 |
-
" # Save model
|
| 538 |
" model_path = os.path.join(save_dir, 'pii_transformer_model.pt')\n",
|
| 539 |
" torch.save(model.state_dict(), model_path)\n",
|
| 540 |
" \n",
|
|
@@ -560,7 +595,7 @@
|
|
| 560 |
},
|
| 561 |
{
|
| 562 |
"cell_type": "code",
|
| 563 |
-
"execution_count":
|
| 564 |
"id": "d93e7c25",
|
| 565 |
"metadata": {
|
| 566 |
"execution": {
|
|
@@ -591,12 +626,12 @@
|
|
| 591 |
"):\n",
|
| 592 |
" \"\"\"Main training function\"\"\"\n",
|
| 593 |
" \n",
|
| 594 |
-
" # Load data\n",
|
| 595 |
" print(\"Loading augmented data...\")\n",
|
| 596 |
" data = pd.read_json(data_path, lines=True)\n",
|
| 597 |
" print(f\"Total samples: {len(data)}\")\n",
|
| 598 |
" \n",
|
| 599 |
-
" #
|
| 600 |
" print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
|
| 601 |
" \n",
|
| 602 |
" # Build vocabularies\n",
|
|
@@ -604,19 +639,21 @@
|
|
| 604 |
" text_vocab = Vocabulary(max_size=100000)\n",
|
| 605 |
" label_vocab = Vocabulary(max_size=50)\n",
|
| 606 |
" \n",
|
|
|
|
| 607 |
" for tokens in data.tokens:\n",
|
| 608 |
" text_vocab.add_sentence(tokens)\n",
|
| 609 |
" for labels in data.labels:\n",
|
| 610 |
" label_vocab.add_sentence(labels)\n",
|
| 611 |
" \n",
|
|
|
|
| 612 |
" text_vocab.build()\n",
|
| 613 |
" label_vocab.build()\n",
|
| 614 |
" \n",
|
| 615 |
-
" # Calculate class weights\n",
|
| 616 |
" class_weights = calculate_class_weights(data, label_vocab)\n",
|
| 617 |
" class_weights = class_weights.to(device)\n",
|
| 618 |
" \n",
|
| 619 |
-
" # Split data\n",
|
| 620 |
" X_train, X_val, y_train, y_val = train_test_split(\n",
|
| 621 |
" data.tokens.tolist(),\n",
|
| 622 |
" data.labels.tolist(),\n",
|
|
@@ -628,13 +665,15 @@
|
|
| 628 |
" print(f\" - Train samples: {len(X_train):,}\")\n",
|
| 629 |
" print(f\" - Validation samples: {len(X_val):,}\")\n",
|
| 630 |
" \n",
|
| 631 |
-
" # Create datasets
|
| 632 |
" max_seq_len = 512\n",
|
| 633 |
" train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
|
| 634 |
" val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
|
| 635 |
" \n",
|
|
|
|
| 636 |
" train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
|
| 637 |
" \n",
|
|
|
|
| 638 |
" train_loader = DataLoader(\n",
|
| 639 |
" train_dataset, \n",
|
| 640 |
" batch_size=batch_size,\n",
|
|
@@ -663,7 +702,7 @@
|
|
| 663 |
" 'max_len': max_seq_len\n",
|
| 664 |
" }\n",
|
| 665 |
" \n",
|
| 666 |
-
" # Create model\n",
|
| 667 |
" print(\"\\nCreating model...\")\n",
|
| 668 |
" model = create_transformer_pii_model(**model_config).to(device)\n",
|
| 669 |
" print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
|
|
@@ -678,15 +717,15 @@
|
|
| 678 |
" else:\n",
|
| 679 |
" criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=0)\n",
|
| 680 |
" \n",
|
| 681 |
-
" # Setup optimizer and scheduler\n",
|
| 682 |
" optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)\n",
|
| 683 |
" scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)\n",
|
| 684 |
" \n",
|
| 685 |
-
" #
|
| 686 |
" f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
|
| 687 |
" f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
|
| 688 |
" \n",
|
| 689 |
-
" # Training
|
| 690 |
" train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
|
| 691 |
" best_val_f1 = 0\n",
|
| 692 |
" patience = 5\n",
|
|
@@ -695,18 +734,21 @@
|
|
| 695 |
" print(\"\\nStarting training...\")\n",
|
| 696 |
" print(\"=\" * 60)\n",
|
| 697 |
" \n",
|
|
|
|
| 698 |
" for epoch in range(num_epochs):\n",
|
| 699 |
" print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
|
| 700 |
" \n",
|
| 701 |
-
" # Train
|
| 702 |
" train_loss, train_f1 = train_epoch(\n",
|
| 703 |
" model, train_loader, optimizer, criterion, device, f1_metric_train\n",
|
| 704 |
" )\n",
|
|
|
|
|
|
|
| 705 |
" val_loss, val_f1 = evaluate(\n",
|
| 706 |
" model, val_loader, criterion, device, f1_metric_val\n",
|
| 707 |
" )\n",
|
| 708 |
" \n",
|
| 709 |
-
" #
|
| 710 |
" scheduler.step(val_loss)\n",
|
| 711 |
" \n",
|
| 712 |
" # Store metrics\n",
|
|
@@ -743,7 +785,7 @@
|
|
| 743 |
" else:\n",
|
| 744 |
" patience_counter += 1\n",
|
| 745 |
" \n",
|
| 746 |
-
" # Early stopping\n",
|
| 747 |
" if patience_counter >= patience and epoch > 10:\n",
|
| 748 |
" print(f\"\\nEarly stopping triggered after {patience} epochs without improvement\")\n",
|
| 749 |
" break\n",
|
|
@@ -751,6 +793,7 @@
|
|
| 751 |
" # Plot training curves\n",
|
| 752 |
" plt.figure(figsize=(12, 5))\n",
|
| 753 |
" \n",
|
|
|
|
| 754 |
" plt.subplot(1, 2, 1)\n",
|
| 755 |
" plt.plot(train_losses, label='Train Loss', linewidth=2)\n",
|
| 756 |
" plt.plot(val_losses, label='Val Loss', linewidth=2)\n",
|
|
@@ -760,6 +803,7 @@
|
|
| 760 |
" plt.legend()\n",
|
| 761 |
" plt.grid(True, alpha=0.3)\n",
|
| 762 |
" \n",
|
|
|
|
| 763 |
" plt.subplot(1, 2, 2)\n",
|
| 764 |
" plt.plot(train_f1s, label='Train F1', linewidth=2)\n",
|
| 765 |
" plt.plot(val_f1s, label='Val F1', linewidth=2)\n",
|
|
@@ -777,6 +821,7 @@
|
|
| 777 |
" print(f\"Training completed!\")\n",
|
| 778 |
" print(f\"Best validation F1: {best_val_f1:.4f}\")\n",
|
| 779 |
" \n",
|
|
|
|
| 780 |
" save_model(model, text_vocab, label_vocab, model_config, 'saved_transformer_model')\n",
|
| 781 |
" \n",
|
| 782 |
" return model, text_vocab, label_vocab"
|
|
@@ -1251,9 +1296,11 @@
|
|
| 1251 |
}
|
| 1252 |
],
|
| 1253 |
"source": [
|
|
|
|
| 1254 |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 1255 |
"print(f\"Using device: {device}\")\n",
|
| 1256 |
"\n",
|
|
|
|
| 1257 |
"model, text_vocab, label_vocab = train_transformer_pii_model(\n",
|
| 1258 |
" data_path='train_augmented.json',\n",
|
| 1259 |
" num_epochs=20,\n",
|
|
|
|
| 42 |
},
|
| 43 |
{
|
| 44 |
"cell_type": "code",
|
| 45 |
+
"execution_count": null,
|
| 46 |
"id": "ff1782dd",
|
| 47 |
"metadata": {
|
| 48 |
"execution": {
|
|
|
|
| 62 |
},
|
| 63 |
"outputs": [],
|
| 64 |
"source": [
|
| 65 |
+
"# Vocabulary class for text and label encoding\n",
|
| 66 |
"class Vocabulary:\n",
|
| 67 |
" \"\"\"Vocabulary class for encoding/decoding text and labels\"\"\"\n",
|
| 68 |
" def __init__(self, max_size=100000):\n",
|
| 69 |
+
" # Initialize special tokens\n",
|
| 70 |
" self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}\n",
|
| 71 |
" self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}\n",
|
| 72 |
" self.word_count = Counter()\n",
|
| 73 |
" self.max_size = max_size\n",
|
| 74 |
" \n",
|
| 75 |
" def add_sentence(self, sentence):\n",
|
| 76 |
+
" # Count word frequencies\n",
|
| 77 |
" for word in sentence:\n",
|
| 78 |
" self.word_count[word.lower()] += 1\n",
|
| 79 |
" \n",
|
| 80 |
" def build(self):\n",
|
| 81 |
+
" # Build vocabulary from most common words\n",
|
| 82 |
" most_common = self.word_count.most_common(self.max_size - len(self.word2idx))\n",
|
| 83 |
" for word, _ in most_common:\n",
|
| 84 |
" if word not in self.word2idx:\n",
|
|
|
|
| 90 |
" return len(self.word2idx)\n",
|
| 91 |
" \n",
|
| 92 |
" def encode(self, sentence):\n",
|
| 93 |
+
" # Convert words to indices\n",
|
| 94 |
" return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]\n",
|
| 95 |
" \n",
|
| 96 |
" def decode(self, indices):\n",
|
| 97 |
+
" # Convert indices back to words\n",
|
| 98 |
" return [self.idx2word.get(idx, '<unk>') for idx in indices]"
|
| 99 |
]
|
| 100 |
},
|
| 101 |
{
|
| 102 |
"cell_type": "code",
|
| 103 |
+
"execution_count": null,
|
| 104 |
"id": "5b2b46d6",
|
| 105 |
"metadata": {
|
| 106 |
"execution": {
|
|
|
|
| 120 |
},
|
| 121 |
"outputs": [],
|
| 122 |
"source": [
|
| 123 |
+
"# Dataset class for PII detection\n",
|
| 124 |
"class PIIDataset(Dataset):\n",
|
| 125 |
" \"\"\"PyTorch Dataset for PII detection\"\"\"\n",
|
| 126 |
" def __init__(self, tokens, labels, text_vocab, label_vocab, max_len=512):\n",
|
|
|
|
| 134 |
" return len(self.tokens)\n",
|
| 135 |
" \n",
|
| 136 |
" def __getitem__(self, idx):\n",
|
| 137 |
+
" # Add special tokens at beginning and end\n",
|
| 138 |
" tokens = ['<start>'] + self.tokens[idx] + ['<end>']\n",
|
| 139 |
" labels = ['<start>'] + self.labels[idx] + ['<end>']\n",
|
| 140 |
" \n",
|
| 141 |
+
" # Truncate if sequence is too long\n",
|
| 142 |
" if len(tokens) > self.max_len:\n",
|
| 143 |
" tokens = tokens[:self.max_len-1] + ['<end>']\n",
|
| 144 |
" labels = labels[:self.max_len-1] + ['<end>']\n",
|
| 145 |
" \n",
|
| 146 |
+
" # Convert to indices\n",
|
| 147 |
" token_ids = self.text_vocab.encode(tokens)\n",
|
| 148 |
" label_ids = self.label_vocab.encode(labels)\n",
|
| 149 |
" \n",
|
|
|
|
| 152 |
},
|
| 153 |
{
|
| 154 |
"cell_type": "code",
|
| 155 |
+
"execution_count": null,
|
| 156 |
"id": "e7ca8f8f",
|
| 157 |
"metadata": {
|
| 158 |
"execution": {
|
|
|
|
| 174 |
"source": [
|
| 175 |
"def collate_fn(batch):\n",
|
| 176 |
" \"\"\"Custom collate function for padding sequences\"\"\"\n",
|
| 177 |
+
" # Separate tokens and labels\n",
|
| 178 |
" tokens, labels = zip(*batch)\n",
|
| 179 |
+
" # Pad sequences to same length in batch\n",
|
| 180 |
" tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=0)\n",
|
| 181 |
" labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)\n",
|
| 182 |
" return tokens_padded, labels_padded"
|
|
|
|
| 184 |
},
|
| 185 |
{
|
| 186 |
"cell_type": "code",
|
| 187 |
+
"execution_count": null,
|
| 188 |
"id": "85b32e21",
|
| 189 |
"metadata": {
|
| 190 |
"execution": {
|
|
|
|
| 204 |
},
|
| 205 |
"outputs": [],
|
| 206 |
"source": [
|
| 207 |
+
"# F1 score metric for evaluation\n",
|
| 208 |
"class F1ScoreMetric:\n",
|
| 209 |
" \"\"\"Custom F1 score metric with beta parameter\"\"\"\n",
|
| 210 |
" def __init__(self, beta=5, num_classes=20, ignore_index=0, label_vocab=None):\n",
|
|
|
|
| 215 |
" self.reset()\n",
|
| 216 |
" \n",
|
| 217 |
" def reset(self):\n",
|
| 218 |
+
" # Reset all counters\n",
|
| 219 |
" self.true_positives = 0\n",
|
| 220 |
" self.false_positives = 0\n",
|
| 221 |
" self.false_negatives = 0\n",
|
| 222 |
" self.class_metrics = {}\n",
|
| 223 |
" \n",
|
| 224 |
" def update(self, predictions, targets):\n",
|
| 225 |
+
" # Create mask to ignore padding and special tokens\n",
|
| 226 |
" mask = (targets != self.ignore_index) & (targets != 2) & (targets != 3)\n",
|
| 227 |
" o_idx = self.label_vocab.word2idx.get('o', -1) if self.label_vocab else -1\n",
|
| 228 |
" \n",
|
| 229 |
+
" # Calculate metrics for each PII class\n",
|
| 230 |
" for class_id in range(1, self.num_classes):\n",
|
| 231 |
" if class_id == o_idx:\n",
|
| 232 |
" continue\n",
|
| 233 |
" \n",
|
| 234 |
+
" # Find where predictions and targets match this class\n",
|
| 235 |
" pred_mask = (predictions == class_id) & mask\n",
|
| 236 |
" true_mask = (targets == class_id) & mask\n",
|
| 237 |
" \n",
|
| 238 |
+
" # Count true positives, false positives, false negatives\n",
|
| 239 |
" tp = ((pred_mask) & (true_mask)).sum().item()\n",
|
| 240 |
" fp = ((pred_mask) & (~true_mask)).sum().item()\n",
|
| 241 |
" fn = ((~pred_mask) & (true_mask)).sum().item()\n",
|
| 242 |
" \n",
|
| 243 |
+
" # Update total counts\n",
|
| 244 |
" self.true_positives += tp\n",
|
| 245 |
" self.false_positives += fp\n",
|
| 246 |
" self.false_negatives += fn\n",
|
| 247 |
" \n",
|
| 248 |
+
" # Store per-class metrics\n",
|
| 249 |
" if class_id not in self.class_metrics:\n",
|
| 250 |
" self.class_metrics[class_id] = {'tp': 0, 'fp': 0, 'fn': 0}\n",
|
| 251 |
" self.class_metrics[class_id]['tp'] += tp\n",
|
|
|
|
| 253 |
" self.class_metrics[class_id]['fn'] += fn\n",
|
| 254 |
" \n",
|
| 255 |
" def compute(self):\n",
|
| 256 |
+
" # Calculate F-beta score\n",
|
| 257 |
" beta_squared = self.beta ** 2\n",
|
| 258 |
" precision = self.true_positives / (self.true_positives + self.false_positives + 1e-8)\n",
|
| 259 |
" recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-8)\n",
|
|
|
|
| 261 |
" return f1\n",
|
| 262 |
" \n",
|
| 263 |
" def get_class_metrics(self):\n",
|
| 264 |
+
" # Get detailed metrics for each class\n",
|
| 265 |
" results = {}\n",
|
| 266 |
" for class_id, metrics in self.class_metrics.items():\n",
|
| 267 |
" if self.label_vocab and class_id in self.label_vocab.idx2word:\n",
|
|
|
|
| 280 |
},
|
| 281 |
{
|
| 282 |
"cell_type": "code",
|
| 283 |
+
"execution_count": null,
|
| 284 |
"id": "60cf16eb",
|
| 285 |
"metadata": {
|
| 286 |
"execution": {
|
|
|
|
| 300 |
},
|
| 301 |
"outputs": [],
|
| 302 |
"source": [
|
| 303 |
+
"# Focal loss for handling class imbalance\n",
|
| 304 |
"class FocalLoss(nn.Module):\n",
|
| 305 |
" \"\"\"Focal Loss for addressing class imbalance\"\"\"\n",
|
| 306 |
" def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
|
|
|
|
| 311 |
" self.ignore_index = ignore_index\n",
|
| 312 |
" \n",
|
| 313 |
" def forward(self, inputs, targets):\n",
|
| 314 |
+
" # Calculate cross entropy loss\n",
|
| 315 |
" ce_loss = nn.functional.cross_entropy(\n",
|
| 316 |
" inputs, targets, \n",
|
| 317 |
" weight=self.alpha, \n",
|
|
|
|
| 319 |
" ignore_index=self.ignore_index\n",
|
| 320 |
" )\n",
|
| 321 |
" \n",
|
| 322 |
+
" # Apply focal term to focus on hard examples\n",
|
| 323 |
" pt = torch.exp(-ce_loss)\n",
|
| 324 |
" focal_loss = (1 - pt) ** self.gamma * ce_loss\n",
|
| 325 |
" \n",
|
| 326 |
+
" # Reduce loss based on specified method\n",
|
| 327 |
" if self.reduction == 'mean':\n",
|
| 328 |
" return focal_loss.mean()\n",
|
| 329 |
" elif self.reduction == 'sum':\n",
|
|
|
|
| 334 |
},
|
| 335 |
{
|
| 336 |
"cell_type": "code",
|
| 337 |
+
"execution_count": null,
|
| 338 |
"id": "4e56747c",
|
| 339 |
"metadata": {
|
| 340 |
"execution": {
|
|
|
|
| 360 |
" total_loss = 0\n",
|
| 361 |
" f1_metric.reset()\n",
|
| 362 |
" \n",
|
| 363 |
+
" # Progress bar for training\n",
|
| 364 |
" progress_bar = tqdm(dataloader, desc='Training')\n",
|
| 365 |
" for batch_idx, (tokens, labels) in enumerate(progress_bar):\n",
|
| 366 |
+
" # Move data to device\n",
|
| 367 |
" tokens = tokens.to(device)\n",
|
| 368 |
" labels = labels.to(device)\n",
|
| 369 |
" \n",
|
|
|
|
| 375 |
" outputs_flat = outputs.view(-1, outputs.size(-1))\n",
|
| 376 |
" labels_flat = labels.view(-1)\n",
|
| 377 |
" \n",
|
| 378 |
+
" # Calculate loss and backpropagate\n",
|
| 379 |
" loss = criterion(outputs_flat, labels_flat)\n",
|
| 380 |
" loss.backward()\n",
|
| 381 |
+
" # Clip gradients to prevent exploding gradients\n",
|
| 382 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
|
| 383 |
" optimizer.step()\n",
|
| 384 |
" \n",
|
|
|
|
| 398 |
},
|
| 399 |
{
|
| 400 |
"cell_type": "code",
|
| 401 |
+
"execution_count": null,
|
| 402 |
"id": "8a2e8d19",
|
| 403 |
"metadata": {
|
| 404 |
"execution": {
|
|
|
|
| 424 |
" total_loss = 0\n",
|
| 425 |
" f1_metric.reset()\n",
|
| 426 |
" \n",
|
| 427 |
+
" # No gradient computation during evaluation\n",
|
| 428 |
" with torch.no_grad():\n",
|
| 429 |
" for tokens, labels in tqdm(dataloader, desc='Evaluating'):\n",
|
| 430 |
+
" # Move data to device\n",
|
| 431 |
" tokens = tokens.to(device)\n",
|
| 432 |
" labels = labels.to(device)\n",
|
| 433 |
" \n",
|
|
|
|
| 449 |
},
|
| 450 |
{
|
| 451 |
"cell_type": "code",
|
| 452 |
+
"execution_count": null,
|
| 453 |
"id": "6e292ace",
|
| 454 |
"metadata": {
|
| 455 |
"execution": {
|
|
|
|
| 473 |
" \"\"\"Create a weighted sampler to balance classes during training\"\"\"\n",
|
| 474 |
" sample_weights = []\n",
|
| 475 |
" \n",
|
| 476 |
+
" # Calculate weight for each sample\n",
|
| 477 |
" for idx in range(len(dataset)):\n",
|
| 478 |
" _, labels = dataset[idx]\n",
|
| 479 |
" \n",
|
|
|
|
| 482 |
" for label_id in labels:\n",
|
| 483 |
" if label_id > 3: # Skip special tokens\n",
|
| 484 |
" label_name = label_vocab.idx2word.get(label_id.item(), 'O')\n",
|
| 485 |
+
" # If sample contains PII, give it higher weight\n",
|
| 486 |
" if label_name != 'o' and 'B-' in label_name:\n",
|
| 487 |
" min_weight = 10.0\n",
|
| 488 |
" break\n",
|
| 489 |
" \n",
|
| 490 |
" sample_weights.append(min_weight)\n",
|
| 491 |
" \n",
|
| 492 |
+
" # Create weighted sampler\n",
|
| 493 |
" sampler = WeightedRandomSampler(\n",
|
| 494 |
" weights=sample_weights,\n",
|
| 495 |
" num_samples=len(sample_weights),\n",
|
| 496 |
" replacement=True\n",
|
| 497 |
" )\n",
|
| 498 |
" \n",
|
| 499 |
+
" return sampler"
|
| 500 |
]
|
| 501 |
},
|
| 502 |
{
|
| 503 |
"cell_type": "code",
|
| 504 |
+
"execution_count": null,
|
| 505 |
"id": "857335cb",
|
| 506 |
"metadata": {
|
| 507 |
"execution": {
|
|
|
|
| 524 |
"def print_label_distribution(data, title=\"Label Distribution\"):\n",
|
| 525 |
" \"\"\"Print label distribution statistics\"\"\"\n",
|
| 526 |
" label_counts = Counter()\n",
|
| 527 |
+
" \n",
|
| 528 |
+
" # Count each label type\n",
|
| 529 |
" for label_seq in data.labels:\n",
|
| 530 |
" for label in label_seq:\n",
|
| 531 |
" if label not in ['<pad>', '<start>', '<end>']:\n",
|
| 532 |
" label_counts[label] += 1\n",
|
| 533 |
" \n",
|
| 534 |
+
" # Print distribution\n",
|
| 535 |
" print(f\"\\n{title}:\")\n",
|
| 536 |
" print(\"-\" * 50)\n",
|
| 537 |
" total = sum(label_counts.values())\n",
|
|
|
|
| 544 |
},
|
| 545 |
{
|
| 546 |
"cell_type": "code",
|
| 547 |
+
"execution_count": null,
|
| 548 |
"id": "1738f8a9",
|
| 549 |
"metadata": {
|
| 550 |
"execution": {
|
|
|
|
| 566 |
"source": [
|
| 567 |
"def save_model(model, text_vocab, label_vocab, config, save_dir):\n",
|
| 568 |
" \"\"\"Save model and all necessary components for Flask deployment\"\"\"\n",
|
| 569 |
+
" # Create directory if it doesn't exist\n",
|
| 570 |
" os.makedirs(save_dir, exist_ok=True)\n",
|
| 571 |
" \n",
|
| 572 |
+
" # Save model weights\n",
|
| 573 |
" model_path = os.path.join(save_dir, 'pii_transformer_model.pt')\n",
|
| 574 |
" torch.save(model.state_dict(), model_path)\n",
|
| 575 |
" \n",
|
|
|
|
| 595 |
},
|
| 596 |
{
|
| 597 |
"cell_type": "code",
|
| 598 |
+
"execution_count": null,
|
| 599 |
"id": "d93e7c25",
|
| 600 |
"metadata": {
|
| 601 |
"execution": {
|
|
|
|
| 626 |
"):\n",
|
| 627 |
" \"\"\"Main training function\"\"\"\n",
|
| 628 |
" \n",
|
| 629 |
+
" # Load augmented data\n",
|
| 630 |
" print(\"Loading augmented data...\")\n",
|
| 631 |
" data = pd.read_json(data_path, lines=True)\n",
|
| 632 |
" print(f\"Total samples: {len(data)}\")\n",
|
| 633 |
" \n",
|
| 634 |
+
" # Show label distribution\n",
|
| 635 |
" print_label_distribution(data, \"Label Distribution in Augmented Data\")\n",
|
| 636 |
" \n",
|
| 637 |
" # Build vocabularies\n",
|
|
|
|
| 639 |
" text_vocab = Vocabulary(max_size=100000)\n",
|
| 640 |
" label_vocab = Vocabulary(max_size=50)\n",
|
| 641 |
" \n",
|
| 642 |
+
" # Add all words and labels to vocabularies\n",
|
| 643 |
" for tokens in data.tokens:\n",
|
| 644 |
" text_vocab.add_sentence(tokens)\n",
|
| 645 |
" for labels in data.labels:\n",
|
| 646 |
" label_vocab.add_sentence(labels)\n",
|
| 647 |
" \n",
|
| 648 |
+
" # Build vocabularies from collected words\n",
|
| 649 |
" text_vocab.build()\n",
|
| 650 |
" label_vocab.build()\n",
|
| 651 |
" \n",
|
| 652 |
+
" # Calculate class weights for balanced loss\n",
|
| 653 |
" class_weights = calculate_class_weights(data, label_vocab)\n",
|
| 654 |
" class_weights = class_weights.to(device)\n",
|
| 655 |
" \n",
|
| 656 |
+
" # Split data into train and validation sets\n",
|
| 657 |
" X_train, X_val, y_train, y_val = train_test_split(\n",
|
| 658 |
" data.tokens.tolist(),\n",
|
| 659 |
" data.labels.tolist(),\n",
|
|
|
|
| 665 |
" print(f\" - Train samples: {len(X_train):,}\")\n",
|
| 666 |
" print(f\" - Validation samples: {len(X_val):,}\")\n",
|
| 667 |
" \n",
|
| 668 |
+
" # Create datasets\n",
|
| 669 |
" max_seq_len = 512\n",
|
| 670 |
" train_dataset = PIIDataset(X_train, y_train, text_vocab, label_vocab, max_len=max_seq_len)\n",
|
| 671 |
" val_dataset = PIIDataset(X_val, y_val, text_vocab, label_vocab, max_len=max_seq_len)\n",
|
| 672 |
" \n",
|
| 673 |
+
" # Create balanced sampler for training\n",
|
| 674 |
" train_sampler = create_balanced_sampler(train_dataset, label_vocab)\n",
|
| 675 |
" \n",
|
| 676 |
+
" # Create data loaders\n",
|
| 677 |
" train_loader = DataLoader(\n",
|
| 678 |
" train_dataset, \n",
|
| 679 |
" batch_size=batch_size,\n",
|
|
|
|
| 702 |
" 'max_len': max_seq_len\n",
|
| 703 |
" }\n",
|
| 704 |
" \n",
|
| 705 |
+
" # Create transformer model\n",
|
| 706 |
" print(\"\\nCreating model...\")\n",
|
| 707 |
" model = create_transformer_pii_model(**model_config).to(device)\n",
|
| 708 |
" print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
|
|
|
|
| 717 |
" else:\n",
|
| 718 |
" criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=0)\n",
|
| 719 |
" \n",
|
| 720 |
+
" # Setup optimizer and learning rate scheduler\n",
|
| 721 |
" optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)\n",
|
| 722 |
" scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)\n",
|
| 723 |
" \n",
|
| 724 |
+
" # Initialize metrics\n",
|
| 725 |
" f1_metric_train = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
|
| 726 |
" f1_metric_val = F1ScoreMetric(beta=5, num_classes=len(label_vocab), label_vocab=label_vocab)\n",
|
| 727 |
" \n",
|
| 728 |
+
" # Training history\n",
|
| 729 |
" train_losses, train_f1s, val_losses, val_f1s = [], [], [], []\n",
|
| 730 |
" best_val_f1 = 0\n",
|
| 731 |
" patience = 5\n",
|
|
|
|
| 734 |
" print(\"\\nStarting training...\")\n",
|
| 735 |
" print(\"=\" * 60)\n",
|
| 736 |
" \n",
|
| 737 |
+
" # Training loop\n",
|
| 738 |
" for epoch in range(num_epochs):\n",
|
| 739 |
" print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
|
| 740 |
" \n",
|
| 741 |
+
" # Train for one epoch\n",
|
| 742 |
" train_loss, train_f1 = train_epoch(\n",
|
| 743 |
" model, train_loader, optimizer, criterion, device, f1_metric_train\n",
|
| 744 |
" )\n",
|
| 745 |
+
" \n",
|
| 746 |
+
" # Evaluate on validation set\n",
|
| 747 |
" val_loss, val_f1 = evaluate(\n",
|
| 748 |
" model, val_loader, criterion, device, f1_metric_val\n",
|
| 749 |
" )\n",
|
| 750 |
" \n",
|
| 751 |
+
" # Adjust learning rate based on validation loss\n",
|
| 752 |
" scheduler.step(val_loss)\n",
|
| 753 |
" \n",
|
| 754 |
" # Store metrics\n",
|
|
|
|
| 785 |
" else:\n",
|
| 786 |
" patience_counter += 1\n",
|
| 787 |
" \n",
|
| 788 |
+
" # Early stopping check\n",
|
| 789 |
" if patience_counter >= patience and epoch > 10:\n",
|
| 790 |
" print(f\"\\nEarly stopping triggered after {patience} epochs without improvement\")\n",
|
| 791 |
" break\n",
|
|
|
|
| 793 |
" # Plot training curves\n",
|
| 794 |
" plt.figure(figsize=(12, 5))\n",
|
| 795 |
" \n",
|
| 796 |
+
" # Plot loss curves\n",
|
| 797 |
" plt.subplot(1, 2, 1)\n",
|
| 798 |
" plt.plot(train_losses, label='Train Loss', linewidth=2)\n",
|
| 799 |
" plt.plot(val_losses, label='Val Loss', linewidth=2)\n",
|
|
|
|
| 803 |
" plt.legend()\n",
|
| 804 |
" plt.grid(True, alpha=0.3)\n",
|
| 805 |
" \n",
|
| 806 |
+
" # Plot F1 score curves\n",
|
| 807 |
" plt.subplot(1, 2, 2)\n",
|
| 808 |
" plt.plot(train_f1s, label='Train F1', linewidth=2)\n",
|
| 809 |
" plt.plot(val_f1s, label='Val F1', linewidth=2)\n",
|
|
|
|
| 821 |
" print(f\"Training completed!\")\n",
|
| 822 |
" print(f\"Best validation F1: {best_val_f1:.4f}\")\n",
|
| 823 |
" \n",
|
| 824 |
+
" # Save model for deployment\n",
|
| 825 |
" save_model(model, text_vocab, label_vocab, model_config, 'saved_transformer_model')\n",
|
| 826 |
" \n",
|
| 827 |
" return model, text_vocab, label_vocab"
|
|
|
|
| 1296 |
}
|
| 1297 |
],
|
| 1298 |
"source": [
|
| 1299 |
+
"# Set device\n",
|
| 1300 |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 1301 |
"print(f\"Using device: {device}\")\n",
|
| 1302 |
"\n",
|
| 1303 |
+
"# Train the transformer model\n",
|
| 1304 |
"model, text_vocab, label_vocab = train_transformer_pii_model(\n",
|
| 1305 |
" data_path='train_augmented.json',\n",
|
| 1306 |
" num_epochs=20,\n",
|