import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl from src.dinov2.models.vision_transformer import vit_base from src.options import opts def freeze_model(m): m.requires_grad_(False) def freeze_all_but_bn(m): if not isinstance(m, torch.nn.LayerNorm): if hasattr(m, 'weight') and m.weight is not None: m.weight.requires_grad_(False) if hasattr(m, 'bias') and m.bias is not None: m.bias.requires_grad_(False) else: print("LayerNorm") class Model(pl.LightningModule): def __init__(self): super().__init__() self.opts = opts self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0) # Prompt Engineering self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) def configure_optimizers(self): model_params = list(self.dino.parameters()) optimizer = torch.optim.Adam([ {'params': model_params, 'lr': self.opts.clip_LN_lr}] ) return optimizer def forward(self, data, dtype='image'): if dtype == 'image': feat = self.dino(data, prompt=self.img_prompt.expand(data.shape[0], -1, -1)) else: feat = self.dino(data, prompt=self.sk_prompt.expand(data.shape[0], -1, -1)) return feat