Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import pickle | |
| import os | |
| import math | |
| from typing import List, Tuple | |
| from collections import Counter | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Import both model architectures | |
| from transformer import create_transformer_pii_model | |
| from lstm import create_lstm_pii_model | |
| # Vocabulary class for handling text encoding and decoding | |
| class Vocabulary: | |
| """Vocabulary class for encoding/decoding text and labels""" | |
| def __init__(self, max_size=100000): | |
| # Initialize special tokens | |
| self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3} | |
| self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'} | |
| self.word_count = Counter() | |
| self.max_size = max_size | |
| def add_sentence(self, sentence): | |
| # Count word frequencies in the sentence | |
| for word in sentence: | |
| self.word_count[word.lower()] += 1 | |
| def build(self): | |
| # Build vocabulary from most common words | |
| most_common = self.word_count.most_common(self.max_size - len(self.word2idx)) | |
| for word, _ in most_common: | |
| if word not in self.word2idx: | |
| idx = len(self.word2idx) | |
| self.word2idx[word] = idx | |
| self.idx2word[idx] = word | |
| def __len__(self): | |
| return len(self.word2idx) | |
| def encode(self, sentence): | |
| # Convert words to indices | |
| return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence] | |
| def decode(self, indices): | |
| # Convert indices back to words | |
| return [self.idx2word.get(idx, '<unk>') for idx in indices] | |
| # Main PII detection class with support for multiple models | |
| class PIIDetector: | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.models = {} | |
| self.vocabularies = {} | |
| self.configs = {} | |
| # Color for highlighting PII entities | |
| self.highlight_color = '#FF6B6B' | |
| # Load both models | |
| self.load_all_models() | |
| def load_all_models(self): | |
| """Load both LSTM and Transformer models""" | |
| print("Loading models...") | |
| # Load Transformer model | |
| try: | |
| self.load_model('transformer', 'saved_transformer') | |
| print("β Transformer model loaded successfully") | |
| except Exception as e: | |
| print(f"β Error loading Transformer model: {str(e)}") | |
| # Load LSTM model | |
| try: | |
| self.load_model('lstm', 'saved_lstm') | |
| print("β LSTM model loaded successfully") | |
| except Exception as e: | |
| print(f"β Error loading LSTM model: {str(e)}") | |
| if not self.models: | |
| raise RuntimeError("No models could be loaded. Please check model files.") | |
| print(f"Using device: {self.device}") | |
| def load_model(self, model_type, model_dir): | |
| """Load a specific model and its vocabularies""" | |
| try: | |
| # Load saved vocabularies | |
| vocab_path = os.path.join(model_dir, 'vocabularies.pkl') | |
| with open(vocab_path, 'rb') as f: | |
| vocabs = pickle.load(f) | |
| text_vocab = vocabs['text_vocab'] | |
| label_vocab = vocabs['label_vocab'] | |
| # Load model configuration | |
| config_path = os.path.join(model_dir, 'model_config.pkl') | |
| with open(config_path, 'rb') as f: | |
| model_config = pickle.load(f) | |
| # Initialize model based on type | |
| if model_type == 'transformer': | |
| model = create_transformer_pii_model(**model_config) | |
| model_file = 'pii_transformer_model.pt' | |
| else: # lstm | |
| model = create_lstm_pii_model(**model_config) | |
| model_file = 'pii_lstm_model.pt' | |
| # Load model weights | |
| model_path = os.path.join(model_dir, model_file) | |
| model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| model.to(self.device) | |
| model.eval() | |
| # Store model and associated data | |
| self.models[model_type] = model | |
| self.vocabularies[model_type] = { | |
| 'text_vocab': text_vocab, | |
| 'label_vocab': label_vocab | |
| } | |
| self.configs[model_type] = model_config | |
| except Exception as e: | |
| print(f"Error loading {model_type} model from {model_dir}: {str(e)}") | |
| raise | |
| def tokenize(self, text: str) -> List[str]: | |
| """Simple tokenization by splitting on spaces and punctuation""" | |
| import re | |
| # Split text into words and punctuation marks | |
| tokens = re.findall(r'\w+|[^\w\s]', text) | |
| return tokens | |
| def predict(self, text: str, model_type: str = 'transformer') -> List[Tuple[str, str]]: | |
| """Predict PII labels for input text using specified model""" | |
| if not text.strip(): | |
| return [] | |
| if model_type not in self.models: | |
| raise ValueError(f"Model type '{model_type}' not available. Available models: {list(self.models.keys())}") | |
| # Get model and vocabularies for the selected type | |
| model = self.models[model_type] | |
| text_vocab = self.vocabularies[model_type]['text_vocab'] | |
| label_vocab = self.vocabularies[model_type]['label_vocab'] | |
| # Tokenize input text | |
| tokens = self.tokenize(text) | |
| # Add special tokens | |
| tokens_with_special = ['<start>'] + tokens + ['<end>'] | |
| # Convert tokens to indices | |
| token_ids = text_vocab.encode(tokens_with_special) | |
| # Prepare tensor for model | |
| input_tensor = torch.tensor([token_ids]).to(self.device) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| predictions = torch.argmax(outputs, dim=-1) | |
| # Convert predictions to labels | |
| predicted_labels = [] | |
| for idx in predictions[0][1:-1]: # Skip special tokens | |
| label = label_vocab.idx2word.get(idx.item(), 'O') | |
| predicted_labels.append(label.upper()) | |
| # Return token-label pairs | |
| return list(zip(tokens, predicted_labels)) | |
| def create_highlighted_html(self, token_label_pairs: List[Tuple[str, str]]) -> str: | |
| """Create HTML with highlighted PII entities""" | |
| html_parts = ['<div style="font-family: Arial, sans-serif; line-height: 1.8; padding: 20px; background-color: white; border-radius: 8px; color: black;">'] | |
| i = 0 | |
| while i < len(token_label_pairs): | |
| token, label = token_label_pairs[i] | |
| # Check if token is part of PII entity | |
| if label != 'O': | |
| # Collect all tokens for this entity | |
| entity_tokens = [token] | |
| entity_label = label | |
| j = i + 1 | |
| # Find continuation tokens | |
| while j < len(token_label_pairs): | |
| next_token, next_label = token_label_pairs[j] | |
| if next_label.startswith('I-') and next_label.replace('I-', 'B-') == entity_label: | |
| entity_tokens.append(next_token) | |
| j += 1 | |
| else: | |
| break | |
| # Join entity tokens with proper spacing | |
| entity_text = '' | |
| for k, tok in enumerate(entity_tokens): | |
| if k > 0 and tok not in '.,!?;:': | |
| entity_text += ' ' | |
| entity_text += tok | |
| # Create highlighted HTML for entity | |
| label_display = entity_label.replace('B-', '').replace('I-', '').replace('_', ' ') | |
| html_parts.append( | |
| f'<mark style="background-color: {self.highlight_color}; padding: 2px 4px; ' | |
| f'border-radius: 3px; margin: 0 2px; font-weight: 500;" ' | |
| f'title="{label_display}">{entity_text}</mark>' | |
| ) | |
| i = j | |
| else: | |
| # Add non-PII token with proper spacing | |
| if i > 0 and token not in '.,!?;:' and len(token_label_pairs) > i-1: | |
| prev_token, _ = token_label_pairs[i-1] | |
| if prev_token not in '(': | |
| html_parts.append(' ') | |
| html_parts.append(f'<span style="color: black;">{token}</span>') | |
| i += 1 | |
| html_parts.append('</div>') | |
| return ''.join(html_parts) | |
| def get_statistics(self, token_label_pairs: List[Tuple[str, str]], model_type: str) -> str: | |
| """Generate statistics about detected PII""" | |
| stats = {} | |
| total_tokens = len(token_label_pairs) | |
| pii_tokens = 0 | |
| # Count PII tokens by type | |
| for _, label in token_label_pairs: | |
| if label != 'O': | |
| pii_tokens += 1 | |
| label_clean = label.replace('B-', '').replace('I-', '').replace('_', ' ') | |
| stats[label_clean] = stats.get(label_clean, 0) + 1 | |
| # Format statistics text | |
| stats_text = f"### Detection Summary\n\n" | |
| stats_text += f"**Model Used:** {model_type.upper()}\n\n" | |
| stats_text += f"**Total tokens:** {total_tokens}\n\n" | |
| stats_text += f"**PII tokens:** {pii_tokens} ({pii_tokens/total_tokens*100:.1f}%)\n\n" | |
| return stats_text | |
| def get_available_models(self): | |
| """Get list of available models""" | |
| return list(self.models.keys()) | |
| # Initialize the detector when the script runs | |
| print("Initializing PII Detector...") | |
| detector = PIIDetector() | |
| def detect_pii(text, model_type): | |
| """Main function for Gradio interface""" | |
| if not text: | |
| return "<p style='color: #6c757d; padding: 20px;'>Please enter some text to analyze.</p>", "No text provided." | |
| try: | |
| # Run PII detection with selected model | |
| token_label_pairs = detector.predict(text, model_type.lower()) | |
| # Generate highlighted output | |
| highlighted_html = detector.create_highlighted_html(token_label_pairs) | |
| # Generate statistics | |
| stats = detector.get_statistics(token_label_pairs, model_type) | |
| return highlighted_html, stats | |
| except Exception as e: | |
| error_html = f'<div style="color: #dc3545; padding: 20px; background-color: #f8d7da; border-radius: 8px;">Error: {str(e)}</div>' | |
| error_stats = f"Error occurred: {str(e)}" | |
| return error_html, error_stats | |
| # Create the Gradio interface | |
| with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π PII Detection System | |
| Select a model and enter a sentence below to analyze it for PII content. | |
| """ | |
| ) | |
| with gr.Column(): | |
| # Model selection dropdown | |
| available_models = [m.upper() for m in detector.get_available_models()] | |
| model_dropdown = gr.Dropdown( | |
| choices=available_models, | |
| value=available_models[0] if available_models else None, | |
| label="Select Model", | |
| ) | |
| # Input text area | |
| input_text = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter a sentence to analyze for PII...", | |
| lines=8, | |
| max_lines=20, | |
| elem_id="no-paste-textarea" | |
| ) | |
| # Control buttons | |
| with gr.Row(): | |
| analyze_btn = gr.Button("π Detect PII", variant="primary", scale=2) | |
| clear_btn = gr.Button("ποΈ Clear", scale=1) | |
| # Output areas | |
| highlighted_output = gr.HTML( | |
| label="Highlighted Text", | |
| value="<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>" | |
| ) | |
| stats_output = gr.Markdown( | |
| label="Detection Statistics", | |
| value="*Statistics will appear here...*" | |
| ) | |
| # Connect buttons to functions | |
| analyze_btn.click( | |
| fn=detect_pii, | |
| inputs=[input_text, model_dropdown], | |
| outputs=[highlighted_output, stats_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>", "*Statistics will appear here...*"), | |
| outputs=[input_text, highlighted_output, stats_output] | |
| ) | |
| demo.load(None, None, None, js=""" | |
| () => { | |
| setTimeout(() => { | |
| const textarea = document.querySelector('#no-paste-textarea textarea'); | |
| if (textarea) { | |
| textarea.addEventListener('paste', (e) => { | |
| e.preventDefault(); | |
| return false; | |
| }); | |
| } | |
| }, 100); | |
| } | |
| """) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| print("\nLaunching Gradio interface...") | |
| demo.launch() |