clipick / model.py
Nightfury16's picture
Initial commit
2ecc846
import torch
import torch.nn as nn
import mobileclip
import open_clip
from huggingface_hub import hf_hub_download
class MobileCLIPRanker(nn.Module):
def __init__(self, cfg):
super().__init__()
model_name = cfg.model.name.lower()
self.model_type = "mobileclip"
if "l14" in model_name or "l-14" in model_name:
self.model_type = "open_clip"
repo_id = "apple/MobileCLIP2-L-14"
filename = "mobileclip2_l14.pt"
self.backbone_dim = 768
print(f"Initializing Teacher (L14)...")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
model, _, _ = open_clip.create_model_and_transforms('MobileCLIP2-L-14', pretrained=ckpt_path)
self.backbone = model.visual
else:
repo_id = "apple/MobileCLIP2-B"
filename = "mobileclip2_b.pt"
arch = "mobileclip_b"
self.backbone_dim = 512
print(f"Initializing Student ({arch})...")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
model, _, _ = mobileclip.create_model_and_transforms(arch, pretrained=ckpt_path)
self.backbone = model.image_encoder
for param in self.backbone.parameters():
param.requires_grad = False
self.head = nn.Linear(self.backbone_dim, 1)
def forward(self, x, valid_lens=None):
b, g, c, h, w = x.shape
x_flat = x.view(b * g, c, h, w)
if self.model_type == "open_clip":
features = self.backbone(x_flat)
else:
features = self.backbone(x_flat)
if isinstance(features, tuple):
features = features[0]
features = features.view(b, g, -1)
scores = self.head(features)
return scores