File size: 1,791 Bytes
2ecc846
 
 
 
 
3aa023a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ecc846
3aa023a
2ecc846
3aa023a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ecc846
 
3aa023a
 
 
 
2ecc846
3aa023a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn
import open_clip
from huggingface_hub import hf_hub_download


class RankingHead(nn.Module):
    """2-layer MLP head with dropout — matches training checkpoint layout:
    head.net.0  Linear(in_dim, hidden_dim)
    head.net.1  GELU
    head.net.2  Dropout
    head.net.3  Linear(hidden_dim, 1)
    """

    def __init__(self, in_dim, hidden_dim=256, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x):
        return self.net(x)


class MobileCLIPRanker(nn.Module):
    def __init__(self, backbone_dim=768, head_hidden_dim=256, head_dropout=0.1):
        super().__init__()
        self.backbone_dim = backbone_dim

        print("Initializing MobileCLIP2-L14 backbone...")
        ckpt_path = hf_hub_download(
            repo_id="apple/MobileCLIP2-L-14",
            filename="mobileclip2_l14.pt",
        )
        model, _, _ = open_clip.create_model_and_transforms(
            "MobileCLIP2-L-14", pretrained=ckpt_path
        )
        self.backbone = model.visual

        self.backbone.eval()
        for p in self.backbone.parameters():
            p.requires_grad = False

        self.head = RankingHead(backbone_dim, head_hidden_dim, head_dropout)

    def train(self, mode=True):
        super().train(mode)
        self.backbone.eval()
        return self

    def forward(self, x, valid_lens=None):
        if x.dim() == 5:
            b, g, c, h, w = x.shape
            features = self.backbone(x.view(b * g, c, h, w))
            features = features.view(b, g, -1)
        else:
            features = x
        return self.head(features)