Upload 5 files
Browse files- README.md +57 -3
- config.json +14 -0
- model_class.py +59 -0
- multihead_model.pt +3 -0
- preprocessor_config.json +23 -0
README.md
CHANGED
|
@@ -1,3 +1,57 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
base_model: Falconsai/nsfw_image_detection
|
| 4 |
+
tags:
|
| 5 |
+
- image-classification
|
| 6 |
+
- content-moderation
|
| 7 |
+
- violence-detection
|
| 8 |
+
- nsfw-detection
|
| 9 |
+
- multi-task-learning
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Multi-Head Content Moderator
|
| 13 |
+
|
| 14 |
+
A multi-task image moderation model with **two classification heads**:
|
| 15 |
+
- **NSFW Detection**: Detects explicit/adult content (preserved from Falconsai)
|
| 16 |
+
- **Violence Detection**: Detects violent content (newly trained)
|
| 17 |
+
|
| 18 |
+
## Architecture
|
| 19 |
+
- Base: ViT (Vision Transformer) from Falconsai/nsfw_image_detection
|
| 20 |
+
- Head 1: NSFW classifier (frozen, pretrained)
|
| 21 |
+
- Head 2: Violence classifier (trained on violence dataset)
|
| 22 |
+
|
| 23 |
+
## Categories
|
| 24 |
+
|
| 25 |
+
### NSFW Head
|
| 26 |
+
- nsfw
|
| 27 |
+
- safe
|
| 28 |
+
|
| 29 |
+
### Violence Head
|
| 30 |
+
- safe
|
| 31 |
+
- violence
|
| 32 |
+
|
| 33 |
+
## Performance (Violence Detection)
|
| 34 |
+
- Accuracy: 0.9075
|
| 35 |
+
- F1 Score: 0.9076
|
| 36 |
+
|
| 37 |
+
## Usage
|
| 38 |
+
```python
|
| 39 |
+
import torch
|
| 40 |
+
from transformers import AutoImageProcessor
|
| 41 |
+
|
| 42 |
+
# Load
|
| 43 |
+
checkpoint = torch.load('multihead_model.pt')
|
| 44 |
+
processor = AutoImageProcessor.from_pretrained('path/to/model')
|
| 45 |
+
|
| 46 |
+
# Create model class (see notebook for full class definition)
|
| 47 |
+
# model = MultiHeadContentModerator(...)
|
| 48 |
+
# model.load_state_dict(checkpoint['model_state_dict'])
|
| 49 |
+
|
| 50 |
+
# Inference
|
| 51 |
+
inputs = processor(images=image, return_tensors='pt')
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
# Get both predictions
|
| 54 |
+
outputs = model(inputs['pixel_values'], task='both')
|
| 55 |
+
nsfw_pred = outputs['nsfw'].argmax(-1)
|
| 56 |
+
violence_pred = outputs['violence'].argmax(-1)
|
| 57 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_model": "Falconsai/nsfw_image_detection",
|
| 3 |
+
"hidden_size": 768,
|
| 4 |
+
"num_violence_labels": 2,
|
| 5 |
+
"violence_id2label": {
|
| 6 |
+
"0": "safe",
|
| 7 |
+
"1": "violence"
|
| 8 |
+
},
|
| 9 |
+
"nsfw_id2label": {
|
| 10 |
+
"0": "normal",
|
| 11 |
+
"1": "nsfw"
|
| 12 |
+
},
|
| 13 |
+
"model_type": "MultiHeadContentModerator"
|
| 14 |
+
}
|
model_class.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
| 5 |
+
|
| 6 |
+
class MultiHeadContentModerator(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Multi-task model with two classification heads:
|
| 9 |
+
- Head 1: NSFW detection (frozen, pretrained)
|
| 10 |
+
- Head 2: Violence detection (trainable)
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, base_model_name="Falconsai/nsfw_image_detection", num_violence_labels=2):
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
# Load base model
|
| 16 |
+
original_model = AutoModelForImageClassification.from_pretrained(base_model_name)
|
| 17 |
+
hidden_size = original_model.config.hidden_size
|
| 18 |
+
|
| 19 |
+
# ViT backbone (shared)
|
| 20 |
+
self.vit = original_model.vit
|
| 21 |
+
|
| 22 |
+
# Head 1: Original NSFW classifier
|
| 23 |
+
self.nsfw_classifier = original_model.classifier
|
| 24 |
+
|
| 25 |
+
# Head 2: Violence classifier
|
| 26 |
+
self.violence_classifier = nn.Linear(hidden_size, num_violence_labels)
|
| 27 |
+
|
| 28 |
+
# Label mappings - use actual Falconsai config
|
| 29 |
+
self.nsfw_id2label = original_model.config.id2label # {0: 'normal', 1: 'nsfw'}
|
| 30 |
+
self.violence_id2label = {0: 'safe', 1: 'violence'} # Will be overwritten from checkpoint
|
| 31 |
+
|
| 32 |
+
def forward(self, pixel_values, task='both'):
|
| 33 |
+
outputs = self.vit(pixel_values=pixel_values)
|
| 34 |
+
pooled_output = outputs.last_hidden_state[:, 0]
|
| 35 |
+
|
| 36 |
+
if task == 'nsfw':
|
| 37 |
+
return self.nsfw_classifier(pooled_output)
|
| 38 |
+
elif task == 'violence':
|
| 39 |
+
return self.violence_classifier(pooled_output)
|
| 40 |
+
elif task == 'both':
|
| 41 |
+
return {
|
| 42 |
+
'nsfw': self.nsfw_classifier(pooled_output),
|
| 43 |
+
'violence': self.violence_classifier(pooled_output)
|
| 44 |
+
}
|
| 45 |
+
return self.violence_classifier(pooled_output)
|
| 46 |
+
|
| 47 |
+
def load_multihead_model(checkpoint_path, device='cuda'):
|
| 48 |
+
"""Load trained multi-head model"""
|
| 49 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 50 |
+
|
| 51 |
+
model = MultiHeadContentModerator(
|
| 52 |
+
base_model_name=checkpoint['base_model'],
|
| 53 |
+
num_violence_labels=checkpoint['num_violence_labels']
|
| 54 |
+
)
|
| 55 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 56 |
+
model.violence_id2label = checkpoint['violence_id2label']
|
| 57 |
+
model.nsfw_id2label = checkpoint['nsfw_id2label']
|
| 58 |
+
|
| 59 |
+
return model.to(device)
|
multihead_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e3ac359e78361b9a1e71071f61550c3617376d41104a9dba9cebf7fe2ad26bf
|
| 3 |
+
size 343290062
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_convert_rgb": null,
|
| 3 |
+
"do_normalize": true,
|
| 4 |
+
"do_rescale": true,
|
| 5 |
+
"do_resize": true,
|
| 6 |
+
"image_mean": [
|
| 7 |
+
0.5,
|
| 8 |
+
0.5,
|
| 9 |
+
0.5
|
| 10 |
+
],
|
| 11 |
+
"image_processor_type": "ViTImageProcessor",
|
| 12 |
+
"image_std": [
|
| 13 |
+
0.5,
|
| 14 |
+
0.5,
|
| 15 |
+
0.5
|
| 16 |
+
],
|
| 17 |
+
"resample": 2,
|
| 18 |
+
"rescale_factor": 0.00392156862745098,
|
| 19 |
+
"size": {
|
| 20 |
+
"height": 224,
|
| 21 |
+
"width": 224
|
| 22 |
+
}
|
| 23 |
+
}
|