PII-Detection / app.py
AmitHirpara's picture
update app
e31f290
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import os
import math
from typing import List, Tuple
from collections import Counter
import warnings
warnings.filterwarnings('ignore')
# Import both model architectures
from transformer import create_transformer_pii_model
from lstm import create_lstm_pii_model
# Vocabulary class for handling text encoding and decoding
class Vocabulary:
"""Vocabulary class for encoding/decoding text and labels"""
def __init__(self, max_size=100000):
# Initialize special tokens
self.word2idx = {'<pad>': 0, '<unk>': 1, '<start>': 2, '<end>': 3}
self.idx2word = {0: '<pad>', 1: '<unk>', 2: '<start>', 3: '<end>'}
self.word_count = Counter()
self.max_size = max_size
def add_sentence(self, sentence):
# Count word frequencies in the sentence
for word in sentence:
self.word_count[word.lower()] += 1
def build(self):
# Build vocabulary from most common words
most_common = self.word_count.most_common(self.max_size - len(self.word2idx))
for word, _ in most_common:
if word not in self.word2idx:
idx = len(self.word2idx)
self.word2idx[word] = idx
self.idx2word[idx] = word
def __len__(self):
return len(self.word2idx)
def encode(self, sentence):
# Convert words to indices
return [self.word2idx.get(word.lower(), self.word2idx['<unk>']) for word in sentence]
def decode(self, indices):
# Convert indices back to words
return [self.idx2word.get(idx, '<unk>') for idx in indices]
# Main PII detection class with support for multiple models
class PIIDetector:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.models = {}
self.vocabularies = {}
self.configs = {}
# Color for highlighting PII entities
self.highlight_color = '#FF6B6B'
# Load both models
self.load_all_models()
def load_all_models(self):
"""Load both LSTM and Transformer models"""
print("Loading models...")
# Load Transformer model
try:
self.load_model('transformer', 'saved_transformer')
print("βœ“ Transformer model loaded successfully")
except Exception as e:
print(f"βœ— Error loading Transformer model: {str(e)}")
# Load LSTM model
try:
self.load_model('lstm', 'saved_lstm')
print("βœ“ LSTM model loaded successfully")
except Exception as e:
print(f"βœ— Error loading LSTM model: {str(e)}")
if not self.models:
raise RuntimeError("No models could be loaded. Please check model files.")
print(f"Using device: {self.device}")
def load_model(self, model_type, model_dir):
"""Load a specific model and its vocabularies"""
try:
# Load saved vocabularies
vocab_path = os.path.join(model_dir, 'vocabularies.pkl')
with open(vocab_path, 'rb') as f:
vocabs = pickle.load(f)
text_vocab = vocabs['text_vocab']
label_vocab = vocabs['label_vocab']
# Load model configuration
config_path = os.path.join(model_dir, 'model_config.pkl')
with open(config_path, 'rb') as f:
model_config = pickle.load(f)
# Initialize model based on type
if model_type == 'transformer':
model = create_transformer_pii_model(**model_config)
model_file = 'pii_transformer_model.pt'
else: # lstm
model = create_lstm_pii_model(**model_config)
model_file = 'pii_lstm_model.pt'
# Load model weights
model_path = os.path.join(model_dir, model_file)
model.load_state_dict(torch.load(model_path, map_location=self.device))
model.to(self.device)
model.eval()
# Store model and associated data
self.models[model_type] = model
self.vocabularies[model_type] = {
'text_vocab': text_vocab,
'label_vocab': label_vocab
}
self.configs[model_type] = model_config
except Exception as e:
print(f"Error loading {model_type} model from {model_dir}: {str(e)}")
raise
def tokenize(self, text: str) -> List[str]:
"""Simple tokenization by splitting on spaces and punctuation"""
import re
# Split text into words and punctuation marks
tokens = re.findall(r'\w+|[^\w\s]', text)
return tokens
def predict(self, text: str, model_type: str = 'transformer') -> List[Tuple[str, str]]:
"""Predict PII labels for input text using specified model"""
if not text.strip():
return []
if model_type not in self.models:
raise ValueError(f"Model type '{model_type}' not available. Available models: {list(self.models.keys())}")
# Get model and vocabularies for the selected type
model = self.models[model_type]
text_vocab = self.vocabularies[model_type]['text_vocab']
label_vocab = self.vocabularies[model_type]['label_vocab']
# Tokenize input text
tokens = self.tokenize(text)
# Add special tokens
tokens_with_special = ['<start>'] + tokens + ['<end>']
# Convert tokens to indices
token_ids = text_vocab.encode(tokens_with_special)
# Prepare tensor for model
input_tensor = torch.tensor([token_ids]).to(self.device)
# Get predictions
with torch.no_grad():
outputs = model(input_tensor)
predictions = torch.argmax(outputs, dim=-1)
# Convert predictions to labels
predicted_labels = []
for idx in predictions[0][1:-1]: # Skip special tokens
label = label_vocab.idx2word.get(idx.item(), 'O')
predicted_labels.append(label.upper())
# Return token-label pairs
return list(zip(tokens, predicted_labels))
def create_highlighted_html(self, token_label_pairs: List[Tuple[str, str]]) -> str:
"""Create HTML with highlighted PII entities"""
html_parts = ['<div style="font-family: Arial, sans-serif; line-height: 1.8; padding: 20px; background-color: white; border-radius: 8px; color: black;">']
i = 0
while i < len(token_label_pairs):
token, label = token_label_pairs[i]
# Check if token is part of PII entity
if label != 'O':
# Collect all tokens for this entity
entity_tokens = [token]
entity_label = label
j = i + 1
# Find continuation tokens
while j < len(token_label_pairs):
next_token, next_label = token_label_pairs[j]
if next_label.startswith('I-') and next_label.replace('I-', 'B-') == entity_label:
entity_tokens.append(next_token)
j += 1
else:
break
# Join entity tokens with proper spacing
entity_text = ''
for k, tok in enumerate(entity_tokens):
if k > 0 and tok not in '.,!?;:':
entity_text += ' '
entity_text += tok
# Create highlighted HTML for entity
label_display = entity_label.replace('B-', '').replace('I-', '').replace('_', ' ')
html_parts.append(
f'<mark style="background-color: {self.highlight_color}; padding: 2px 4px; '
f'border-radius: 3px; margin: 0 2px; font-weight: 500;" '
f'title="{label_display}">{entity_text}</mark>'
)
i = j
else:
# Add non-PII token with proper spacing
if i > 0 and token not in '.,!?;:' and len(token_label_pairs) > i-1:
prev_token, _ = token_label_pairs[i-1]
if prev_token not in '(':
html_parts.append(' ')
html_parts.append(f'<span style="color: black;">{token}</span>')
i += 1
html_parts.append('</div>')
return ''.join(html_parts)
def get_statistics(self, token_label_pairs: List[Tuple[str, str]], model_type: str) -> str:
"""Generate statistics about detected PII"""
stats = {}
total_tokens = len(token_label_pairs)
pii_tokens = 0
# Count PII tokens by type
for _, label in token_label_pairs:
if label != 'O':
pii_tokens += 1
label_clean = label.replace('B-', '').replace('I-', '').replace('_', ' ')
stats[label_clean] = stats.get(label_clean, 0) + 1
# Format statistics text
stats_text = f"### Detection Summary\n\n"
stats_text += f"**Model Used:** {model_type.upper()}\n\n"
stats_text += f"**Total tokens:** {total_tokens}\n\n"
stats_text += f"**PII tokens:** {pii_tokens} ({pii_tokens/total_tokens*100:.1f}%)\n\n"
return stats_text
def get_available_models(self):
"""Get list of available models"""
return list(self.models.keys())
# Initialize the detector when the script runs
print("Initializing PII Detector...")
detector = PIIDetector()
def detect_pii(text, model_type):
"""Main function for Gradio interface"""
if not text:
return "<p style='color: #6c757d; padding: 20px;'>Please enter some text to analyze.</p>", "No text provided."
try:
# Run PII detection with selected model
token_label_pairs = detector.predict(text, model_type.lower())
# Generate highlighted output
highlighted_html = detector.create_highlighted_html(token_label_pairs)
# Generate statistics
stats = detector.get_statistics(token_label_pairs, model_type)
return highlighted_html, stats
except Exception as e:
error_html = f'<div style="color: #dc3545; padding: 20px; background-color: #f8d7da; border-radius: 8px;">Error: {str(e)}</div>'
error_stats = f"Error occurred: {str(e)}"
return error_html, error_stats
# Create the Gradio interface
with gr.Blocks(title="PII Detection System", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ”’ PII Detection System
Select a model and enter a sentence below to analyze it for PII content.
"""
)
with gr.Column():
# Model selection dropdown
available_models = [m.upper() for m in detector.get_available_models()]
model_dropdown = gr.Dropdown(
choices=available_models,
value=available_models[0] if available_models else None,
label="Select Model",
)
# Input text area
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter a sentence to analyze for PII...",
lines=8,
max_lines=20,
elem_id="no-paste-textarea"
)
# Control buttons
with gr.Row():
analyze_btn = gr.Button("πŸ” Detect PII", variant="primary", scale=2)
clear_btn = gr.Button("πŸ—‘οΈ Clear", scale=1)
# Output areas
highlighted_output = gr.HTML(
label="Highlighted Text",
value="<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>"
)
stats_output = gr.Markdown(
label="Detection Statistics",
value="*Statistics will appear here...*"
)
# Connect buttons to functions
analyze_btn.click(
fn=detect_pii,
inputs=[input_text, model_dropdown],
outputs=[highlighted_output, stats_output]
)
clear_btn.click(
fn=lambda: ("", "<p style='color: #6c757d; padding: 20px;'>Results will appear here after analysis...</p>", "*Statistics will appear here...*"),
outputs=[input_text, highlighted_output, stats_output]
)
demo.load(None, None, None, js="""
() => {
setTimeout(() => {
const textarea = document.querySelector('#no-paste-textarea textarea');
if (textarea) {
textarea.addEventListener('paste', (e) => {
e.preventDefault();
return false;
});
}
}, 100);
}
""")
# Launch the application
if __name__ == "__main__":
print("\nLaunching Gradio interface...")
demo.launch()