| import torch |
| import torch.nn as nn |
| from transformers import CLIPModel |
| from transformers.utils import logging |
| logging.set_verbosity_error() |
|
|
|
|
| class CLIPDenseEncoder(nn.Module): |
| def __init__(self, model_name="openai/clip-vit-base-patch16"): |
| super().__init__() |
| self.clip = CLIPModel.from_pretrained(str(model_name), use_safetensors=True) |
|
|
| self.vision = self.clip.vision_model |
| vcfg = self.clip.config.vision_config |
| self.hidden_size = int(vcfg.hidden_size) |
| self.patch_size = int(vcfg.patch_size) |
| self.image_size = int(vcfg.image_size) |
|
|
| |
| self.register_buffer("_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)) |
| self.register_buffer("_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)) |
|
|
| def forward(self, x01): |
| x = (x01 - self._mean.to(dtype=x01.dtype)) / self._std.to(dtype=x01.dtype) |
| out = self.vision(pixel_values=x) |
| tokens = out.last_hidden_state[:, 1:, :] |
| b, n, c = tokens.shape |
| side = int(n**0.5) |
| if side * side != n: |
| raise RuntimeError(f"CLIP patch tokens N={n} not a square; cannot reshape to 2D grid.") |
|
|
| fmap = tokens.transpose(1, 2).contiguous().view(b, c, side, side) |
| return fmap |
|
|
| def freeze_all(self): |
| for p in self.clip.parameters(): |
| p.requires_grad = False |
|
|
| def unfreeze_last_blocks(self, n_blocks=2, also_unfreeze_ln=True): |
| self.freeze_all() |
| layers = self.vision.encoder.layers |
| n_total = len(layers) |
|
|
| |
| if n_blocks < 0 or n_blocks >= n_total: |
| for p in self.vision.parameters(): |
| p.requires_grad = True |
| |
| else: |
| k = max(0, min(int(n_blocks), n_total)) |
| for i in range(n_total - k, n_total): |
| for p in layers[i].parameters(): |
| p.requires_grad = True |
|
|
| if also_unfreeze_ln: |
| if hasattr(self.vision, "pre_layrnorm"): |
| for p in self.vision.pre_layrnorm.parameters(): |
| p.requires_grad = True |
| if hasattr(self.vision, "post_layernorm"): |
| for p in self.vision.post_layernorm.parameters(): |
| p.requires_grad = True |
|
|