File size: 13,375 Bytes
de46a17
 
 
 
 
 
 
 
 
 
 
 
e31f290
 
 
 
f53fac9
de46a17
 
 
f53fac9
de46a17
 
 
 
 
 
f53fac9
de46a17
 
 
 
f53fac9
de46a17
 
 
 
 
 
 
 
 
 
 
f53fac9
de46a17
 
 
f53fac9
de46a17
 
e31f290
de46a17
e31f290
de46a17
e31f290
 
 
de46a17
f53fac9
de46a17
 
e31f290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de46a17
f53fac9
e31f290
de46a17
 
e31f290
 
de46a17
 
e31f290
de46a17
 
 
e31f290
 
 
 
 
 
 
 
 
 
 
 
 
de46a17
e31f290
 
 
 
 
 
 
de46a17
 
e31f290
de46a17
 
 
 
 
f53fac9
de46a17
 
 
e31f290
 
de46a17
 
 
e31f290
 
 
 
 
 
 
 
f53fac9
de46a17
 
f53fac9
de46a17
 
f53fac9
e31f290
de46a17
f53fac9
de46a17
 
f53fac9
de46a17
e31f290
de46a17
 
f53fac9
de46a17
f53fac9
e31f290
de46a17
 
f53fac9
de46a17
 
 
 
 
 
 
 
 
 
f53fac9
de46a17
 
 
 
 
 
f53fac9
de46a17
 
 
 
 
 
 
 
f53fac9
de46a17
 
 
 
 
 
f53fac9
de46a17
 
 
 
 
 
 
 
 
f53fac9
de46a17
 
 
 
 
 
 
 
 
 
 
 
e31f290
de46a17
 
 
 
 
f53fac9
de46a17
 
 
 
 
 
f53fac9
de46a17
e31f290
de46a17
 
 
 
e31f290
 
 
 
de46a17
f53fac9
de46a17
 
 
e31f290
de46a17
 
 
 
 
e31f290
 
de46a17
f53fac9
de46a17
 
f53fac9
e31f290
de46a17
 
 
 
 
 
 
 
f53fac9
de46a17
 
 
 
 
e31f290
de46a17
 
 
 
e31f290
 
 
 
 
 
 
 
f53fac9
de46a17
 
854d667
de46a17
854d667
 
de46a17
 
f53fac9
de46a17
 
 
 
f53fac9
de46a17
 
 
 
 
 
 
 
 
 
f53fac9
de46a17
 
e31f290
de46a17
 
 
 
 
 
 
 
854d667
 
 
 
 
 
 
 
 
 
 
 
 
 
f53fac9
de46a17
 
3dea7de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import os
import math
from typing import List, Tuple
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Import both model architectures
from transformer import create_transformer_pii_model
from lstm import create_lstm_pii_model

# Vocabulary class for handling text encoding and decoding
class Vocabulary:
    """Vocabulary class for encoding/decoding text and labels"""
    def __init__(self, max_size=100000):
        # Initialize special tokens
        self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}
        self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}
        self.word_count = Counter()
        self.max_size = max_size
        
    def add_sentence(self, sentence):
        # Count word frequencies in the sentence
        for word in sentence:
            self.word_count[word.lower()] += 1
    
    def build(self):
        # Build vocabulary from most common words
        most_common = self.word_count.most_common(self.max_size - len(self.word2idx))
        for word, _ in most_common:
            if word not in self.word2idx:
                idx = len(self.word2idx)
                self.word2idx[word] = idx
                self.idx2word[idx] = word
    
    def __len__(self):
        return len(self.word2idx)
    
    def encode(self, sentence):
        # Convert words to indices
        return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]
    
    def decode(self, indices):
        # Convert indices back to words
        return [self.idx2word.get(idx, '<unk>') for idx in indices]

# Main PII detection class with support for multiple models
class PIIDetector:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.models = {}
        self.vocabularies = {}
        self.configs = {}
        
        # Color for highlighting PII entities
        self.highlight_color = '#FF6B6B'
        
        # Load both models
        self.load_all_models()
        
    def load_all_models(self):
        """Load both LSTM and Transformer models"""
        print("Loading models...")
        
        # Load Transformer model
        try:
            self.load_model('transformer', 'saved_transformer')
            print("βœ“ Transformer model loaded successfully")
        except Exception as e:
            print(f"βœ— Error loading Transformer model: {str(e)}")
            
        # Load LSTM model
        try:
            self.load_model('lstm', 'saved_lstm')
            print("βœ“ LSTM model loaded successfully")
        except Exception as e:
            print(f"βœ— Error loading LSTM model: {str(e)}")
            
        if not self.models:
            raise RuntimeError("No models could be loaded. Please check model files.")
            
        print(f"Using device: {self.device}")
        
    def load_model(self, model_type, model_dir):
        """Load a specific model and its vocabularies"""
        try:
            # Load saved vocabularies
            vocab_path = os.path.join(model_dir, 'vocabularies.pkl')
            with open(vocab_path, 'rb') as f:
                vocabs = pickle.load(f)
                text_vocab = vocabs['text_vocab']
                label_vocab = vocabs['label_vocab']
            
            # Load model configuration
            config_path = os.path.join(model_dir, 'model_config.pkl')
            with open(config_path, 'rb') as f:
                model_config = pickle.load(f)
            
            # Initialize model based on type
            if model_type == 'transformer':
                model = create_transformer_pii_model(**model_config)
                model_file = 'pii_transformer_model.pt'
            else:  # lstm
                model = create_lstm_pii_model(**model_config)
                model_file = 'pii_lstm_model.pt'
            
            # Load model weights
            model_path = os.path.join(model_dir, model_file)
            model.load_state_dict(torch.load(model_path, map_location=self.device))
            model.to(self.device)
            model.eval()
            
            # Store model and associated data
            self.models[model_type] = model
            self.vocabularies[model_type] = {
                'text_vocab': text_vocab,
                'label_vocab': label_vocab
            }
            self.configs[model_type] = model_config
            
        except Exception as e:
            print(f"Error loading {model_type} model from {model_dir}: {str(e)}")
            raise
        
    def tokenize(self, text: str) -> List[str]:
        """Simple tokenization by splitting on spaces and punctuation"""
        import re
        # Split text into words and punctuation marks
        tokens = re.findall(r'\w+|[^\w\s]', text)
        return tokens
    
    def predict(self, text: str, model_type: str = 'transformer') -> List[Tuple[str, str]]:
        """Predict PII labels for input text using specified model"""
        if not text.strip():
            return []
        
        if model_type not in self.models:
            raise ValueError(f"Model type '{model_type}' not available. Available models: {list(self.models.keys())}")
        
        # Get model and vocabularies for the selected type
        model = self.models[model_type]
        text_vocab = self.vocabularies[model_type]['text_vocab']
        label_vocab = self.vocabularies[model_type]['label_vocab']
        
        # Tokenize input text
        tokens = self.tokenize(text)
        
        # Add special tokens
        tokens_with_special = ['<start>'] + tokens + ['<end>']
        
        # Convert tokens to indices
        token_ids = text_vocab.encode(tokens_with_special)
        
        # Prepare tensor for model
        input_tensor = torch.tensor([token_ids]).to(self.device)
        
        # Get predictions
        with torch.no_grad():
            outputs = model(input_tensor)
            predictions = torch.argmax(outputs, dim=-1)
        
        # Convert predictions to labels
        predicted_labels = []
        for idx in predictions[0][1:-1]:  # Skip special tokens
            label = label_vocab.idx2word.get(idx.item(), 'O')
            predicted_labels.append(label.upper())
        
        # Return token-label pairs
        return list(zip(tokens, predicted_labels))
    
    def create_highlighted_html(self, token_label_pairs: List[Tuple[str, str]]) -> str:
        """Create HTML with highlighted PII entities"""
        html_parts = ['<div style="font-family: Arial, sans-serif; line-height: 1.8; padding: 20px; background-color: white; border-radius: 8px; color: black;">']
        
        i = 0
        while i < len(token_label_pairs):
            token, label = token_label_pairs[i]
            
            # Check if token is part of PII entity
            if label != 'O':
                # Collect all tokens for this entity
                entity_tokens = [token]
                entity_label = label
                j = i + 1
                
                # Find continuation tokens
                while j < len(token_label_pairs):
                    next_token, next_label = token_label_pairs[j]
                    if next_label.startswith('I-') and next_label.replace('I-', 'B-') == entity_label:
                        entity_tokens.append(next_token)
                        j += 1
                    else:
                        break
                
                # Join entity tokens with proper spacing
                entity_text = ''
                for k, tok in enumerate(entity_tokens):
                    if k > 0 and tok not in '.,!?;:':
                        entity_text += ' '
                    entity_text += tok
                
                # Create highlighted HTML for entity
                label_display = entity_label.replace('B-', '').replace('I-', '').replace('_', ' ')
                html_parts.append(
                    f'<mark style="background-color: {self.highlight_color}; padding: 2px 4px; '
                    f'border-radius: 3px; margin: 0 2px; font-weight: 500;" '
                    f'title="{label_display}">{entity_text}</mark>'
                )
                
                i = j
            else:
                # Add non-PII token with proper spacing
                if i > 0 and token not in '.,!?;:' and len(token_label_pairs) > i-1:
                    prev_token, _ = token_label_pairs[i-1]
                    if prev_token not in '(':
                        html_parts.append(' ')
                
                html_parts.append(f'<span style="color: black;">{token}</span>')
                i += 1
        
        html_parts.append('</div>')
        
        return ''.join(html_parts)
    
    def get_statistics(self, token_label_pairs: List[Tuple[str, str]], model_type: str) -> str:
        """Generate statistics about detected PII"""
        stats = {}
        total_tokens = len(token_label_pairs)
        pii_tokens = 0
        
        # Count PII tokens by type
        for _, label in token_label_pairs:
            if label != 'O':
                pii_tokens += 1
                label_clean = label.replace('B-', '').replace('I-', '').replace('_', ' ')
                stats[label_clean] = stats.get(label_clean, 0) + 1
        
        # Format statistics text
        stats_text = f"### Detection Summary\n\n"
        stats_text += f"**Model Used:** {model_type.upper()}\n\n"
        stats_text += f"**Total tokens:** {total_tokens}\n\n"
        stats_text += f"**PII tokens:** {pii_tokens} ({pii_tokens/total_tokens*100:.1f}%)\n\n"
        
        return stats_text
    
    def get_available_models(self):
        """Get list of available models"""
        return list(self.models.keys())

# Initialize the detector when the script runs
print("Initializing PII Detector...")
detector = PIIDetector()

def detect_pii(text, model_type):
    """Main function for Gradio interface"""
    if not text:
        return "<p style='color: #6c757d; padding: 20px;'>Please enter some text to analyze.</p>", "No text provided."
    
    try:
        # Run PII detection with selected model
        token_label_pairs = detector.predict(text, model_type.lower())
        
        # Generate highlighted output
        highlighted_html = detector.create_highlighted_html(token_label_pairs)
        
        # Generate statistics
        stats = detector.get_statistics(token_label_pairs, model_type)
        
        return highlighted_html, stats
    
    except Exception as e:
        error_html = f'<div style="color: #dc3545; padding: 20px; background-color: #f8d7da; border-radius: 8px;">Error: {str(e)}</div>'
        error_stats = f"Error occurred: {str(e)}"
        return error_html, error_stats

# Create the Gradio interface
with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # πŸ”’ PII Detection System
        
        Select a model and enter a sentence below to analyze it for PII content.
        """
    )
    
    with gr.Column():
        # Model selection dropdown
        available_models = [m.upper() for m in detector.get_available_models()]
        model_dropdown = gr.Dropdown(
            choices=available_models,
            value=available_models[0] if available_models else None,
            label="Select Model",
        )
        
        # Input text area
        input_text = gr.Textbox(
            label="Input Text",
            placeholder="Enter a sentence to analyze for PII...",
            lines=8,
            max_lines=20,
            elem_id="no-paste-textarea"
        )
        
        # Control buttons
        with gr.Row():
            analyze_btn = gr.Button("πŸ” Detect PII", variant="primary", scale=2)
            clear_btn = gr.Button("πŸ—‘οΈ Clear", scale=1)
        
        # Output areas
        highlighted_output = gr.HTML(
            label="Highlighted Text",
            value="<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>"
        )
        
        stats_output = gr.Markdown(
            label="Detection Statistics",
            value="*Statistics will appear here...*"
        )
    
    # Connect buttons to functions
    analyze_btn.click(
        fn=detect_pii,
        inputs=[input_text, model_dropdown],
        outputs=[highlighted_output, stats_output]
    )
    
    clear_btn.click(
        fn=lambda: ("", "<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>", "*Statistics will appear here...*"),
        outputs=[input_text, highlighted_output, stats_output]
    )

    demo.load(None, None, None, js="""
        () => {
            setTimeout(() => {
                const textarea = document.querySelector('#no-paste-textarea textarea');
                if (textarea) {
                    textarea.addEventListener('paste', (e) => {
                        e.preventDefault();
                        return false;
                    });
                }
            }, 100);
        }
    """)

# Launch the application
if __name__ == "__main__":
    print("\nLaunching Gradio interface...")
    demo.launch()