coco-demo / models /encoders.py
evanec's picture
Upload 5 files
94a0812 verified
# models/resnet_encoder.py
# Just for setting up the pipeline, this will be replaced
import torch
import torch.nn as nn
from torchvision import models
from .base_encoder import BaseVisionEncoder
from transformers import CLIPVisionModel
from transformers import AutoModel
class ResnetCNNEncoder(nn.Module):
def __init__(self, model_name="resnet50", fine_tune_all_encoder_layers=False, freeze_encoder_entirely=False, freeze_until=3, vision_mode="patch"):
super().__init__()
self.embed_dim = 2048 # Fixed output dimension for ResNet-50/101
self.vision_mode = vision_mode
if model_name == "resnet50":
resnet = models.resnet50(
weights=models.ResNet50_Weights.IMAGENET1K_V2
)
elif model_name == "resnet101":
resnet = models.resnet101(
weights=models.ResNet101_Weights.IMAGENET1K_V2
)
else:
raise ValueError("model_name must be 'resnet50' or 'resnet101'")
# Keep encoder layers only (no classifier head)
# ResNet children indices: 0:conv1, 1:bn1, 2:relu, 3:maxpool, 4:layer1, 5:layer2, 6:layer3, 7:layer4
self.features = nn.Sequential(*list(resnet.children())[:-1])
self.model = self.features
# Full Fine-Tuning Mode
if fine_tune_all_encoder_layers:
print("[INFO] ResNet-50: Fine-tuning ALL layers (1-4).")
# PyTorch defaults to requires_grad=True, so no action is needed here.
return
# Full Freezing Mode
elif freeze_encoder_entirely:
print("[INFO] ResNet-50: Freezing ALL layers (1-4).")
for param in self.features.parameters():
param.requires_grad = False
return
# Dynamic Partial Freezing Mode
else:
# freeze_until=3 is the default behavior (freeze L1-L3, train L4)
# freeze_until=2 means freeze L1-L2, train L3-L4
print(f"[INFO] ResNet-50: Dynamic partial fine-tuning (Freezing Layers 1-{freeze_until}).")
# To freeze up to layer N (L1, L2, L3, or L4), we freeze all indices from 0 up to N+3.
max_freeze_idx = freeze_until + 3
# Create a set of indices to freeze (from 0 up to max_freeze_idx, inclusive)
freeze_indices = set(range(max_freeze_idx + 1))
for idx, layer in enumerate(self.features):
# Ensure we only process layers up to the target index
if idx in freeze_indices:
for param in layer.parameters():
param.requires_grad = False
def forward(self, pixel_values):
x = self.features(pixel_values) # (B, 2048, H, W)
if self.vision_mode == "cls":
x_flat = x.flatten(1) # (B, 2048*H*W)
return {"image_embeds": x_flat}
tokens = x.flatten(2).transpose(1, 2) # (B, S, 2048)
return {"image_embeds": tokens}
def get_output_dim(self):
return self.embed_dim
# ViT Encoders
class ViTEncoder(BaseVisionEncoder):
def __init__(self, model_name="google/vit-base-patch16-224", train_last_n_layers=4, vision_mode="patch"):
super().__init__(embed_dim=None)
self.model = AutoModel.from_pretrained(model_name)
self.vision_mode = vision_mode
self.embed_dim = self.model.config.hidden_size
if self.embed_dim is None:
raise ValueError("Could not determine embed_dim from model config.")
# Partial Fine-Tuning Strategy
# Strategy: Freeze first 8 layers (0-7), train last 4 layers (8-11),
# plus embeddings and final LayerNorm. (Total layers = 12 for ViT-Base)
# Freeze all parameters initially
for param in self.model.parameters():
param.requires_grad = False
# Unfreeze the final N transformer blocks
NUM_LAYERS_TO_TRAIN = train_last_n_layers
try:
# The layers are typically stored in .encoder.layer
encoder_layers = self.model.encoder.layer
num_layers = len(encoder_layers)
# Unfreeze the last NUM_LAYERS_TO_TRAIN blocks
for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers):
layer = encoder_layers[i]
for param in layer.parameters():
param.requires_grad = True
print(f"ViT Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).")
except AttributeError:
print("Warning: Could not find standard ViT layer structure for partial fine-tuning.")
# Unfreeze Positional Embeddings (often gives a small boost)
if hasattr(self.model.embeddings, 'position_embeddings'):
self.model.embeddings.position_embeddings.requires_grad = True
print("ViT Encoder: Unfrozen positional embeddings.")
# Unfreeze the final LayerNorm (for stabilization)
if hasattr(self.model.encoder, 'layernorm'):
for param in self.model.encoder.layernorm.parameters():
param.requires_grad = True
print("ViT Encoder: Unfrozen final LayerNorm.")
def forward(self, pixel_values):
out = self.model(pixel_values=pixel_values)
# CLS MODE
if self.vision_mode == "cls":
if hasattr(out, 'pooler_output') and out.pooler_output is not None:
pooled = out.pooler_output # (B, D)
elif hasattr(out, 'last_hidden_state'):
pooled = out.last_hidden_state[:, 0, :] # CLS token (B, D)
else:
raise RuntimeError("Model output format not recognized.")
return {"image_embeds": pooled}
# PATCH
seq = out.last_hidden_state # (B, S, D)
return {"image_embeds": seq}
def get_output_dim(self):
return self.embed_dim
# Clip Encoders
class CLIPEncoder(BaseVisionEncoder):
def __init__(self, model_name="openai/clip-vit-base-patch32", train_last_n_layers=4, vision_mode="patch"):
# The output dimension (hidden size) will be set after loading the model config
super().__init__(embed_dim=None)
self.model = CLIPVisionModel.from_pretrained(model_name)
self.vision_mode = vision_mode
self.embed_dim = self.model.config.hidden_size
if self.embed_dim is None:
raise ValueError("Could not determine embed_dim from model config.")
# Partial Fine-Tuning Strategy
# Strategy: Freeze first 8 layers (0-7), train last 4 layers (8-11),
# plus embeddings and final LayerNorm. (Total layers = 12 for ViT-Base)
# Freeze all parameters initially
for param in self.model.parameters():
param.requires_grad = False
# Unfreeze the final N transformer blocks
NUM_LAYERS_TO_TRAIN = train_last_n_layers
try:
encoder_layers = self.model.vision_model.encoder.layers
num_layers = len(encoder_layers)
for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers):
layer = encoder_layers[i]
for param in layer.parameters():
param.requires_grad = True
print(f"CLIP Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).")
except AttributeError:
print("Warning: Could not find standard CLIP layer structure for partial fine-tuning. Ensure model structure is correct.")
if hasattr(self.model.vision_model.embeddings, 'position_embedding'):
self.model.vision_model.embeddings.position_embedding.requires_grad = True
print("CLIP Encoder: Unfrozen positional embeddings.")
if hasattr(self.model.vision_model, 'post_layernorm'):
for param in self.model.vision_model.post_layernorm.parameters():
param.requires_grad = True
print("CLIP Encoder: Unfrozen final LayerNorm.")
def forward(self, pixel_values):
out = self.model(pixel_values=pixel_values)
seq = out.last_hidden_state # (B, S, D)
if self.vision_mode == "cls":
return {"image_embeds": seq[:, 0, :]} # (B, D)
return {"image_embeds": seq} # (B, S, D)
def get_output_dim(self):
return self.embed_dim