clipick / model.py
Nightfury16's picture
Restore API endpoint and fix UI
651f200
import torch
import torch.nn as nn
import mobileclip
import open_clip
from huggingface_hub import hf_hub_download
class MobileCLIPRanker(nn.Module):
def __init__(self, cfg):
super().__init__()
self.model_type = "mobileclip"
if "l14" in cfg.model.name.lower() or "l-14" in cfg.model.name.lower():
self.model_type = "open_clip"
repo_id = "apple/MobileCLIP2-L-14"
filename = "mobileclip2_l14.pt"
print(f"Initializing OpenCLIP L14 backbone...")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
model, _, _ = open_clip.create_model_and_transforms(
'MobileCLIP2-L-14',
pretrained=ckpt_path
)
self.backbone = model.visual
self.backbone_dim = 768
else:
if "s3" in cfg.model.name.lower():
repo_id = "apple/MobileCLIP2-S3"
filename = "mobileclip2_s3.pt"
arch = "mobileclip2_s3"
elif "s0" in cfg.model.name.lower():
repo_id = "apple/MobileCLIP2-S0"
filename = "mobileclip2_s0.pt"
arch = "mobileclip2_s0"
else:
repo_id = "apple/MobileCLIP2-B"
filename = "mobileclip2_b.pt"
arch = "mobileclip_b"
print(f"Initializing MobileCLIP {arch} backbone...")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
model, _, _ = mobileclip.create_model_and_transforms(arch, pretrained=ckpt_path)
self.backbone = model.image_encoder
self.backbone_dim = 512
self.backbone.eval()
for param in self.backbone.parameters():
param.requires_grad = False
params_to_train = list(self.backbone.named_parameters())[-60:]
for name, param in params_to_train:
param.requires_grad = True
self.head = nn.Linear(self.backbone_dim, 1)
def train(self, mode=True):
super().train(mode)
self.backbone.eval()
return self
def forward(self, x, valid_lens=None):
b, g, c, h, w = x.shape
x_flat = x.view(b * g, c, h, w)
with torch.no_grad():
if self.model_type == "open_clip":
features = self.backbone(x_flat)
else:
features = self.backbone(x_flat)
features = features.view(b, g, -1)
scores = self.head(features)
return scores