metythorn's picture
Initial push
0f3fa05
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
class CLIP_ViT_B_32(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
if pretrained:
self.model, _ = clip.load("ViT-B/32", device="cpu")
else:
self.model, _ = clip.load("ViT-B/32", device="cpu", jit=False)
# Remove the final classification layer
self.feature_dim = 512 # CLIP ViT-B/32 feature dimension
def freeze_weights(self):
"""Freeze all CLIP weights"""
for param in self.model.parameters():
param.requires_grad = False
def forward(self, x):
"""Extract features from CLIP"""
with torch.no_grad() if not self.training else torch.enable_grad():
features = self.model.encode_image(x)
return features.float()
class CosineClassifier(nn.Module):
def __init__(self, dim, num_classes, scale_factor=30):
super().__init__()
self.num_classes = num_classes
self.scale_factor = scale_factor
self.weights = nn.Parameter(torch.randn(num_classes, dim))
nn.init.xavier_uniform_(self.weights)
def forward(self, x):
# Normalize input features
x = F.normalize(x, dim=-1)
# Normalize weights
w = F.normalize(self.weights, dim=-1)
# Compute cosine similarity and scale
logits = torch.matmul(x, w.T) * self.scale_factor
return logits
class FaceFairnessModel(nn.Module):
def __init__(self, num_classes, scale_factor=30):
super().__init__()
# CLIP feature extractor
self.clip_model = CLIP_ViT_B_32(pretrained=True)
self.clip_model.freeze_weights()
# Cosine classifier
self.classifier = CosineClassifier(
dim=self.clip_model.feature_dim,
num_classes=num_classes,
scale_factor=scale_factor
)
def forward(self, x):
# Extract features with CLIP
features = self.clip_model(x)
# Classify with cosine classifier
logits = self.classifier(features)
return logits