|
|
""" |
|
|
LoRA Adapter Implementation for PHI Detection in Vision Tokens. |
|
|
This module implements a LoRA adapter for DeepSeek-OCR to detect PHI at the vision token level. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
from dataclasses import dataclass |
|
|
import numpy as np |
|
|
from peft import LoraConfig, get_peft_model, TaskType, PeftModel |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
import yaml |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PHIDetectorConfig: |
|
|
"""Configuration for PHI Detector LoRA.""" |
|
|
|
|
|
lora_rank: int = 16 |
|
|
lora_alpha: int = 32 |
|
|
lora_dropout: float = 0.1 |
|
|
target_modules: List[str] = None |
|
|
|
|
|
|
|
|
num_phi_categories: int = 18 |
|
|
confidence_threshold: float = 0.85 |
|
|
|
|
|
|
|
|
vision_hidden_size: int = 1024 |
|
|
num_vision_tokens: int = 256 |
|
|
|
|
|
|
|
|
learning_rate: float = 2e-4 |
|
|
warmup_steps: int = 500 |
|
|
gradient_checkpointing: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.target_modules is None: |
|
|
self.target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"] |
|
|
|
|
|
|
|
|
class PHITokenClassifier(nn.Module): |
|
|
"""Classification head for PHI detection on vision tokens.""" |
|
|
|
|
|
def __init__(self, config: PHIDetectorConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(config.vision_hidden_size, config.vision_hidden_size // 2), |
|
|
nn.LayerNorm(config.vision_hidden_size // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(config.lora_dropout), |
|
|
nn.Linear(config.vision_hidden_size // 2, config.vision_hidden_size // 4), |
|
|
nn.LayerNorm(config.vision_hidden_size // 4), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(config.lora_dropout), |
|
|
nn.Linear(config.vision_hidden_size // 4, config.num_phi_categories + 1) |
|
|
) |
|
|
|
|
|
|
|
|
self.confidence_head = nn.Sequential( |
|
|
nn.Linear(config.vision_hidden_size, 128), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
self.importance_scorer = nn.Sequential( |
|
|
nn.Linear(config.vision_hidden_size, 256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
vision_features: torch.Tensor, |
|
|
return_importance: bool = False |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass for PHI detection. |
|
|
|
|
|
Args: |
|
|
vision_features: Vision token features [batch_size, num_tokens, hidden_size] |
|
|
return_importance: Whether to return token importance scores |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- logits: PHI category logits [batch_size, num_tokens, num_categories + 1] |
|
|
- confidence: Confidence scores [batch_size, num_tokens, 1] |
|
|
- importance: Token importance scores (if requested) |
|
|
""" |
|
|
|
|
|
logits = self.classifier(vision_features) |
|
|
|
|
|
|
|
|
confidence = self.confidence_head(vision_features) |
|
|
|
|
|
outputs = { |
|
|
'logits': logits, |
|
|
'confidence': confidence, |
|
|
} |
|
|
|
|
|
|
|
|
if return_importance: |
|
|
importance = self.importance_scorer(vision_features) |
|
|
outputs['importance'] = importance |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class PHIDetectorLoRA(nn.Module): |
|
|
"""LoRA adapter for PHI detection in DeepSeek-OCR.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_model: nn.Module, |
|
|
config: PHIDetectorConfig, |
|
|
device: str = 'cuda' |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.device = device |
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=config.lora_rank, |
|
|
lora_alpha=config.lora_alpha, |
|
|
lora_dropout=config.lora_dropout, |
|
|
target_modules=config.target_modules, |
|
|
task_type=TaskType.FEATURE_EXTRACTION, |
|
|
) |
|
|
|
|
|
|
|
|
self.base_model = get_peft_model(base_model, lora_config) |
|
|
|
|
|
|
|
|
self.phi_detector = PHITokenClassifier(config) |
|
|
|
|
|
|
|
|
self.token_processor = VisionTokenProcessor(config) |
|
|
|
|
|
|
|
|
self.to(device) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
images: torch.Tensor, |
|
|
return_masked: bool = True, |
|
|
masking_strategy: str = 'selective_attention' |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Forward pass with PHI detection and optional masking. |
|
|
|
|
|
Args: |
|
|
images: Input images [batch_size, channels, height, width] |
|
|
return_masked: Whether to return masked vision tokens |
|
|
masking_strategy: Strategy for masking ('selective_attention' or 'token_replacement') |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- vision_features: Original vision features |
|
|
- phi_predictions: PHI detection results |
|
|
- masked_features: Masked vision features (if requested) |
|
|
- attention_mask: Attention mask for PHI tokens |
|
|
""" |
|
|
|
|
|
with torch.no_grad() if not self.training else torch.enable_grad(): |
|
|
vision_outputs = self.base_model.get_vision_features(images) |
|
|
|
|
|
vision_features = vision_outputs['last_hidden_state'] |
|
|
|
|
|
|
|
|
phi_predictions = self.phi_detector(vision_features, return_importance=True) |
|
|
|
|
|
outputs = { |
|
|
'vision_features': vision_features, |
|
|
'phi_predictions': phi_predictions, |
|
|
} |
|
|
|
|
|
|
|
|
if return_masked: |
|
|
masked_features, attention_mask = self.token_processor.apply_masking( |
|
|
vision_features, |
|
|
phi_predictions, |
|
|
strategy=masking_strategy |
|
|
) |
|
|
outputs['masked_features'] = masked_features |
|
|
outputs['attention_mask'] = attention_mask |
|
|
|
|
|
return outputs |
|
|
|
|
|
def detect_phi( |
|
|
self, |
|
|
vision_features: torch.Tensor, |
|
|
threshold: Optional[float] = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Detect PHI in vision features. |
|
|
|
|
|
Args: |
|
|
vision_features: Vision token features |
|
|
threshold: Confidence threshold (uses config default if None) |
|
|
|
|
|
Returns: |
|
|
Tuple of (phi_mask, phi_categories) |
|
|
""" |
|
|
if threshold is None: |
|
|
threshold = self.config.confidence_threshold |
|
|
|
|
|
|
|
|
predictions = self.phi_detector(vision_features) |
|
|
|
|
|
|
|
|
phi_probs = F.softmax(predictions['logits'], dim=-1) |
|
|
phi_categories = torch.argmax(phi_probs, dim=-1) |
|
|
|
|
|
|
|
|
is_phi = phi_categories > 0 |
|
|
confident = predictions['confidence'].squeeze(-1) > threshold |
|
|
|
|
|
phi_mask = is_phi & confident |
|
|
|
|
|
return phi_mask, phi_categories |
|
|
|
|
|
def save_adapter(self, save_path: str): |
|
|
"""Save LoRA adapter weights.""" |
|
|
save_dir = Path(save_path) |
|
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.base_model.save_pretrained(save_dir) |
|
|
|
|
|
|
|
|
torch.save(self.phi_detector.state_dict(), save_dir / 'phi_detector.pt') |
|
|
|
|
|
|
|
|
config_dict = { |
|
|
'lora_rank': self.config.lora_rank, |
|
|
'lora_alpha': self.config.lora_alpha, |
|
|
'lora_dropout': self.config.lora_dropout, |
|
|
'target_modules': self.config.target_modules, |
|
|
'num_phi_categories': self.config.num_phi_categories, |
|
|
'confidence_threshold': self.config.confidence_threshold, |
|
|
'vision_hidden_size': self.config.vision_hidden_size, |
|
|
'num_vision_tokens': self.config.num_vision_tokens, |
|
|
} |
|
|
|
|
|
with open(save_dir / 'config.yaml', 'w') as f: |
|
|
yaml.dump(config_dict, f) |
|
|
|
|
|
print(f"✓ Adapter saved to {save_dir}") |
|
|
|
|
|
@classmethod |
|
|
def load_adapter(cls, base_model: nn.Module, load_path: str, device: str = 'cuda'): |
|
|
"""Load LoRA adapter from saved weights.""" |
|
|
load_dir = Path(load_path) |
|
|
|
|
|
|
|
|
with open(load_dir / 'config.yaml', 'r') as f: |
|
|
config_dict = yaml.safe_load(f) |
|
|
|
|
|
config = PHIDetectorConfig(**config_dict) |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, load_dir) |
|
|
|
|
|
|
|
|
detector = cls(model, config, device) |
|
|
|
|
|
|
|
|
detector.phi_detector.load_state_dict( |
|
|
torch.load(load_dir / 'phi_detector.pt', map_location=device) |
|
|
) |
|
|
|
|
|
print(f"✓ Adapter loaded from {load_dir}") |
|
|
return detector |
|
|
|
|
|
|
|
|
class VisionTokenProcessor: |
|
|
"""Process and mask vision tokens based on PHI detection.""" |
|
|
|
|
|
def __init__(self, config: PHIDetectorConfig): |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.privacy_embeddings = nn.Parameter( |
|
|
torch.randn(config.num_phi_categories + 1, config.vision_hidden_size) |
|
|
) |
|
|
|
|
|
def apply_masking( |
|
|
self, |
|
|
vision_features: torch.Tensor, |
|
|
phi_predictions: Dict[str, torch.Tensor], |
|
|
strategy: str = 'selective_attention' |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Apply masking to vision features based on PHI predictions. |
|
|
|
|
|
Args: |
|
|
vision_features: Original vision features |
|
|
phi_predictions: PHI detection results |
|
|
strategy: Masking strategy |
|
|
|
|
|
Returns: |
|
|
Tuple of (masked_features, attention_mask) |
|
|
""" |
|
|
batch_size, num_tokens, hidden_size = vision_features.shape |
|
|
|
|
|
|
|
|
phi_categories = torch.argmax(phi_predictions['logits'], dim=-1) |
|
|
is_phi = phi_categories > 0 |
|
|
|
|
|
if strategy == 'token_replacement': |
|
|
|
|
|
masked_features = vision_features.clone() |
|
|
|
|
|
for b in range(batch_size): |
|
|
for t in range(num_tokens): |
|
|
if is_phi[b, t]: |
|
|
category = phi_categories[b, t] |
|
|
masked_features[b, t] = self.privacy_embeddings[category] |
|
|
|
|
|
attention_mask = torch.ones_like(is_phi, dtype=torch.float32) |
|
|
|
|
|
elif strategy == 'selective_attention': |
|
|
|
|
|
attention_mask = (~is_phi).float() |
|
|
|
|
|
|
|
|
masked_features = vision_features |
|
|
|
|
|
|
|
|
if 'importance' in phi_predictions: |
|
|
importance = phi_predictions['importance'].squeeze(-1) |
|
|
attention_mask = attention_mask * importance |
|
|
|
|
|
elif strategy == 'hybrid': |
|
|
|
|
|
masked_features = vision_features.clone() |
|
|
attention_mask = torch.ones_like(is_phi, dtype=torch.float32) |
|
|
|
|
|
|
|
|
high_confidence = phi_predictions['confidence'].squeeze(-1) > 0.95 |
|
|
replace_mask = is_phi & high_confidence |
|
|
|
|
|
|
|
|
attention_mask_phi = is_phi & ~high_confidence |
|
|
attention_mask[attention_mask_phi] = 0.1 |
|
|
|
|
|
|
|
|
for b in range(batch_size): |
|
|
for t in range(num_tokens): |
|
|
if replace_mask[b, t]: |
|
|
category = phi_categories[b, t] |
|
|
masked_features[b, t] = self.privacy_embeddings[category] |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown masking strategy: {strategy}") |
|
|
|
|
|
return masked_features, attention_mask |
|
|
|
|
|
|
|
|
class PHILoss(nn.Module): |
|
|
"""Custom loss function for PHI detection training.""" |
|
|
|
|
|
def __init__(self, config: PHIDetectorConfig, class_weights: Optional[torch.Tensor] = None): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.focal_alpha = 0.25 |
|
|
self.focal_gamma = 2.0 |
|
|
|
|
|
|
|
|
self.class_weights = class_weights |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
predictions: Dict[str, torch.Tensor], |
|
|
targets: Dict[str, torch.Tensor] |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Calculate PHI detection losses. |
|
|
|
|
|
Args: |
|
|
predictions: Model predictions |
|
|
targets: Ground truth labels |
|
|
|
|
|
Returns: |
|
|
Dictionary of losses |
|
|
""" |
|
|
losses = {} |
|
|
|
|
|
|
|
|
logits = predictions['logits'] |
|
|
labels = targets['labels'] |
|
|
|
|
|
ce_loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
labels.view(-1), |
|
|
weight=self.class_weights, |
|
|
reduction='none' |
|
|
) |
|
|
|
|
|
|
|
|
pt = torch.exp(-ce_loss) |
|
|
focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss |
|
|
losses['classification'] = focal_loss.mean() |
|
|
|
|
|
|
|
|
if 'confidence' in predictions and 'confidence_targets' in targets: |
|
|
confidence_loss = F.mse_loss( |
|
|
predictions['confidence'].squeeze(-1), |
|
|
targets['confidence_targets'] |
|
|
) |
|
|
losses['confidence'] = confidence_loss |
|
|
|
|
|
|
|
|
if 'importance' in predictions: |
|
|
importance = predictions['importance'].squeeze(-1) |
|
|
|
|
|
|
|
|
importance_diff = torch.diff(importance, dim=1) |
|
|
consistency_loss = torch.mean(importance_diff ** 2) |
|
|
losses['consistency'] = consistency_loss * 0.1 |
|
|
|
|
|
|
|
|
losses['total'] = sum(losses.values()) |
|
|
|
|
|
return losses |
|
|
|
|
|
|
|
|
def create_phi_detector( |
|
|
model_name: str = "deepseek-ai/DeepSeek-OCR", |
|
|
config_path: Optional[str] = None, |
|
|
device: str = 'cuda' |
|
|
) -> PHIDetectorLoRA: |
|
|
""" |
|
|
Create a PHI detector with LoRA adapter. |
|
|
|
|
|
Args: |
|
|
model_name: Base model name |
|
|
config_path: Path to config file |
|
|
device: Device to use |
|
|
|
|
|
Returns: |
|
|
PHIDetectorLoRA instance |
|
|
""" |
|
|
|
|
|
if config_path: |
|
|
with open(config_path, 'r') as f: |
|
|
config_dict = yaml.safe_load(f) |
|
|
config = PHIDetectorConfig(**config_dict) |
|
|
else: |
|
|
config = PHIDetectorConfig() |
|
|
|
|
|
|
|
|
base_model = AutoModel.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 if device == 'cuda' else torch.float32, |
|
|
) |
|
|
|
|
|
|
|
|
detector = PHIDetectorLoRA(base_model, config, device) |
|
|
|
|
|
return detector |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("Creating PHI Detector with LoRA...") |
|
|
|
|
|
|
|
|
config = PHIDetectorConfig() |
|
|
print(f"Config: {config}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nPHI Categories:") |
|
|
categories = [ |
|
|
"Non-PHI", "Name", "Date", "Address", "Phone", "Email", |
|
|
"SSN", "MRN", "Insurance ID", "Account", "License", |
|
|
"Vehicle", "Device ID", "URL", "IP", "Biometric", |
|
|
"Unique ID", "Geo Small", "Institution" |
|
|
] |
|
|
for i, cat in enumerate(categories): |
|
|
print(f" {i:2d}: {cat}") |