# models/resnet_encoder.py # Just for setting up the pipeline, this will be replaced import torch import torch.nn as nn from torchvision import models from .base_encoder import BaseVisionEncoder from transformers import CLIPVisionModel from transformers import AutoModel class ResnetCNNEncoder(nn.Module): def __init__(self, model_name="resnet50", fine_tune_all_encoder_layers=False, freeze_encoder_entirely=False, freeze_until=3, vision_mode="patch"): super().__init__() self.embed_dim = 2048 # Fixed output dimension for ResNet-50/101 self.vision_mode = vision_mode if model_name == "resnet50": resnet = models.resnet50( weights=models.ResNet50_Weights.IMAGENET1K_V2 ) elif model_name == "resnet101": resnet = models.resnet101( weights=models.ResNet101_Weights.IMAGENET1K_V2 ) else: raise ValueError("model_name must be 'resnet50' or 'resnet101'") # Keep encoder layers only (no classifier head) # ResNet children indices: 0:conv1, 1:bn1, 2:relu, 3:maxpool, 4:layer1, 5:layer2, 6:layer3, 7:layer4 self.features = nn.Sequential(*list(resnet.children())[:-1]) self.model = self.features # Full Fine-Tuning Mode if fine_tune_all_encoder_layers: print("[INFO] ResNet-50: Fine-tuning ALL layers (1-4).") # PyTorch defaults to requires_grad=True, so no action is needed here. return # Full Freezing Mode elif freeze_encoder_entirely: print("[INFO] ResNet-50: Freezing ALL layers (1-4).") for param in self.features.parameters(): param.requires_grad = False return # Dynamic Partial Freezing Mode else: # freeze_until=3 is the default behavior (freeze L1-L3, train L4) # freeze_until=2 means freeze L1-L2, train L3-L4 print(f"[INFO] ResNet-50: Dynamic partial fine-tuning (Freezing Layers 1-{freeze_until}).") # To freeze up to layer N (L1, L2, L3, or L4), we freeze all indices from 0 up to N+3. max_freeze_idx = freeze_until + 3 # Create a set of indices to freeze (from 0 up to max_freeze_idx, inclusive) freeze_indices = set(range(max_freeze_idx + 1)) for idx, layer in enumerate(self.features): # Ensure we only process layers up to the target index if idx in freeze_indices: for param in layer.parameters(): param.requires_grad = False def forward(self, pixel_values): x = self.features(pixel_values) # (B, 2048, H, W) if self.vision_mode == "cls": x_flat = x.flatten(1) # (B, 2048*H*W) return {"image_embeds": x_flat} tokens = x.flatten(2).transpose(1, 2) # (B, S, 2048) return {"image_embeds": tokens} def get_output_dim(self): return self.embed_dim # ViT Encoders class ViTEncoder(BaseVisionEncoder): def __init__(self, model_name="google/vit-base-patch16-224", train_last_n_layers=4, vision_mode="patch"): super().__init__(embed_dim=None) self.model = AutoModel.from_pretrained(model_name) self.vision_mode = vision_mode self.embed_dim = self.model.config.hidden_size if self.embed_dim is None: raise ValueError("Could not determine embed_dim from model config.") # Partial Fine-Tuning Strategy # Strategy: Freeze first 8 layers (0-7), train last 4 layers (8-11), # plus embeddings and final LayerNorm. (Total layers = 12 for ViT-Base) # Freeze all parameters initially for param in self.model.parameters(): param.requires_grad = False # Unfreeze the final N transformer blocks NUM_LAYERS_TO_TRAIN = train_last_n_layers try: # The layers are typically stored in .encoder.layer encoder_layers = self.model.encoder.layer num_layers = len(encoder_layers) # Unfreeze the last NUM_LAYERS_TO_TRAIN blocks for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers): layer = encoder_layers[i] for param in layer.parameters(): param.requires_grad = True print(f"ViT Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).") except AttributeError: print("Warning: Could not find standard ViT layer structure for partial fine-tuning.") # Unfreeze Positional Embeddings (often gives a small boost) if hasattr(self.model.embeddings, 'position_embeddings'): self.model.embeddings.position_embeddings.requires_grad = True print("ViT Encoder: Unfrozen positional embeddings.") # Unfreeze the final LayerNorm (for stabilization) if hasattr(self.model.encoder, 'layernorm'): for param in self.model.encoder.layernorm.parameters(): param.requires_grad = True print("ViT Encoder: Unfrozen final LayerNorm.") def forward(self, pixel_values): out = self.model(pixel_values=pixel_values) # CLS MODE if self.vision_mode == "cls": if hasattr(out, 'pooler_output') and out.pooler_output is not None: pooled = out.pooler_output # (B, D) elif hasattr(out, 'last_hidden_state'): pooled = out.last_hidden_state[:, 0, :] # CLS token (B, D) else: raise RuntimeError("Model output format not recognized.") return {"image_embeds": pooled} # PATCH seq = out.last_hidden_state # (B, S, D) return {"image_embeds": seq} def get_output_dim(self): return self.embed_dim # Clip Encoders class CLIPEncoder(BaseVisionEncoder): def __init__(self, model_name="openai/clip-vit-base-patch32", train_last_n_layers=4, vision_mode="patch"): # The output dimension (hidden size) will be set after loading the model config super().__init__(embed_dim=None) self.model = CLIPVisionModel.from_pretrained(model_name) self.vision_mode = vision_mode self.embed_dim = self.model.config.hidden_size if self.embed_dim is None: raise ValueError("Could not determine embed_dim from model config.") # Partial Fine-Tuning Strategy # Strategy: Freeze first 8 layers (0-7), train last 4 layers (8-11), # plus embeddings and final LayerNorm. (Total layers = 12 for ViT-Base) # Freeze all parameters initially for param in self.model.parameters(): param.requires_grad = False # Unfreeze the final N transformer blocks NUM_LAYERS_TO_TRAIN = train_last_n_layers try: encoder_layers = self.model.vision_model.encoder.layers num_layers = len(encoder_layers) for i in range(num_layers - NUM_LAYERS_TO_TRAIN, num_layers): layer = encoder_layers[i] for param in layer.parameters(): param.requires_grad = True print(f"CLIP Encoder: Unfrozen the final {NUM_LAYERS_TO_TRAIN} blocks ({num_layers - NUM_LAYERS_TO_TRAIN} to {num_layers - 1}).") except AttributeError: print("Warning: Could not find standard CLIP layer structure for partial fine-tuning. Ensure model structure is correct.") if hasattr(self.model.vision_model.embeddings, 'position_embedding'): self.model.vision_model.embeddings.position_embedding.requires_grad = True print("CLIP Encoder: Unfrozen positional embeddings.") if hasattr(self.model.vision_model, 'post_layernorm'): for param in self.model.vision_model.post_layernorm.parameters(): param.requires_grad = True print("CLIP Encoder: Unfrozen final LayerNorm.") def forward(self, pixel_values): out = self.model(pixel_values=pixel_values) seq = out.last_hidden_state # (B, S, D) if self.vision_mode == "cls": return {"image_embeds": seq[:, 0, :]} # (B, D) return {"image_embeds": seq} # (B, S, D) def get_output_dim(self): return self.embed_dim