|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
|
|
|
|
class MultiHeadContentModerator(nn.Module): |
|
|
""" |
|
|
Multi-task model with two classification heads: |
|
|
- Head 1: NSFW detection (frozen, pretrained) |
|
|
- Head 2: Violence detection (trainable) |
|
|
""" |
|
|
def __init__(self, base_model_name="Falconsai/nsfw_image_detection", num_violence_labels=2): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
original_model = AutoModelForImageClassification.from_pretrained(base_model_name) |
|
|
hidden_size = original_model.config.hidden_size |
|
|
|
|
|
|
|
|
self.vit = original_model.vit |
|
|
|
|
|
|
|
|
self.nsfw_classifier = original_model.classifier |
|
|
|
|
|
|
|
|
self.violence_classifier = nn.Linear(hidden_size, num_violence_labels) |
|
|
|
|
|
|
|
|
self.nsfw_id2label = original_model.config.id2label |
|
|
self.violence_id2label = {0: 'safe', 1: 'violence'} |
|
|
|
|
|
def forward(self, pixel_values, task='both'): |
|
|
outputs = self.vit(pixel_values=pixel_values) |
|
|
pooled_output = outputs.last_hidden_state[:, 0] |
|
|
|
|
|
if task == 'nsfw': |
|
|
return self.nsfw_classifier(pooled_output) |
|
|
elif task == 'violence': |
|
|
return self.violence_classifier(pooled_output) |
|
|
elif task == 'both': |
|
|
return { |
|
|
'nsfw': self.nsfw_classifier(pooled_output), |
|
|
'violence': self.violence_classifier(pooled_output) |
|
|
} |
|
|
return self.violence_classifier(pooled_output) |
|
|
|
|
|
def load_multihead_model(checkpoint_path, device='cuda'): |
|
|
"""Load trained multi-head model""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
model = MultiHeadContentModerator( |
|
|
base_model_name=checkpoint['base_model'], |
|
|
num_violence_labels=checkpoint['num_violence_labels'] |
|
|
) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.violence_id2label = checkpoint['violence_id2label'] |
|
|
model.nsfw_id2label = checkpoint['nsfw_id2label'] |
|
|
|
|
|
return model.to(device) |
|
|
|