Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import mobileclip | |
| import open_clip | |
| from huggingface_hub import hf_hub_download | |
| class MobileCLIPRanker(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| model_name = cfg.model.name.lower() | |
| self.model_type = "mobileclip" | |
| if "l14" in model_name or "l-14" in model_name: | |
| self.model_type = "open_clip" | |
| repo_id = "apple/MobileCLIP2-L-14" | |
| filename = "mobileclip2_l14.pt" | |
| self.backbone_dim = 768 | |
| print(f"Initializing Teacher (L14)...") | |
| ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| model, _, _ = open_clip.create_model_and_transforms('MobileCLIP2-L-14', pretrained=ckpt_path) | |
| self.backbone = model.visual | |
| else: | |
| repo_id = "apple/MobileCLIP2-B" | |
| filename = "mobileclip2_b.pt" | |
| arch = "mobileclip_b" | |
| self.backbone_dim = 512 | |
| print(f"Initializing Student ({arch})...") | |
| ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| model, _, _ = mobileclip.create_model_and_transforms(arch, pretrained=ckpt_path) | |
| self.backbone = model.image_encoder | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| self.head = nn.Linear(self.backbone_dim, 1) | |
| def forward(self, x, valid_lens=None): | |
| b, g, c, h, w = x.shape | |
| x_flat = x.view(b * g, c, h, w) | |
| if self.model_type == "open_clip": | |
| features = self.backbone(x_flat) | |
| else: | |
| features = self.backbone(x_flat) | |
| if isinstance(features, tuple): | |
| features = features[0] | |
| features = features.view(b, g, -1) | |
| scores = self.head(features) | |
| return scores |