import torch import torch.nn as nn from transformers import CLIPModel, PreTrainedModel, PretrainedConfig class DISCOConfig(PretrainedConfig): """Configuration for DISCO model.""" model_type = "clip_nsfw_detector" def __init__( self, clip_model_name: str = "openai/clip-vit-base-patch32", num_classes: int = 2, threshold: float = 0.5, **kwargs ): super().__init__(**kwargs) self.clip_model_name = clip_model_name self.num_classes = num_classes self.threshold = threshold class DISCO(PreTrainedModel): """ DISCO model combining CLIP image encoder and classification head. """ config_class = DISCOConfig def __init__(self, config: DISCOConfig): super().__init__(config) self.clip_model = CLIPModel.from_pretrained(config.clip_model_name) self.clip_model.eval() # Freeze CLIP, only classifier is trainable # Get embedding dimension from CLIP config embedding_dim = self.clip_model.config.projection_dim # Direct linear classifier layer (equivalent to logistic regression) self.classifier = nn.Linear(embedding_dim, config.num_classes) self.threshold = config.threshold def forward(self, pixel_values: torch.Tensor, **kwargs) -> torch.Tensor: """ Forward pass through CLIP and classifier. Args: pixel_values: Preprocessed image tensors (batch_size, channels, height, width) Returns: Logits for binary classification """ # Get image features from CLIP with torch.no_grad(): image_features = self.clip_model.get_image_features( pixel_values=pixel_values) # Normalize embeddings (CLIP uses normalized embeddings) image_features = image_features / \ image_features.norm(dim=-1, keepdim=True) # Pass through classifier logits = self.classifier(image_features) return logits def predict_proba(self, pixel_values: torch.Tensor) -> torch.Tensor: """Get probability predictions.""" logits = self.forward(pixel_values) return torch.softmax(logits, dim=-1) def predict(self, pixel_values: torch.Tensor, threshold: float = None) -> torch.Tensor: """Get binary predictions.""" if threshold is None: threshold = self.threshold proba = self.predict_proba(pixel_values) return (proba[:, 1] >= threshold).long()