File size: 8,712 Bytes
94a0812 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
# models/resnet_encoder.py
# Just for setting up the pipeline, this will be replaced
import torch
import torch.nn as nn
from torchvision import models
from .base_encoder import BaseVisionEncoder
from transformers import CLIPVisionModel
from transformers import AutoModel
class ResnetCNNEncoder(nn.Module):
def __init__(self, model_name="resnet50", fine_tune_all_encoder_layers=False, freeze_encoder_entirely=False, freeze_until=3, vision_mode="patch"):
super().__init__()
self.embed_dim = 2048 # Fixed output dimension for ResNet-50/101
self.vision_mode = vision_mode
if model_name == "resnet50":
resnet = models.resnet50(
weights=models.ResNet50_Weights.IMAGENET1K_V2
)
elif model_name == "resnet101":
resnet = models.resnet101(
weights=models.ResNet101_Weights.IMAGENET1K_V2
)
else:
raise ValueError("model_name must be 'resnet50' or 'resnet101'")
# Keep encoder layers only (no classifier head)
# ResNet children indices: 0:conv1, 1:bn1, 2:relu, 3:maxpool, 4:layer1, 5:layer2, 6:layer3, 7:layer4
self.features = nn.Sequential(*list(resnet.children())[:-1])
self.model = self.features
# Full Fine-Tuning Mode
if fine_tune_all_encoder_layers:
print("[INFO] ResNet-50: Fine-tuning ALL layers (1-4).")
# PyTorch defaults to requires_grad=True, so no action is needed here.
return
# Full Freezing Mode
elif freeze_encoder_entirely:
print("[INFO] ResNet-50: Freezing ALL layers (1-4).")
for param in self.features.parameters():
param.requires_grad = False
return
# Dynamic Partial Freezing Mode
else:
# freeze_until=3 is the default behavior (freeze L1-L3, train L4)
# freeze_until=2 means freeze L1-L2, train L3-L4
print(f"[INFO] ResNet-50: Dynamic partial fine-tuning (Freezing Layers 1-{freeze_until}).")
# To freeze up to layer N (L1, L2, L3, or L4), we freeze all indices from 0 up to N+3.
max_freeze_idx = freeze_until + 3
# Create a set of indices to freeze (from 0 up to max_freeze_idx, inclusive)
freeze_indices = set(range(max_freeze_idx + 1))
for idx, layer in enumerate(self.features):
# Ensure we only process layers up to the target index
if idx in freeze_indices:
for param in layer.parameters():
param.requires_grad = False
def forward(self, pixel_values):
x = self.features(pixel_values) # (B, 2048, H, W)
if self.vision_mode == "cls":
x_flat = x.flatten(1) # (B, 2048*H*W)
return {"image_embeds": x_flat}
tokens = x.flatten(2).transpose(1, 2) # (B, S, 2048)
return {"image_embeds": tokens}
def get_output_dim(self):
return self.embed_dim
# ViT Encoders
class ViTEncoder(BaseVisionEncoder):
def __init__(self, model_name="google/vit-base-patch16-224", train_last_n_layers=4, vision_mode="patch"):
super().__init__(embed_dim=None)
self.model = AutoModel.from_pretrained(model_name)
self.vision_mode = vision_mode
self.embed_dim = self.model.config.hidden_size
if self.embed_dim is None:
raise ValueError("Could not determine embed_dim from model config.")
# Partial Fine-Tuning Strategy
# Strategy: Freeze first 8 layers (0-7), train last 4 layers (8-11),
# plus embeddings and final LayerNorm. (Total layers = 12 for ViT-Base)
# Freeze all parameters initially
for param in self.model.parameters():
param.requires_grad = False
# Unfreeze the final N transformer blocks
NUM_LAYERS_TO_TRAIN = train_last_n_layers
try:
# The layers are typically stored in .encoder.layer
encoder_layers = self.model.encoder.layer
num_layers = len(encoder_layers)
# Unfreeze the last NUM_LAYERS_TO_TRAIN blocks
for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers):
layer = encoder_layers[i]
for param in layer.parameters():
param.requires_grad = True
print(f"ViT Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).")
except AttributeError:
print("Warning: Could not find standard ViT layer structure for partial fine-tuning.")
# Unfreeze Positional Embeddings (often gives a small boost)
if hasattr(self.model.embeddings, 'position_embeddings'):
self.model.embeddings.position_embeddings.requires_grad = True
print("ViT Encoder: Unfrozen positional embeddings.")
# Unfreeze the final LayerNorm (for stabilization)
if hasattr(self.model.encoder, 'layernorm'):
for param in self.model.encoder.layernorm.parameters():
param.requires_grad = True
print("ViT Encoder: Unfrozen final LayerNorm.")
def forward(self, pixel_values):
out = self.model(pixel_values=pixel_values)
# CLS MODE
if self.vision_mode == "cls":
if hasattr(out, 'pooler_output') and out.pooler_output is not None:
pooled = out.pooler_output # (B, D)
elif hasattr(out, 'last_hidden_state'):
pooled = out.last_hidden_state[:, 0, :] # CLS token (B, D)
else:
raise RuntimeError("Model output format not recognized.")
return {"image_embeds": pooled}
# PATCH
seq = out.last_hidden_state # (B, S, D)
return {"image_embeds": seq}
def get_output_dim(self):
return self.embed_dim
# Clip Encoders
class CLIPEncoder(BaseVisionEncoder):
def __init__(self, model_name="openai/clip-vit-base-patch32", train_last_n_layers=4, vision_mode="patch"):
# The output dimension (hidden size) will be set after loading the model config
super().__init__(embed_dim=None)
self.model = CLIPVisionModel.from_pretrained(model_name)
self.vision_mode = vision_mode
self.embed_dim = self.model.config.hidden_size
if self.embed_dim is None:
raise ValueError("Could not determine embed_dim from model config.")
# Partial Fine-Tuning Strategy
# Strategy: Freeze first 8 layers (0-7), train last 4 layers (8-11),
# plus embeddings and final LayerNorm. (Total layers = 12 for ViT-Base)
# Freeze all parameters initially
for param in self.model.parameters():
param.requires_grad = False
# Unfreeze the final N transformer blocks
NUM_LAYERS_TO_TRAIN = train_last_n_layers
try:
encoder_layers = self.model.vision_model.encoder.layers
num_layers = len(encoder_layers)
for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers):
layer = encoder_layers[i]
for param in layer.parameters():
param.requires_grad = True
print(f"CLIP Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).")
except AttributeError:
print("Warning: Could not find standard CLIP layer structure for partial fine-tuning. Ensure model structure is correct.")
if hasattr(self.model.vision_model.embeddings, 'position_embedding'):
self.model.vision_model.embeddings.position_embedding.requires_grad = True
print("CLIP Encoder: Unfrozen positional embeddings.")
if hasattr(self.model.vision_model, 'post_layernorm'):
for param in self.model.vision_model.post_layernorm.parameters():
param.requires_grad = True
print("CLIP Encoder: Unfrozen final LayerNorm.")
def forward(self, pixel_values):
out = self.model(pixel_values=pixel_values)
seq = out.last_hidden_state # (B, S, D)
if self.vision_mode == "cls":
return {"image_embeds": seq[:, 0, :]} # (B, D)
return {"image_embeds": seq} # (B, S, D)
def get_output_dim(self):
return self.embed_dim |