| import sys |
| import os |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint |
| from omegaconf import OmegaConf |
|
|
| sys.path.append("./dinov2") |
| import hubconf |
|
|
| CONFIG_PATH = './configs/pics.yaml' |
|
|
| def load_config(path=CONFIG_PATH): |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Config file not found at {path}") |
| return OmegaConf.load(path) |
|
|
| config = load_config() |
|
|
| DINOv2_weight_path = config.model.params.cond_stage_config.weight |
|
|
| class LayerNorm(nn.LayerNorm): |
| """Subclass torch's LayerNorm (with cast back to input dtype).""" |
|
|
| def forward(self, x: torch.Tensor): |
| orig_type = x.dtype |
| x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
| return x.to(orig_type) |
|
|
| class AbstractEncoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def encode(self, *args, **kwargs): |
| raise NotImplementedError |
|
|
|
|
| def disabled_train(self, mode=True): |
| """Overwrite model.train with this function to make sure train/eval mode |
| does not change anymore.""" |
| return self |
|
|
|
|
| class FrozenDinoV2Encoder(AbstractEncoder): |
| """ |
| Uses the DINOv2 encoder for image |
| """ |
| def __init__(self, device="cuda", freeze=True): |
| super().__init__() |
| dinov2 = hubconf.dinov2_vitg14() |
| state_dict = torch.load(DINOv2_weight_path) |
| dinov2.load_state_dict(state_dict, strict=False) |
| self.model = dinov2.to(device) |
| self.device = device |
| if freeze: |
| self.freeze() |
| self.image_mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
| self.image_std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
| self.projector = nn.Linear(1536,1024) |
|
|
| def freeze(self): |
| self.model.eval() |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, image): |
| if isinstance(image,list): |
| image = torch.cat(image,0) |
|
|
| image = (image.to(self.device) - self.image_mean.to(self.device)) / self.image_std.to(self.device) |
| features = self.model.forward_features(image) |
| tokens = features["x_norm_patchtokens"] |
| image_features = features["x_norm_clstoken"] |
| image_features = image_features.unsqueeze(1) |
| hint = torch.cat([image_features,tokens],1) |
| hint = self.projector(hint) |
| return hint |
|
|
| def encode(self, image): |
| return self(image) |
|
|
|
|
|
|