| import os |
| import torch |
| import torch.nn as nn |
| from src.modeling import ImageEncoder |
| from src.utils import DotDict |
|
|
| class AttentionOnlyFinetuneEncoder(ImageEncoder): |
| """ |
| A specialized ImageEncoder that fine-tunes only the attention module weights in the ViT. |
| Corresponds to the method described in Jin et al. (2025). |
| """ |
| def __init__(self, args, keep_lang=False): |
| |
| super().__init__(args, keep_lang=keep_lang) |
|
|
| self.args = args |
|
|
| |
| |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| |
| |
| self._unfreeze_attention_weights(self.model.visual) |
|
|
| |
| |
|
|
| def _unfreeze_attention_weights(self, vit_model): |
| """ |
| Iterate over all Transformer blocks and unfreeze the attention projection weights. |
| """ |
| |
| for block in vit_model.transformer.resblocks: |
| |
| block.attn.in_proj_weight.requires_grad = True |
|
|
| |
| block.attn.out_proj.weight.requires_grad = True |
|
|
| |
| |
| |
|
|
| def _verify_trainable_params(self): |
| """Print all trainable parameters for debugging and verification.""" |
| print("="*80) |
| print("Trainable parameters in AttentionOnlyFinetuneEncoder:") |
| trainable_params_count = 0 |
| for name, param in self.model.named_parameters(): |
| if param.requires_grad: |
| print(f" - {name}") |
| trainable_params_count += param.numel() |
| print(f"Total trainable parameters: {trainable_params_count / 1e6:.2f}M") |
| print("="*80) |
|
|
| def forward(self, images, calculate_ortho_loss=False, pretrained_state_dict=None): |
| """ |
| Extended forward method to optionally compute and return the orthogonal loss. |
| Consistent with the logic implemented for standard_ortho. |
| """ |
| |
| features = self.model.encode_image(images) |
|
|
| |
| if not calculate_ortho_loss: |
| return features |
|
|
| |
| if pretrained_state_dict is None: |
| raise ValueError("pretrained_state_dict must be provided when calculate_ortho_loss is True") |
|
|
| ortho_loss = 0.0 |
| |
| for name, p_finetuned in self.model.named_parameters(): |
| |
| if p_finetuned.requires_grad and p_finetuned.dim() == 2: |
| if name in pretrained_state_dict: |
| p_pretrained = pretrained_state_dict[name].to(p_finetuned.device) |
|
|
| delta_W = p_finetuned - p_pretrained |
|
|
| |
| rows, cols = delta_W.shape |
| if rows < cols: |
| mat = delta_W @ delta_W.T |
| identity = torch.eye(rows, device=delta_W.device) |
| else: |
| mat = delta_W.T @ delta_W |
| identity = torch.eye(cols, device=delta_W.device) |
| |
| ortho_loss += torch.norm(mat - identity, p='fro') |
| |
| return features, ortho_loss |
|
|
| def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None): |
| |
| return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict) |
|
|
| def save(self, filename): |
| """Save model weights.""" |
| |
| if os.path.dirname(filename): |
| os.makedirs(os.path.dirname(filename), exist_ok=True) |
| |
| torch.save(self.state_dict(), filename) |
|
|
| @classmethod |
| def load(cls, filename, args): |
| """Load model from a state_dict.""" |
| |
| encoder = cls(args) |
| state_dict = torch.load(filename, map_location='cpu') |
| encoder.load_state_dict(state_dict) |
| return encoder |