multihead-content-moderator / model_class.py
Ali7880's picture
Upload 5 files
31e7458 verified
import torch
import torch.nn as nn
from transformers import AutoModelForImageClassification, AutoImageProcessor
class MultiHeadContentModerator(nn.Module):
"""
Multi-task model with two classification heads:
- Head 1: NSFW detection (frozen, pretrained)
- Head 2: Violence detection (trainable)
"""
def __init__(self, base_model_name="Falconsai/nsfw_image_detection", num_violence_labels=2):
super().__init__()
# Load base model
original_model = AutoModelForImageClassification.from_pretrained(base_model_name)
hidden_size = original_model.config.hidden_size
# ViT backbone (shared)
self.vit = original_model.vit
# Head 1: Original NSFW classifier
self.nsfw_classifier = original_model.classifier
# Head 2: Violence classifier
self.violence_classifier = nn.Linear(hidden_size, num_violence_labels)
# Label mappings - use actual Falconsai config
self.nsfw_id2label = original_model.config.id2label # {0: 'normal', 1: 'nsfw'}
self.violence_id2label = {0: 'safe', 1: 'violence'} # Will be overwritten from checkpoint
def forward(self, pixel_values, task='both'):
outputs = self.vit(pixel_values=pixel_values)
pooled_output = outputs.last_hidden_state[:, 0]
if task == 'nsfw':
return self.nsfw_classifier(pooled_output)
elif task == 'violence':
return self.violence_classifier(pooled_output)
elif task == 'both':
return {
'nsfw': self.nsfw_classifier(pooled_output),
'violence': self.violence_classifier(pooled_output)
}
return self.violence_classifier(pooled_output)
def load_multihead_model(checkpoint_path, device='cuda'):
"""Load trained multi-head model"""
checkpoint = torch.load(checkpoint_path, map_location=device)
model = MultiHeadContentModerator(
base_model_name=checkpoint['base_model'],
num_violence_labels=checkpoint['num_violence_labels']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.violence_id2label = checkpoint['violence_id2label']
model.nsfw_id2label = checkpoint['nsfw_id2label']
return model.to(device)