|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import CLIPModel, PreTrainedModel, PretrainedConfig |
|
|
|
|
|
|
|
|
class DISCOConfig(PretrainedConfig): |
|
|
"""Configuration for DISCO model.""" |
|
|
model_type = "clip_nsfw_detector" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
clip_model_name: str = "openai/clip-vit-base-patch32", |
|
|
num_classes: int = 2, |
|
|
threshold: float = 0.5, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.clip_model_name = clip_model_name |
|
|
self.num_classes = num_classes |
|
|
self.threshold = threshold |
|
|
|
|
|
|
|
|
class DISCO(PreTrainedModel): |
|
|
""" |
|
|
DISCO model combining CLIP image encoder and classification head. |
|
|
""" |
|
|
config_class = DISCOConfig |
|
|
|
|
|
def __init__(self, config: DISCOConfig): |
|
|
super().__init__(config) |
|
|
self.clip_model = CLIPModel.from_pretrained(config.clip_model_name) |
|
|
self.clip_model.eval() |
|
|
|
|
|
|
|
|
embedding_dim = self.clip_model.config.projection_dim |
|
|
|
|
|
self.classifier = nn.Linear(embedding_dim, config.num_classes) |
|
|
self.threshold = config.threshold |
|
|
|
|
|
def forward(self, pixel_values: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through CLIP and classifier. |
|
|
|
|
|
Args: |
|
|
pixel_values: Preprocessed image tensors (batch_size, channels, height, width) |
|
|
|
|
|
Returns: |
|
|
Logits for binary classification |
|
|
""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_features = self.clip_model.get_image_features( |
|
|
pixel_values=pixel_values) |
|
|
|
|
|
image_features = image_features / \ |
|
|
image_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
logits = self.classifier(image_features) |
|
|
return logits |
|
|
|
|
|
def predict_proba(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
|
"""Get probability predictions.""" |
|
|
logits = self.forward(pixel_values) |
|
|
return torch.softmax(logits, dim=-1) |
|
|
|
|
|
def predict(self, pixel_values: torch.Tensor, threshold: float = None) -> torch.Tensor: |
|
|
"""Get binary predictions.""" |
|
|
if threshold is None: |
|
|
threshold = self.threshold |
|
|
proba = self.predict_proba(pixel_values) |
|
|
return (proba[:, 1] >= threshold).long() |
|
|
|