Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import mobileclip | |
| import open_clip | |
| from huggingface_hub import hf_hub_download | |
| class MobileCLIPRanker(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.model_type = "mobileclip" | |
| if "l14" in cfg.model.name.lower() or "l-14" in cfg.model.name.lower(): | |
| self.model_type = "open_clip" | |
| repo_id = "apple/MobileCLIP2-L-14" | |
| filename = "mobileclip2_l14.pt" | |
| print(f"Initializing OpenCLIP L14 backbone...") | |
| ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| model, _, _ = open_clip.create_model_and_transforms( | |
| 'MobileCLIP2-L-14', | |
| pretrained=ckpt_path | |
| ) | |
| self.backbone = model.visual | |
| self.backbone_dim = 768 | |
| else: | |
| if "s3" in cfg.model.name.lower(): | |
| repo_id = "apple/MobileCLIP2-S3" | |
| filename = "mobileclip2_s3.pt" | |
| arch = "mobileclip2_s3" | |
| elif "s0" in cfg.model.name.lower(): | |
| repo_id = "apple/MobileCLIP2-S0" | |
| filename = "mobileclip2_s0.pt" | |
| arch = "mobileclip2_s0" | |
| else: | |
| repo_id = "apple/MobileCLIP2-B" | |
| filename = "mobileclip2_b.pt" | |
| arch = "mobileclip_b" | |
| print(f"Initializing MobileCLIP {arch} backbone...") | |
| ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| model, _, _ = mobileclip.create_model_and_transforms(arch, pretrained=ckpt_path) | |
| self.backbone = model.image_encoder | |
| self.backbone_dim = 512 | |
| self.backbone.eval() | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| params_to_train = list(self.backbone.named_parameters())[-60:] | |
| for name, param in params_to_train: | |
| param.requires_grad = True | |
| self.head = nn.Linear(self.backbone_dim, 1) | |
| def train(self, mode=True): | |
| super().train(mode) | |
| self.backbone.eval() | |
| return self | |
| def forward(self, x, valid_lens=None): | |
| b, g, c, h, w = x.shape | |
| x_flat = x.view(b * g, c, h, w) | |
| with torch.no_grad(): | |
| if self.model_type == "open_clip": | |
| features = self.backbone(x_flat) | |
| else: | |
| features = self.backbone(x_flat) | |
| features = features.view(b, g, -1) | |
| scores = self.head(features) | |
| return scores |