Ali7880 commited on
Commit
31e7458
·
verified ·
1 Parent(s): 0870f9d

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +57 -3
  2. config.json +14 -0
  3. model_class.py +59 -0
  4. multihead_model.pt +3 -0
  5. preprocessor_config.json +23 -0
README.md CHANGED
@@ -1,3 +1,57 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: Falconsai/nsfw_image_detection
4
+ tags:
5
+ - image-classification
6
+ - content-moderation
7
+ - violence-detection
8
+ - nsfw-detection
9
+ - multi-task-learning
10
+ ---
11
+
12
+ # Multi-Head Content Moderator
13
+
14
+ A multi-task image moderation model with **two classification heads**:
15
+ - **NSFW Detection**: Detects explicit/adult content (preserved from Falconsai)
16
+ - **Violence Detection**: Detects violent content (newly trained)
17
+
18
+ ## Architecture
19
+ - Base: ViT (Vision Transformer) from Falconsai/nsfw_image_detection
20
+ - Head 1: NSFW classifier (frozen, pretrained)
21
+ - Head 2: Violence classifier (trained on violence dataset)
22
+
23
+ ## Categories
24
+
25
+ ### NSFW Head
26
+ - nsfw
27
+ - safe
28
+
29
+ ### Violence Head
30
+ - safe
31
+ - violence
32
+
33
+ ## Performance (Violence Detection)
34
+ - Accuracy: 0.9075
35
+ - F1 Score: 0.9076
36
+
37
+ ## Usage
38
+ ```python
39
+ import torch
40
+ from transformers import AutoImageProcessor
41
+
42
+ # Load
43
+ checkpoint = torch.load('multihead_model.pt')
44
+ processor = AutoImageProcessor.from_pretrained('path/to/model')
45
+
46
+ # Create model class (see notebook for full class definition)
47
+ # model = MultiHeadContentModerator(...)
48
+ # model.load_state_dict(checkpoint['model_state_dict'])
49
+
50
+ # Inference
51
+ inputs = processor(images=image, return_tensors='pt')
52
+ with torch.no_grad():
53
+ # Get both predictions
54
+ outputs = model(inputs['pixel_values'], task='both')
55
+ nsfw_pred = outputs['nsfw'].argmax(-1)
56
+ violence_pred = outputs['violence'].argmax(-1)
57
+ ```
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "Falconsai/nsfw_image_detection",
3
+ "hidden_size": 768,
4
+ "num_violence_labels": 2,
5
+ "violence_id2label": {
6
+ "0": "safe",
7
+ "1": "violence"
8
+ },
9
+ "nsfw_id2label": {
10
+ "0": "normal",
11
+ "1": "nsfw"
12
+ },
13
+ "model_type": "MultiHeadContentModerator"
14
+ }
model_class.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
5
+
6
+ class MultiHeadContentModerator(nn.Module):
7
+ """
8
+ Multi-task model with two classification heads:
9
+ - Head 1: NSFW detection (frozen, pretrained)
10
+ - Head 2: Violence detection (trainable)
11
+ """
12
+ def __init__(self, base_model_name="Falconsai/nsfw_image_detection", num_violence_labels=2):
13
+ super().__init__()
14
+
15
+ # Load base model
16
+ original_model = AutoModelForImageClassification.from_pretrained(base_model_name)
17
+ hidden_size = original_model.config.hidden_size
18
+
19
+ # ViT backbone (shared)
20
+ self.vit = original_model.vit
21
+
22
+ # Head 1: Original NSFW classifier
23
+ self.nsfw_classifier = original_model.classifier
24
+
25
+ # Head 2: Violence classifier
26
+ self.violence_classifier = nn.Linear(hidden_size, num_violence_labels)
27
+
28
+ # Label mappings - use actual Falconsai config
29
+ self.nsfw_id2label = original_model.config.id2label # {0: 'normal', 1: 'nsfw'}
30
+ self.violence_id2label = {0: 'safe', 1: 'violence'} # Will be overwritten from checkpoint
31
+
32
+ def forward(self, pixel_values, task='both'):
33
+ outputs = self.vit(pixel_values=pixel_values)
34
+ pooled_output = outputs.last_hidden_state[:, 0]
35
+
36
+ if task == 'nsfw':
37
+ return self.nsfw_classifier(pooled_output)
38
+ elif task == 'violence':
39
+ return self.violence_classifier(pooled_output)
40
+ elif task == 'both':
41
+ return {
42
+ 'nsfw': self.nsfw_classifier(pooled_output),
43
+ 'violence': self.violence_classifier(pooled_output)
44
+ }
45
+ return self.violence_classifier(pooled_output)
46
+
47
+ def load_multihead_model(checkpoint_path, device='cuda'):
48
+ """Load trained multi-head model"""
49
+ checkpoint = torch.load(checkpoint_path, map_location=device)
50
+
51
+ model = MultiHeadContentModerator(
52
+ base_model_name=checkpoint['base_model'],
53
+ num_violence_labels=checkpoint['num_violence_labels']
54
+ )
55
+ model.load_state_dict(checkpoint['model_state_dict'])
56
+ model.violence_id2label = checkpoint['violence_id2label']
57
+ model.nsfw_id2label = checkpoint['nsfw_id2label']
58
+
59
+ return model.to(device)
multihead_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e3ac359e78361b9a1e71071f61550c3617376d41104a9dba9cebf7fe2ad26bf
3
+ size 343290062
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "ViTImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 224,
21
+ "width": 224
22
+ }
23
+ }