import torch import torch.nn as nn import mobileclip import open_clip from huggingface_hub import hf_hub_download class MobileCLIPRanker(nn.Module): def __init__(self, cfg): super().__init__() model_name = cfg.model.name.lower() self.model_type = "mobileclip" if "l14" in model_name or "l-14" in model_name: self.model_type = "open_clip" repo_id = "apple/MobileCLIP2-L-14" filename = "mobileclip2_l14.pt" self.backbone_dim = 768 print(f"Initializing Teacher (L14)...") ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) model, _, _ = open_clip.create_model_and_transforms('MobileCLIP2-L-14', pretrained=ckpt_path) self.backbone = model.visual else: repo_id = "apple/MobileCLIP2-B" filename = "mobileclip2_b.pt" arch = "mobileclip_b" self.backbone_dim = 512 print(f"Initializing Student ({arch})...") ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) model, _, _ = mobileclip.create_model_and_transforms(arch, pretrained=ckpt_path) self.backbone = model.image_encoder for param in self.backbone.parameters(): param.requires_grad = False self.head = nn.Linear(self.backbone_dim, 1) def forward(self, x, valid_lens=None): b, g, c, h, w = x.shape x_flat = x.view(b * g, c, h, w) if self.model_type == "open_clip": features = self.backbone(x_flat) else: features = self.backbone(x_flat) if isinstance(features, tuple): features = features[0] features = features.view(b, g, -1) scores = self.head(features) return scores