import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import pickle import os import math from typing import List, Tuple from collections import Counter import warnings warnings.filterwarnings('ignore') # Import both model architectures from transformer import create_transformer_pii_model from lstm import create_lstm_pii_model # Vocabulary class for handling text encoding and decoding class Vocabulary: """Vocabulary class for encoding/decoding text and labels""" def __init__(self, max_size=100000): # Initialize special tokens self.word2idx = {'': 0, '': 1, '': 2, '': 3} self.idx2word = {0: '', 1: '', 2: '', 3: ''} self.word_count = Counter() self.max_size = max_size def add_sentence(self, sentence): # Count word frequencies in the sentence for word in sentence: self.word_count[word.lower()] += 1 def build(self): # Build vocabulary from most common words most_common = self.word_count.most_common(self.max_size - len(self.word2idx)) for word, _ in most_common: if word not in self.word2idx: idx = len(self.word2idx) self.word2idx[word] = idx self.idx2word[idx] = word def __len__(self): return len(self.word2idx) def encode(self, sentence): # Convert words to indices return [self.word2idx.get(word.lower(), self.word2idx['']) for word in sentence] def decode(self, indices): # Convert indices back to words return [self.idx2word.get(idx, '') for idx in indices] # Main PII detection class with support for multiple models class PIIDetector: def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.models = {} self.vocabularies = {} self.configs = {} # Color for highlighting PII entities self.highlight_color = '#FF6B6B' # Load both models self.load_all_models() def load_all_models(self): """Load both LSTM and Transformer models""" print("Loading models...") # Load Transformer model try: self.load_model('transformer', 'saved_transformer') print("✓ Transformer model loaded successfully") except Exception as e: print(f"✗ Error loading Transformer model: {str(e)}") # Load LSTM model try: self.load_model('lstm', 'saved_lstm') print("✓ LSTM model loaded successfully") except Exception as e: print(f"✗ Error loading LSTM model: {str(e)}") if not self.models: raise RuntimeError("No models could be loaded. Please check model files.") print(f"Using device: {self.device}") def load_model(self, model_type, model_dir): """Load a specific model and its vocabularies""" try: # Load saved vocabularies vocab_path = os.path.join(model_dir, 'vocabularies.pkl') with open(vocab_path, 'rb') as f: vocabs = pickle.load(f) text_vocab = vocabs['text_vocab'] label_vocab = vocabs['label_vocab'] # Load model configuration config_path = os.path.join(model_dir, 'model_config.pkl') with open(config_path, 'rb') as f: model_config = pickle.load(f) # Initialize model based on type if model_type == 'transformer': model = create_transformer_pii_model(**model_config) model_file = 'pii_transformer_model.pt' else: # lstm model = create_lstm_pii_model(**model_config) model_file = 'pii_lstm_model.pt' # Load model weights model_path = os.path.join(model_dir, model_file) model.load_state_dict(torch.load(model_path, map_location=self.device)) model.to(self.device) model.eval() # Store model and associated data self.models[model_type] = model self.vocabularies[model_type] = { 'text_vocab': text_vocab, 'label_vocab': label_vocab } self.configs[model_type] = model_config except Exception as e: print(f"Error loading {model_type} model from {model_dir}: {str(e)}") raise def tokenize(self, text: str) -> List[str]: """Simple tokenization by splitting on spaces and punctuation""" import re # Split text into words and punctuation marks tokens = re.findall(r'\w+|[^\w\s]', text) return tokens def predict(self, text: str, model_type: str = 'transformer') -> List[Tuple[str, str]]: """Predict PII labels for input text using specified model""" if not text.strip(): return [] if model_type not in self.models: raise ValueError(f"Model type '{model_type}' not available. Available models: {list(self.models.keys())}") # Get model and vocabularies for the selected type model = self.models[model_type] text_vocab = self.vocabularies[model_type]['text_vocab'] label_vocab = self.vocabularies[model_type]['label_vocab'] # Tokenize input text tokens = self.tokenize(text) # Add special tokens tokens_with_special = [''] + tokens + [''] # Convert tokens to indices token_ids = text_vocab.encode(tokens_with_special) # Prepare tensor for model input_tensor = torch.tensor([token_ids]).to(self.device) # Get predictions with torch.no_grad(): outputs = model(input_tensor) predictions = torch.argmax(outputs, dim=-1) # Convert predictions to labels predicted_labels = [] for idx in predictions[0][1:-1]: # Skip special tokens label = label_vocab.idx2word.get(idx.item(), 'O') predicted_labels.append(label.upper()) # Return token-label pairs return list(zip(tokens, predicted_labels)) def create_highlighted_html(self, token_label_pairs: List[Tuple[str, str]]) -> str: """Create HTML with highlighted PII entities""" html_parts = ['
'] i = 0 while i < len(token_label_pairs): token, label = token_label_pairs[i] # Check if token is part of PII entity if label != 'O': # Collect all tokens for this entity entity_tokens = [token] entity_label = label j = i + 1 # Find continuation tokens while j < len(token_label_pairs): next_token, next_label = token_label_pairs[j] if next_label.startswith('I-') and next_label.replace('I-', 'B-') == entity_label: entity_tokens.append(next_token) j += 1 else: break # Join entity tokens with proper spacing entity_text = '' for k, tok in enumerate(entity_tokens): if k > 0 and tok not in '.,!?;:': entity_text += ' ' entity_text += tok # Create highlighted HTML for entity label_display = entity_label.replace('B-', '').replace('I-', '').replace('_', ' ') html_parts.append( f'{entity_text}' ) i = j else: # Add non-PII token with proper spacing if i > 0 and token not in '.,!?;:' and len(token_label_pairs) > i-1: prev_token, _ = token_label_pairs[i-1] if prev_token not in '(': html_parts.append(' ') html_parts.append(f'{token}') i += 1 html_parts.append('
') return ''.join(html_parts) def get_statistics(self, token_label_pairs: List[Tuple[str, str]], model_type: str) -> str: """Generate statistics about detected PII""" stats = {} total_tokens = len(token_label_pairs) pii_tokens = 0 # Count PII tokens by type for _, label in token_label_pairs: if label != 'O': pii_tokens += 1 label_clean = label.replace('B-', '').replace('I-', '').replace('_', ' ') stats[label_clean] = stats.get(label_clean, 0) + 1 # Format statistics text stats_text = f"### Detection Summary\n\n" stats_text += f"**Model Used:** {model_type.upper()}\n\n" stats_text += f"**Total tokens:** {total_tokens}\n\n" stats_text += f"**PII tokens:** {pii_tokens} ({pii_tokens/total_tokens*100:.1f}%)\n\n" return stats_text def get_available_models(self): """Get list of available models""" return list(self.models.keys()) # Initialize the detector when the script runs print("Initializing PII Detector...") detector = PIIDetector() def detect_pii(text, model_type): """Main function for Gradio interface""" if not text: return "

Please enter some text to analyze.

", "No text provided." try: # Run PII detection with selected model token_label_pairs = detector.predict(text, model_type.lower()) # Generate highlighted output highlighted_html = detector.create_highlighted_html(token_label_pairs) # Generate statistics stats = detector.get_statistics(token_label_pairs, model_type) return highlighted_html, stats except Exception as e: error_html = f'
Error: {str(e)}
' error_stats = f"Error occurred: {str(e)}" return error_html, error_stats # Create the Gradio interface with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🔒 PII Detection System Select a model and enter a sentence below to analyze it for PII content. """ ) with gr.Column(): # Model selection dropdown available_models = [m.upper() for m in detector.get_available_models()] model_dropdown = gr.Dropdown( choices=available_models, value=available_models[0] if available_models else None, label="Select Model", ) # Input text area input_text = gr.Textbox( label="Input Text", placeholder="Enter a sentence to analyze for PII...", lines=8, max_lines=20, elem_id="no-paste-textarea" ) # Control buttons with gr.Row(): analyze_btn = gr.Button("🔍 Detect PII", variant="primary", scale=2) clear_btn = gr.Button("🗑️ Clear", scale=1) # Output areas highlighted_output = gr.HTML( label="Highlighted Text", value="

Results will appear here after analysis...

" ) stats_output = gr.Markdown( label="Detection Statistics", value="*Statistics will appear here...*" ) # Connect buttons to functions analyze_btn.click( fn=detect_pii, inputs=[input_text, model_dropdown], outputs=[highlighted_output, stats_output] ) clear_btn.click( fn=lambda: ("", "

Results will appear here after analysis...

", "*Statistics will appear here...*"), outputs=[input_text, highlighted_output, stats_output] ) demo.load(None, None, None, js=""" () => { setTimeout(() => { const textarea = document.querySelector('#no-paste-textarea textarea'); if (textarea) { textarea.addEventListener('paste', (e) => { e.preventDefault(); return false; }); } }, 100); } """) # Launch the application if __name__ == "__main__": print("\nLaunching Gradio interface...") demo.launch()