import torch import torch.nn as nn import torch.nn.functional as F import clip class CLIP_ViT_B_32(nn.Module): def __init__(self, pretrained=True): super().__init__() if pretrained: self.model, _ = clip.load("ViT-B/32", device="cpu") else: self.model, _ = clip.load("ViT-B/32", device="cpu", jit=False) # Remove the final classification layer self.feature_dim = 512 # CLIP ViT-B/32 feature dimension def freeze_weights(self): """Freeze all CLIP weights""" for param in self.model.parameters(): param.requires_grad = False def forward(self, x): """Extract features from CLIP""" with torch.no_grad() if not self.training else torch.enable_grad(): features = self.model.encode_image(x) return features.float() class CosineClassifier(nn.Module): def __init__(self, dim, num_classes, scale_factor=30): super().__init__() self.num_classes = num_classes self.scale_factor = scale_factor self.weights = nn.Parameter(torch.randn(num_classes, dim)) nn.init.xavier_uniform_(self.weights) def forward(self, x): # Normalize input features x = F.normalize(x, dim=-1) # Normalize weights w = F.normalize(self.weights, dim=-1) # Compute cosine similarity and scale logits = torch.matmul(x, w.T) * self.scale_factor return logits class FaceFairnessModel(nn.Module): def __init__(self, num_classes, scale_factor=30): super().__init__() # CLIP feature extractor self.clip_model = CLIP_ViT_B_32(pretrained=True) self.clip_model.freeze_weights() # Cosine classifier self.classifier = CosineClassifier( dim=self.clip_model.feature_dim, num_classes=num_classes, scale_factor=scale_factor ) def forward(self, x): # Extract features with CLIP features = self.clip_model(x) # Classify with cosine classifier logits = self.classifier(features) return logits