import torch import torch.nn as nn from transformers import AutoModelForImageClassification, AutoImageProcessor class MultiHeadContentModerator(nn.Module): """ Multi-task model with two classification heads: - Head 1: NSFW detection (frozen, pretrained) - Head 2: Violence detection (trainable) """ def __init__(self, base_model_name="Falconsai/nsfw_image_detection", num_violence_labels=2): super().__init__() # Load base model original_model = AutoModelForImageClassification.from_pretrained(base_model_name) hidden_size = original_model.config.hidden_size # ViT backbone (shared) self.vit = original_model.vit # Head 1: Original NSFW classifier self.nsfw_classifier = original_model.classifier # Head 2: Violence classifier self.violence_classifier = nn.Linear(hidden_size, num_violence_labels) # Label mappings - use actual Falconsai config self.nsfw_id2label = original_model.config.id2label # {0: 'normal', 1: 'nsfw'} self.violence_id2label = {0: 'safe', 1: 'violence'} # Will be overwritten from checkpoint def forward(self, pixel_values, task='both'): outputs = self.vit(pixel_values=pixel_values) pooled_output = outputs.last_hidden_state[:, 0] if task == 'nsfw': return self.nsfw_classifier(pooled_output) elif task == 'violence': return self.violence_classifier(pooled_output) elif task == 'both': return { 'nsfw': self.nsfw_classifier(pooled_output), 'violence': self.violence_classifier(pooled_output) } return self.violence_classifier(pooled_output) def load_multihead_model(checkpoint_path, device='cuda'): """Load trained multi-head model""" checkpoint = torch.load(checkpoint_path, map_location=device) model = MultiHeadContentModerator( base_model_name=checkpoint['base_model'], num_violence_labels=checkpoint['num_violence_labels'] ) model.load_state_dict(checkpoint['model_state_dict']) model.violence_id2label = checkpoint['violence_id2label'] model.nsfw_id2label = checkpoint['nsfw_id2label'] return model.to(device)