File size: 8,712 Bytes
94a0812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# models/resnet_encoder.py
# Just for setting up the pipeline, this will be replaced

import torch
import torch.nn as nn
from torchvision import models
from .base_encoder import BaseVisionEncoder
from transformers import CLIPVisionModel
from transformers import AutoModel


class ResnetCNNEncoder(nn.Module): 
    def __init__(self, model_name="resnet50", fine_tune_all_encoder_layers=False, freeze_encoder_entirely=False, freeze_until=3, vision_mode="patch"):
        super().__init__()
        self.embed_dim = 2048 # Fixed output dimension for ResNet-50/101
        self.vision_mode = vision_mode

        if model_name == "resnet50":
            resnet = models.resnet50(
                weights=models.ResNet50_Weights.IMAGENET1K_V2
            )
        elif model_name == "resnet101":
            resnet = models.resnet101(
                weights=models.ResNet101_Weights.IMAGENET1K_V2
            )
        else:
            raise ValueError("model_name must be 'resnet50' or 'resnet101'")

        # Keep encoder layers only (no classifier head)
        # ResNet children indices: 0:conv1, 1:bn1, 2:relu, 3:maxpool, 4:layer1, 5:layer2, 6:layer3, 7:layer4
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.model = self.features


        # Full Fine-Tuning Mode
        if fine_tune_all_encoder_layers:
            print("[INFO] ResNet-50: Fine-tuning ALL layers (1-4).")
            # PyTorch defaults to requires_grad=True, so no action is needed here.
            return

        # Full Freezing Mode
        elif freeze_encoder_entirely:
            print("[INFO] ResNet-50: Freezing ALL layers (1-4).")
            for param in self.features.parameters():
                param.requires_grad = False
            return
            
        # Dynamic Partial Freezing Mode 
        else:
            # freeze_until=3 is the default behavior (freeze L1-L3, train L4)
            # freeze_until=2 means freeze L1-L2, train L3-L4
            print(f"[INFO] ResNet-50: Dynamic partial fine-tuning (Freezing Layers 1-{freeze_until}).")
            
            # To freeze up to layer N (L1, L2, L3, or L4), we freeze all indices from 0 up to N+3.
            max_freeze_idx = freeze_until + 3 
            
            # Create a set of indices to freeze (from 0 up to max_freeze_idx, inclusive)
            freeze_indices = set(range(max_freeze_idx + 1))

            for idx, layer in enumerate(self.features):
                # Ensure we only process layers up to the target index
                if idx in freeze_indices:
                    for param in layer.parameters():
                        param.requires_grad = False
                        

    def forward(self, pixel_values):
        x = self.features(pixel_values)      # (B, 2048, H, W)

        if self.vision_mode == "cls":
            x_flat = x.flatten(1)            # (B, 2048*H*W)
            return {"image_embeds": x_flat}

        tokens = x.flatten(2).transpose(1, 2)  # (B, S, 2048)
        return {"image_embeds": tokens}


    def get_output_dim(self):
        return self.embed_dim



# ViT  Encoders

class ViTEncoder(BaseVisionEncoder):

    def __init__(self, model_name="google/vit-base-patch16-224", train_last_n_layers=4, vision_mode="patch"):
        super().__init__(embed_dim=None) 

        self.model = AutoModel.from_pretrained(model_name)
        self.vision_mode = vision_mode

        self.embed_dim = self.model.config.hidden_size
        if self.embed_dim is None:
            raise ValueError("Could not determine embed_dim from model config.")
        
        # Partial Fine-Tuning Strategy
        # Strategy: Freeze first 8 layers (0-7), train last 4 layers (8-11), 
        # plus embeddings and final LayerNorm. (Total layers = 12 for ViT-Base)
        
        # Freeze all parameters initially
        for param in self.model.parameters():
            param.requires_grad = False
            
        # Unfreeze the final N transformer blocks
        NUM_LAYERS_TO_TRAIN = train_last_n_layers
        
        try:
            # The layers are typically stored in .encoder.layer
            encoder_layers = self.model.encoder.layer
            num_layers = len(encoder_layers)
            
            # Unfreeze the last NUM_LAYERS_TO_TRAIN blocks
            for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers): 
                layer = encoder_layers[i] 
                for param in layer.parameters():
                    param.requires_grad = True
            
            print(f"ViT Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).")
            
        except AttributeError:
            print("Warning: Could not find standard ViT layer structure for partial fine-tuning.")
            
        #  Unfreeze Positional Embeddings (often gives a small boost)
        if hasattr(self.model.embeddings, 'position_embeddings'):
            self.model.embeddings.position_embeddings.requires_grad = True
            print("ViT Encoder: Unfrozen positional embeddings.")

        #  Unfreeze the final LayerNorm (for stabilization)
        if hasattr(self.model.encoder, 'layernorm'):
            for param in self.model.encoder.layernorm.parameters():
                param.requires_grad = True
            print("ViT Encoder: Unfrozen final LayerNorm.")


    def forward(self, pixel_values):
        out = self.model(pixel_values=pixel_values)

        # CLS MODE
        if self.vision_mode == "cls":
            if hasattr(out, 'pooler_output') and out.pooler_output is not None:
                pooled = out.pooler_output             # (B, D)
            elif hasattr(out, 'last_hidden_state'):
                pooled = out.last_hidden_state[:, 0, :]  # CLS token (B, D)
            else:
                raise RuntimeError("Model output format not recognized.")

            return {"image_embeds": pooled}

        # PATCH
        seq = out.last_hidden_state       # (B, S, D)
        return {"image_embeds": seq}


    def get_output_dim(self):
        return self.embed_dim


# Clip Encoders

class CLIPEncoder(BaseVisionEncoder):

    def __init__(self, model_name="openai/clip-vit-base-patch32", train_last_n_layers=4, vision_mode="patch"):
        # The output dimension (hidden size) will be set after loading the model config
        super().__init__(embed_dim=None) 

        self.model = CLIPVisionModel.from_pretrained(model_name)
        self.vision_mode = vision_mode
        
        self.embed_dim = self.model.config.hidden_size
        if self.embed_dim is None:
            raise ValueError("Could not determine embed_dim from model config.")
        
        # Partial Fine-Tuning Strategy 
        # Strategy: Freeze first 8 layers (0-7), train last 4 layers (8-11), 
        # plus embeddings and final LayerNorm. (Total layers = 12 for ViT-Base)
        
        # Freeze all parameters initially
        for param in self.model.parameters():
            param.requires_grad = False
            
        #  Unfreeze the final N transformer blocks
        NUM_LAYERS_TO_TRAIN = train_last_n_layers
        
        try:
            encoder_layers = self.model.vision_model.encoder.layers
            num_layers = len(encoder_layers)
            
            for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers): 
                layer = encoder_layers[i] 
                for param in layer.parameters():
                    param.requires_grad = True
            
            print(f"CLIP Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).")
            
        except AttributeError:
            print("Warning: Could not find standard CLIP layer structure for partial fine-tuning. Ensure model structure is correct.")
            

        if hasattr(self.model.vision_model.embeddings, 'position_embedding'):
            self.model.vision_model.embeddings.position_embedding.requires_grad = True
            print("CLIP Encoder: Unfrozen positional embeddings.")

        if hasattr(self.model.vision_model, 'post_layernorm'):
            for param in self.model.vision_model.post_layernorm.parameters():
                param.requires_grad = True
            print("CLIP Encoder: Unfrozen final LayerNorm.")

 

    def forward(self, pixel_values):
        out = self.model(pixel_values=pixel_values)
        seq = out.last_hidden_state  # (B, S, D)

        if self.vision_mode == "cls":
            return {"image_embeds": seq[:, 0, :]}  # (B, D)

        return {"image_embeds": seq}               # (B, S, D)



    def get_output_dim(self):
        return self.embed_dim