DISCO-v0.1 / src /model.py
younissk's picture
Upload folder using huggingface_hub
9894d76 verified
import torch
import torch.nn as nn
from transformers import CLIPModel, PreTrainedModel, PretrainedConfig
class DISCOConfig(PretrainedConfig):
"""Configuration for DISCO model."""
model_type = "clip_nsfw_detector"
def __init__(
self,
clip_model_name: str = "openai/clip-vit-base-patch32",
num_classes: int = 2,
threshold: float = 0.5,
**kwargs
):
super().__init__(**kwargs)
self.clip_model_name = clip_model_name
self.num_classes = num_classes
self.threshold = threshold
class DISCO(PreTrainedModel):
"""
DISCO model combining CLIP image encoder and classification head.
"""
config_class = DISCOConfig
def __init__(self, config: DISCOConfig):
super().__init__(config)
self.clip_model = CLIPModel.from_pretrained(config.clip_model_name)
self.clip_model.eval() # Freeze CLIP, only classifier is trainable
# Get embedding dimension from CLIP config
embedding_dim = self.clip_model.config.projection_dim
# Direct linear classifier layer (equivalent to logistic regression)
self.classifier = nn.Linear(embedding_dim, config.num_classes)
self.threshold = config.threshold
def forward(self, pixel_values: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Forward pass through CLIP and classifier.
Args:
pixel_values: Preprocessed image tensors (batch_size, channels, height, width)
Returns:
Logits for binary classification
"""
# Get image features from CLIP
with torch.no_grad():
image_features = self.clip_model.get_image_features(
pixel_values=pixel_values)
# Normalize embeddings (CLIP uses normalized embeddings)
image_features = image_features / \
image_features.norm(dim=-1, keepdim=True)
# Pass through classifier
logits = self.classifier(image_features)
return logits
def predict_proba(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Get probability predictions."""
logits = self.forward(pixel_values)
return torch.softmax(logits, dim=-1)
def predict(self, pixel_values: torch.Tensor, threshold: float = None) -> torch.Tensor:
"""Get binary predictions."""
if threshold is None:
threshold = self.threshold
proba = self.predict_proba(pixel_values)
return (proba[:, 1] >= threshold).long()