Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import clip | |
| class CLIP_ViT_B_32(nn.Module): | |
| def __init__(self, pretrained=True): | |
| super().__init__() | |
| if pretrained: | |
| self.model, _ = clip.load("ViT-B/32", device="cpu") | |
| else: | |
| self.model, _ = clip.load("ViT-B/32", device="cpu", jit=False) | |
| # Remove the final classification layer | |
| self.feature_dim = 512 # CLIP ViT-B/32 feature dimension | |
| def freeze_weights(self): | |
| """Freeze all CLIP weights""" | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x): | |
| """Extract features from CLIP""" | |
| with torch.no_grad() if not self.training else torch.enable_grad(): | |
| features = self.model.encode_image(x) | |
| return features.float() | |
| class CosineClassifier(nn.Module): | |
| def __init__(self, dim, num_classes, scale_factor=30): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.scale_factor = scale_factor | |
| self.weights = nn.Parameter(torch.randn(num_classes, dim)) | |
| nn.init.xavier_uniform_(self.weights) | |
| def forward(self, x): | |
| # Normalize input features | |
| x = F.normalize(x, dim=-1) | |
| # Normalize weights | |
| w = F.normalize(self.weights, dim=-1) | |
| # Compute cosine similarity and scale | |
| logits = torch.matmul(x, w.T) * self.scale_factor | |
| return logits | |
| class FaceFairnessModel(nn.Module): | |
| def __init__(self, num_classes, scale_factor=30): | |
| super().__init__() | |
| # CLIP feature extractor | |
| self.clip_model = CLIP_ViT_B_32(pretrained=True) | |
| self.clip_model.freeze_weights() | |
| # Cosine classifier | |
| self.classifier = CosineClassifier( | |
| dim=self.clip_model.feature_dim, | |
| num_classes=num_classes, | |
| scale_factor=scale_factor | |
| ) | |
| def forward(self, x): | |
| # Extract features with CLIP | |
| features = self.clip_model(x) | |
| # Classify with cosine classifier | |
| logits = self.classifier(features) | |
| return logits |