File size: 3,436 Bytes
c88b91b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
HF-compatible wrapper that turns the FastViT backbone into a pure *image encoder*.
Output: a single (B, embed_dim) vector obtained with the built-in GlobalPool2D head.
"""
import torch
from transformers import PreTrainedModel, PretrainedConfig
from .mci import fastvithd, GlobalPool2D     # imports your backbone factory


# -----------------------  Config  -----------------------
class FastViTImageConfig(PretrainedConfig):
    """Minimal config so HF knows the image size & embed dim."""
    model_type = "fastvit_image_encoder"

    def __init__(
        self,
        image_size: int = 1024,
        embed_dim:  int = 3072,   # channels after conv_exp
        patch_size: int = 16,
        **kwargs
    ):
        self.image_size = image_size
        self.embed_dim  = embed_dim
        self.patch_size = patch_size
        super().__init__(**kwargs)


# -----------------------  Model  ------------------------
class FastViTImageEncoder(PreTrainedModel):
    """
    Wraps FastViT-HD and exposes an `.embeddings` output;
    no text tower, no CLIP logits, only a pooled image embedding.
    """
    config_class    = FastViTImageConfig
    main_input_name = "pixel_values"

    def __init__(self, config: FastViTImageConfig):
        super().__init__(config)

        # We **keep** GlobalPool2D by asking for `num_classes = embed_dim`
        # (FastViT replaces the classifier with GlobalPool2D in that case).
        self.backbone = fastvithd(num_classes=0)
        self.backbone.head = GlobalPool2D(
            in_dim  = 3072,
            out_dim = 768
        )

        # HF helper that registers weights for bf16/half-precision etc.
        self.post_init()

    # ------------------------------------------
    def forward(self, pixel_values, return_dict=True, **unused):
        """
        Args:
            pixel_values: (B, 3, H, W) tensor (already resized/normalized).
        Returns:
            Dict with a single key `"embeddings"` of shape (B, embed_dim).
        """
        # FastViT forward returns the pooled tensor directly because
        # `num_classes == embed_dim` and head == GlobalPool2D.
        embeddings = self.backbone(pixel_values)     # (B, embed_dim)

        if not return_dict:
            return (embeddings,)

        return {"embeddings": embeddings}
    
    def forward(self, images):
        return self.forward_images(images)

    def feature_select(self, image_forward_outs):
        # Features from penultimate layer
        image_features = image_forward_outs["image_embeddings"]

        # Reshape 4D tensor to 3D
        B, C, H, W = image_features.shape
        image_features = image_features.reshape(B, C, H*W)
        image_features = image_features.transpose(1, 2)
        return image_features

    def forward_images(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                image_forward_out = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
                image_feature = self.feature_select(image_forward_out).to(image.dtype)
                image_features.append(image_feature)
        else:
            image_forward_outs = self.backbone(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
            image_features = self.feature_select(image_forward_outs).to(images.dtype)

        return image_features