|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models |
|
|
from .base_encoder import BaseVisionEncoder |
|
|
from transformers import CLIPVisionModel |
|
|
from transformers import AutoModel |
|
|
|
|
|
|
|
|
class ResnetCNNEncoder(nn.Module): |
|
|
def __init__(self, model_name="resnet50", fine_tune_all_encoder_layers=False, freeze_encoder_entirely=False, freeze_until=3, vision_mode="patch"): |
|
|
super().__init__() |
|
|
self.embed_dim = 2048 |
|
|
self.vision_mode = vision_mode |
|
|
|
|
|
if model_name == "resnet50": |
|
|
resnet = models.resnet50( |
|
|
weights=models.ResNet50_Weights.IMAGENET1K_V2 |
|
|
) |
|
|
elif model_name == "resnet101": |
|
|
resnet = models.resnet101( |
|
|
weights=models.ResNet101_Weights.IMAGENET1K_V2 |
|
|
) |
|
|
else: |
|
|
raise ValueError("model_name must be 'resnet50' or 'resnet101'") |
|
|
|
|
|
|
|
|
|
|
|
self.features = nn.Sequential(*list(resnet.children())[:-1]) |
|
|
self.model = self.features |
|
|
|
|
|
|
|
|
|
|
|
if fine_tune_all_encoder_layers: |
|
|
print("[INFO] ResNet-50: Fine-tuning ALL layers (1-4).") |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
elif freeze_encoder_entirely: |
|
|
print("[INFO] ResNet-50: Freezing ALL layers (1-4).") |
|
|
for param in self.features.parameters(): |
|
|
param.requires_grad = False |
|
|
return |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
print(f"[INFO] ResNet-50: Dynamic partial fine-tuning (Freezing Layers 1-{freeze_until}).") |
|
|
|
|
|
|
|
|
max_freeze_idx = freeze_until + 3 |
|
|
|
|
|
|
|
|
freeze_indices = set(range(max_freeze_idx + 1)) |
|
|
|
|
|
for idx, layer in enumerate(self.features): |
|
|
|
|
|
if idx in freeze_indices: |
|
|
for param in layer.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
def forward(self, pixel_values): |
|
|
x = self.features(pixel_values) |
|
|
|
|
|
if self.vision_mode == "cls": |
|
|
x_flat = x.flatten(1) |
|
|
return {"image_embeds": x_flat} |
|
|
|
|
|
tokens = x.flatten(2).transpose(1, 2) |
|
|
return {"image_embeds": tokens} |
|
|
|
|
|
|
|
|
def get_output_dim(self): |
|
|
return self.embed_dim |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ViTEncoder(BaseVisionEncoder): |
|
|
|
|
|
def __init__(self, model_name="google/vit-base-patch16-224", train_last_n_layers=4, vision_mode="patch"): |
|
|
super().__init__(embed_dim=None) |
|
|
|
|
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
self.vision_mode = vision_mode |
|
|
|
|
|
self.embed_dim = self.model.config.hidden_size |
|
|
if self.embed_dim is None: |
|
|
raise ValueError("Could not determine embed_dim from model config.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
NUM_LAYERS_TO_TRAIN = train_last_n_layers |
|
|
|
|
|
try: |
|
|
|
|
|
encoder_layers = self.model.encoder.layer |
|
|
num_layers = len(encoder_layers) |
|
|
|
|
|
|
|
|
for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers): |
|
|
layer = encoder_layers[i] |
|
|
for param in layer.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
print(f"ViT Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).") |
|
|
|
|
|
except AttributeError: |
|
|
print("Warning: Could not find standard ViT layer structure for partial fine-tuning.") |
|
|
|
|
|
|
|
|
if hasattr(self.model.embeddings, 'position_embeddings'): |
|
|
self.model.embeddings.position_embeddings.requires_grad = True |
|
|
print("ViT Encoder: Unfrozen positional embeddings.") |
|
|
|
|
|
|
|
|
if hasattr(self.model.encoder, 'layernorm'): |
|
|
for param in self.model.encoder.layernorm.parameters(): |
|
|
param.requires_grad = True |
|
|
print("ViT Encoder: Unfrozen final LayerNorm.") |
|
|
|
|
|
|
|
|
def forward(self, pixel_values): |
|
|
out = self.model(pixel_values=pixel_values) |
|
|
|
|
|
|
|
|
if self.vision_mode == "cls": |
|
|
if hasattr(out, 'pooler_output') and out.pooler_output is not None: |
|
|
pooled = out.pooler_output |
|
|
elif hasattr(out, 'last_hidden_state'): |
|
|
pooled = out.last_hidden_state[:, 0, :] |
|
|
else: |
|
|
raise RuntimeError("Model output format not recognized.") |
|
|
|
|
|
return {"image_embeds": pooled} |
|
|
|
|
|
|
|
|
seq = out.last_hidden_state |
|
|
return {"image_embeds": seq} |
|
|
|
|
|
|
|
|
def get_output_dim(self): |
|
|
return self.embed_dim |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPEncoder(BaseVisionEncoder): |
|
|
|
|
|
def __init__(self, model_name="openai/clip-vit-base-patch32", train_last_n_layers=4, vision_mode="patch"): |
|
|
|
|
|
super().__init__(embed_dim=None) |
|
|
|
|
|
self.model = CLIPVisionModel.from_pretrained(model_name) |
|
|
self.vision_mode = vision_mode |
|
|
|
|
|
self.embed_dim = self.model.config.hidden_size |
|
|
if self.embed_dim is None: |
|
|
raise ValueError("Could not determine embed_dim from model config.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
NUM_LAYERS_TO_TRAIN = train_last_n_layers |
|
|
|
|
|
try: |
|
|
encoder_layers = self.model.vision_model.encoder.layers |
|
|
num_layers = len(encoder_layers) |
|
|
|
|
|
for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers): |
|
|
layer = encoder_layers[i] |
|
|
for param in layer.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
print(f"CLIP Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).") |
|
|
|
|
|
except AttributeError: |
|
|
print("Warning: Could not find standard CLIP layer structure for partial fine-tuning. Ensure model structure is correct.") |
|
|
|
|
|
|
|
|
if hasattr(self.model.vision_model.embeddings, 'position_embedding'): |
|
|
self.model.vision_model.embeddings.position_embedding.requires_grad = True |
|
|
print("CLIP Encoder: Unfrozen positional embeddings.") |
|
|
|
|
|
if hasattr(self.model.vision_model, 'post_layernorm'): |
|
|
for param in self.model.vision_model.post_layernorm.parameters(): |
|
|
param.requires_grad = True |
|
|
print("CLIP Encoder: Unfrozen final LayerNorm.") |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, pixel_values): |
|
|
out = self.model(pixel_values=pixel_values) |
|
|
seq = out.last_hidden_state |
|
|
|
|
|
if self.vision_mode == "cls": |
|
|
return {"image_embeds": seq[:, 0, :]} |
|
|
|
|
|
return {"image_embeds": seq} |
|
|
|
|
|
|
|
|
|
|
|
def get_output_dim(self): |
|
|
return self.embed_dim |