| """ |
| Image Encoder Components for Persistence Images |
| |
| This module contains various image encoders optimized for persistence images: |
| - SimpleCNN: Lightweight CNN |
| - SmallViT: Compact Vision Transformer |
| - PersistenceViT: Topologically-aware Vision Transformer with persistence-weighted positional encoding |
| - ResNet18/ResNet50: Standard ResNet encoders |
| - DINOv2: Self-supervised visual encoders |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.models as models |
| from torchvision.models import ResNet18_Weights, ResNet50_Weights, ResNet101_Weights |
| from typing import Tuple, Optional |
|
|
|
|
| |
| |
| |
|
|
| class SimpleCNN(nn.Module): |
| """Lightweight CNN for persistence images""" |
| def __init__(self, output_dim=128): |
| super(SimpleCNN, self).__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(3, 32, 3, padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(2), |
| nn.Conv2d(32, 64, 3, padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(2), |
| nn.Dropout2d(0.3), |
| nn.Conv2d(64, 128, 3, padding=1), |
| nn.ReLU(), |
| nn.AdaptiveAvgPool2d(1) |
| ) |
| self.fc = nn.Linear(128, output_dim) |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = x.view(x.size(0), -1) |
| x = self.fc(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class MultiHeadAttention(nn.Module): |
| """Multi-head self-attention for Vision Transformer""" |
| def __init__(self, dim, heads, dropout=0.1): |
| super().__init__() |
| self.heads = heads |
| self.scale = (dim // heads) ** -0.5 |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=False) |
| self.attn_drop = nn.Dropout(dropout) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Transformer block with self-attention and MLP""" |
| def __init__(self, dim, heads, mlp_dim, dropout=0.1): |
| super().__init__() |
| self.attention = MultiHeadAttention(dim, heads, dropout) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(mlp_dim, dim), |
| nn.Dropout(dropout) |
| ) |
| self.norm1 = nn.LayerNorm(dim) |
| self.norm2 = nn.LayerNorm(dim) |
|
|
| def forward(self, x): |
| x = x + self.attention(self.norm1(x)) |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class SmallViT(nn.Module): |
| """Lightweight Vision Transformer for persistence images""" |
| def __init__( |
| self, |
| image_size=224, |
| patch_size=14, |
| output_dim=128, |
| dim=128, |
| depth=4, |
| heads=4, |
| mlp_dim=256, |
| channels=3, |
| dropout=0.2, |
| emb_dropout=0.2 |
| ): |
| super().__init__() |
| assert image_size % patch_size == 0 |
| num_patches = (image_size // patch_size) ** 2 |
|
|
| self.patch_embedding = nn.Sequential( |
| nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size), |
| nn.Flatten(2), |
| ) |
| self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) |
| self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) |
| self.dropout = nn.Dropout(emb_dropout) |
|
|
| self.transformer = nn.ModuleList([ |
| TransformerBlock(dim, heads, mlp_dim, dropout) |
| for _ in range(depth) |
| ]) |
|
|
| self.mlp_head = nn.Sequential( |
| nn.LayerNorm(dim), |
| nn.Dropout(dropout), |
| nn.Linear(dim, output_dim) |
| ) |
|
|
| def forward(self, img): |
| x = self.patch_embedding(img).transpose(1, 2) |
| b, n, _ = x.shape |
| cls_tokens = self.cls_token.expand(b, -1, -1) |
| x = torch.cat([cls_tokens, x], dim=1) |
| x = x + self.pos_embedding[:, :(n + 1)] |
| x = self.dropout(x) |
|
|
| for transformer_block in self.transformer: |
| x = transformer_block(x) |
|
|
| return self.mlp_head(x[:, 0]) |
|
|
|
|
| |
| |
| |
|
|
| class TopologicalPatchEmbedding(nn.Module): |
| """Topological-Aware Patch Embedding with importance weighting""" |
| def __init__(self, in_channels, dim, patch_size, image_size): |
| super().__init__() |
| self.patch_size = patch_size |
| self.num_patches = (image_size // patch_size) ** 2 |
|
|
| self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size) |
|
|
| |
| self.importance_estimator = nn.Sequential( |
| nn.Conv2d(in_channels, 32, kernel_size=patch_size, stride=patch_size), |
| nn.ReLU(), |
| nn.Conv2d(32, 1, kernel_size=1), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| patches = self.proj(x).flatten(2).transpose(1, 2) |
| importance = self.importance_estimator(x).flatten(2).transpose(1, 2) |
| weighted_patches = patches * (1 + importance) |
| return weighted_patches, importance.squeeze(-1) |
|
|
|
|
| class BirthDeathAttention(nn.Module): |
| """Attention mechanism with persistence-based modulation""" |
| def __init__(self, dim, heads, dropout=0.1): |
| super().__init__() |
| self.heads = heads |
| self.scale = (dim // heads) ** -0.5 |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=False) |
| self.attn_drop = nn.Dropout(dropout) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(dropout) |
| self.persistence_bias = nn.Parameter(torch.zeros(1, heads, 1, 1)) |
|
|
| def forward(self, x, importance_weights=None): |
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if importance_weights is not None: |
| importance = importance_weights.unsqueeze(1).unsqueeze(-1) |
| attn = attn + self.persistence_bias + importance * 0.1 |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
|
|
| class BirthDeathTransformerBlock(nn.Module): |
| """Transformer block with topological awareness""" |
| def __init__(self, dim, heads, mlp_dim, dropout=0.1): |
| super().__init__() |
| self.attention = BirthDeathAttention(dim, heads, dropout) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(mlp_dim, dim), |
| nn.Dropout(dropout) |
| ) |
| self.norm1 = nn.LayerNorm(dim) |
| self.norm2 = nn.LayerNorm(dim) |
|
|
| def forward(self, x, importance_weights=None): |
| x = x + self.attention(self.norm1(x), importance_weights) |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class TopologicalFeatureAggregation(nn.Module): |
| """Aggregates CLS token with homology-specific tokens""" |
| def __init__(self, dim, num_homology_dims): |
| super().__init__() |
| self.num_homology = num_homology_dims |
| self.cross_attn = nn.MultiheadAttention(dim, num_heads=4, dropout=0.1, batch_first=True) |
| self.norm = nn.LayerNorm(dim) |
| self.aggregation_weights = nn.Parameter(torch.ones(1, num_homology_dims + 1, 1)) |
|
|
| def forward(self, tokens): |
| B = tokens.shape[0] |
| cls_token = tokens[:, 0:1] |
| homology_tokens = tokens[:, 1:] |
|
|
| aggregated, _ = self.cross_attn(cls_token, homology_tokens, homology_tokens) |
| aggregated = self.norm(aggregated) |
|
|
| weights = F.softmax(self.aggregation_weights, dim=1) |
| weighted_tokens = tokens * weights |
| final = aggregated + weighted_tokens.sum(dim=1, keepdim=True) |
|
|
| return final.squeeze(1) |
|
|
|
|
| class PersistenceWeightedPositionalEncoding(nn.Module): |
| """ |
| NOVEL CONTRIBUTION: Positional encoding based on topological persistence |
| rather than just spatial grid location. |
| |
| Key insight: Patches corresponding to high-persistence features should have |
| similar positional encodings regardless of their spatial location, because |
| they represent similar topological significance. |
| """ |
| def __init__(self, dim, image_size=224, patch_size=14): |
| super().__init__() |
| self.dim = dim |
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.num_patches = (image_size // patch_size) ** 2 |
|
|
| |
| self.spatial_pos = nn.Parameter(torch.randn(1, self.num_patches, dim // 2)) |
|
|
| |
| self.birth_encoder = nn.Sequential( |
| nn.Linear(1, dim // 4), |
| nn.ReLU(), |
| nn.Linear(dim // 4, dim // 4) |
| ) |
|
|
| self.persistence_encoder = nn.Sequential( |
| nn.Linear(1, dim // 4), |
| nn.ReLU(), |
| nn.Linear(dim // 4, dim // 4) |
| ) |
|
|
| |
| self.fusion = nn.Sequential( |
| nn.Linear(dim, dim), |
| nn.LayerNorm(dim), |
| nn.Tanh() |
| ) |
|
|
| def encode_patch_persistence( |
| self, |
| persistence_coords: torch.Tensor, |
| pixel_coords: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Assign persistence values to patches based on which features fall in each patch. |
| |
| Returns: |
| patch_birth: (num_patches, 1) - average birth for each patch |
| patch_persistence: (num_patches, 1) - average persistence for each patch |
| """ |
| device = persistence_coords.device |
| num_patches_h = self.image_size // self.patch_size |
| num_patches_w = self.image_size // self.patch_size |
|
|
| |
| patch_birth_sum = torch.zeros(num_patches_h, num_patches_w, device=device) |
| patch_pers_sum = torch.zeros(num_patches_h, num_patches_w, device=device) |
| patch_count = torch.zeros(num_patches_h, num_patches_w, device=device) |
|
|
| |
| for i in range(len(pixel_coords)): |
| |
| if torch.all(pixel_coords[i] == 0) and torch.all(persistence_coords[i] == 0): |
| continue |
|
|
| x, y = pixel_coords[i] |
| patch_x = int(torch.clamp(x / self.patch_size, 0, num_patches_w - 1)) |
| patch_y = int(torch.clamp(y / self.patch_size, 0, num_patches_h - 1)) |
|
|
| birth = persistence_coords[i, 0] |
| pers = persistence_coords[i, 2] |
|
|
| patch_birth_sum[patch_y, patch_x] += birth |
| patch_pers_sum[patch_y, patch_x] += pers |
| patch_count[patch_y, patch_x] += 1 |
|
|
| |
| mask = patch_count > 0 |
| patch_birth_avg = torch.zeros_like(patch_birth_sum) |
| patch_pers_avg = torch.zeros_like(patch_pers_sum) |
|
|
| patch_birth_avg[mask] = patch_birth_sum[mask] / patch_count[mask] |
| patch_pers_avg[mask] = patch_pers_sum[mask] / patch_count[mask] |
|
|
| |
| patch_birth_flat = patch_birth_avg.flatten().unsqueeze(-1) |
| patch_pers_flat = patch_pers_avg.flatten().unsqueeze(-1) |
|
|
| return patch_birth_flat, patch_pers_flat |
|
|
| def forward( |
| self, |
| batch_size: int, |
| persistence_coords: Optional[torch.Tensor] = None, |
| pixel_coords: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Generate positional encoding combining spatial and persistence information. |
| |
| Args: |
| batch_size: Batch size |
| persistence_coords: (B, N_features, 3) - optional persistence coordinates |
| pixel_coords: (B, N_features, 2) - optional pixel coordinates |
| |
| Returns: |
| pos_encoding: (B, num_patches, dim) |
| """ |
| |
| spatial_enc = self.spatial_pos.expand(batch_size, -1, -1) |
|
|
| if persistence_coords is None or pixel_coords is None: |
| |
| zero_enc = torch.zeros(batch_size, self.num_patches, self.dim // 2, |
| device=spatial_enc.device) |
| return torch.cat([spatial_enc, zero_enc], dim=-1) |
|
|
| |
| batch_persistence_enc = [] |
|
|
| for b in range(batch_size): |
| pers_coords_b = persistence_coords[b] |
| pix_coords_b = pixel_coords[b] |
|
|
| |
| patch_birth, patch_pers = self.encode_patch_persistence(pers_coords_b, pix_coords_b) |
|
|
| |
| birth_enc = self.birth_encoder(patch_birth) |
| pers_enc = self.persistence_encoder(patch_pers) |
|
|
| |
| persistence_enc = torch.cat([birth_enc, pers_enc], dim=-1) |
| batch_persistence_enc.append(persistence_enc) |
|
|
| persistence_enc = torch.stack(batch_persistence_enc, dim=0) |
|
|
| |
| combined = torch.cat([spatial_enc, persistence_enc], dim=-1) |
|
|
| |
| pos_encoding = self.fusion(combined) |
|
|
| return pos_encoding |
|
|
|
|
| class PersistenceViT(nn.Module): |
| """ |
| Vision Transformer for Persistence Images with topological inductive biases. |
| |
| Features: |
| - Topological-aware patch embedding with importance weighting |
| - Birth-death attention mechanism |
| - Multi-scale persistence encoding via homology tokens |
| - Topological feature aggregation |
| - Persistence-weighted positional encoding |
| """ |
| def __init__( |
| self, |
| image_size=256, |
| patch_size=16, |
| output_dim=128, |
| dim=128, |
| depth=4, |
| heads=4, |
| mlp_dim=256, |
| channels=3, |
| dropout=0.2, |
| homology_dims=3, |
| ): |
| super().__init__() |
|
|
| self.image_size = image_size |
| self.patch_size = patch_size |
| num_patches = (image_size // patch_size) ** 2 |
|
|
| |
| self.patch_embedding = TopologicalPatchEmbedding( |
| channels, dim, patch_size, image_size |
| ) |
|
|
| |
| self.homology_tokens = nn.Parameter(torch.randn(1, homology_dims, dim)) |
| self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) |
|
|
| |
| self.pos_encoder = PersistenceWeightedPositionalEncoding( |
| dim=dim, |
| image_size=image_size, |
| patch_size=patch_size |
| ) |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.transformer_blocks = nn.ModuleList([ |
| BirthDeathTransformerBlock(dim, heads, mlp_dim, dropout) |
| for _ in range(depth) |
| ]) |
|
|
| |
| self.topological_aggregation = TopologicalFeatureAggregation(dim, homology_dims) |
|
|
| |
| self.norm = nn.LayerNorm(dim) |
| self.mlp_head = nn.Sequential( |
| nn.Dropout(dropout), |
| nn.Linear(dim, output_dim) |
| ) |
|
|
| def forward(self, persistence_img, persistence_coords=None, pixel_coords=None): |
| """ |
| Forward pass with optional persistence coordinate information. |
| |
| Args: |
| persistence_img: (B, C, H, W) - persistence image |
| persistence_coords: (B, N_features, 3) - optional normalized (birth, death, persistence) |
| pixel_coords: (B, N_features, 2) - optional (x, y) pixel locations |
| |
| Returns: |
| output: (B, output_dim) - encoded representation |
| """ |
| B = persistence_img.shape[0] |
|
|
| |
| x, importance_weights = self.patch_embedding(persistence_img) |
|
|
| |
| pos_encoding = self.pos_encoder( |
| batch_size=B, |
| persistence_coords=persistence_coords, |
| pixel_coords=pixel_coords |
| ) |
|
|
| |
| x = x + pos_encoding |
|
|
| |
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| homology_tokens = self.homology_tokens.expand(B, -1, -1) |
| x = torch.cat([cls_tokens, homology_tokens, x], dim=1) |
|
|
| |
| num_special_tokens = 1 + self.homology_tokens.shape[1] |
| importance_padding = torch.zeros(B, num_special_tokens, device=importance_weights.device) |
| importance_weights_padded = torch.cat([importance_padding, importance_weights], dim=1) |
|
|
| x = self.dropout(x) |
|
|
| |
| for block in self.transformer_blocks: |
| x = block(x, importance_weights_padded) |
|
|
| x = self.norm(x) |
|
|
| |
| x = self.topological_aggregation(x[:, :1+self.homology_tokens.shape[1]]) |
|
|
| return self.mlp_head(x) |
|
|
|
|
| |
| |
| |
|
|
| class ImageEncoder(nn.Module): |
| """ |
| Unified image encoder interface supporting multiple architectures: |
| - SimpleCNN: Lightweight CNN |
| - SmallViT: Compact Vision Transformer |
| - PersistenceViT: Topologically-aware ViT |
| - ResNet18/ResNet50/ResNet101: Standard CNNs with ImageNet pretraining |
| - DINOv2: Self-supervised vision encoders (ViT-S/B/L/g) |
| - ConvNeXt-Small: Modern CNN architecture |
| - HybridResNetViT: ResNet conv layers + PersistenceViT attention |
| """ |
| def __init__(self, output_dim=128, model_type='resnet18', image_size=224, freeze_backbone=False): |
| super(ImageEncoder, self).__init__() |
|
|
| if model_type == 'hybrid_resnet18_vit': |
| from .hybrid_resnet_persistencevit import create_hybrid_encoder |
| self.encoder = create_hybrid_encoder( |
| output_dim=output_dim, |
| image_size=image_size, |
| freeze_backbone=freeze_backbone, |
| resnet_type='resnet18', |
| resnet_layers=2, |
| dim=256, |
| depth=4, |
| heads=8, |
| mlp_dim=512, |
| dropout=0.2, |
| ) |
| self.use_simple_model = True |
|
|
| elif model_type == 'hybrid_resnet50_vit': |
| from .hybrid_resnet_persistencevit import create_hybrid_encoder |
| self.encoder = create_hybrid_encoder( |
| output_dim=output_dim, |
| image_size=image_size, |
| freeze_backbone=freeze_backbone, |
| resnet_type='resnet50', |
| resnet_layers=2, |
| dim=256, |
| depth=4, |
| heads=8, |
| mlp_dim=512, |
| dropout=0.2, |
| ) |
| self.use_simple_model = True |
| |
| elif model_type == 'simplecnn': |
| self.encoder = SimpleCNN(output_dim=output_dim) |
| self.use_simple_model = True |
| elif model_type == 'smallvit': |
| self.encoder = SmallViT(image_size=image_size, patch_size=16, output_dim=output_dim, |
| dim=128, depth=6, heads=4, mlp_dim=256, channels=3, |
| dropout=0.2, emb_dropout=0.2) |
| self.use_simple_model = True |
| elif model_type == 'persistencevit': |
| self.encoder = PersistenceViT(image_size=image_size, patch_size=16, output_dim=output_dim, |
| dim=128, depth=6, heads=4, mlp_dim=256, channels=3, |
| dropout=0.2, homology_dims=3) |
| self.use_simple_model = True |
| elif model_type == 'resnet18': |
| resnet = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) |
| feature_dim = 512 |
| self.encoder = nn.Sequential(*list(resnet.children())[:-1]) |
| self.fc = nn.Sequential( |
| nn.Linear(feature_dim, 256), |
| nn.LayerNorm(256), |
| nn.ReLU(), |
| nn.Linear(256, output_dim), |
| ) |
| self.use_simple_model = False |
| elif model_type == 'resnet50': |
| resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) |
| feature_dim = 2048 |
| self.encoder = nn.Sequential(*list(resnet.children())[:-1]) |
| self.fc = nn.Sequential( |
| nn.Linear(feature_dim, 256), |
| nn.LayerNorm(256), |
| nn.ReLU(), |
| nn.Linear(256, output_dim), |
| ) |
| self.use_simple_model = False |
| elif model_type == 'resnet101': |
| resnet = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1) |
| feature_dim = 2048 |
| self.encoder = nn.Sequential(*list(resnet.children())[:-1]) |
| self.fc = nn.Sequential( |
| nn.Linear(feature_dim, 256) |
| ) |
| self.use_simple_model = False |
| elif model_type.startswith('dinov2'): |
| self.encoder = DINOv2ImageEncoder(output_dim=output_dim, freeze_backbone=False, model_variant=model_type) |
| |
| self.use_simple_model = True |
| elif model_type == 'convnext_small': |
| from torchvision.models import convnext_small, ConvNeXt_Small_Weights |
| convnext = convnext_small(weights=ConvNeXt_Small_Weights.IMAGENET1K_V1) |
| feature_dim = 768 |
| self.encoder = nn.Sequential(*list(convnext.children())[:-1]) |
| if freeze_backbone: |
| for param in self.encoder.parameters(): |
| param.requires_grad = False |
| self.fc = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(feature_dim, 256), |
| nn.LayerNorm(256), |
| nn.ReLU(), |
| nn.Linear(256, output_dim), |
| ) |
| self.use_simple_model = False |
| else: |
| raise ValueError(f"model_type must be 'simplecnn', 'smallvit', 'persistencevit', 'resnet18', 'resnet50', 'resnet101', 'dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14', 'convnext_small', or 'hybrid_resnet_vit', got {model_type}") |
|
|
| self.model_type = model_type |
|
|
| def forward(self, images, persistence_coords=None, pixel_coords=None): |
| """ |
| Args: |
| images: (B, 3, H, W) or (B, 1, H, W) tensor of persistence images |
| persistence_coords: (B, N_features, 3) - optional for PersistenceViT |
| pixel_coords: (B, N_features, 2) - optional for PersistenceViT |
| Returns: |
| features: (B, output_dim) tensor of image features |
| """ |
| if self.use_simple_model: |
| |
| if self.model_type == 'persistencevit': |
| return self.encoder(images, persistence_coords, pixel_coords) |
| else: |
| return self.encoder(images) |
| else: |
| features = self.encoder(images) |
| features = features.view(features.size(0), -1) |
| features = self.fc(features) |
| return features |
|
|
|
|
| class DINOv2ImageEncoder(nn.Module): |
| """ |
| DINOv2-based image encoder for persistence images. |
| Supports multiple DINOv2 variants: ViT-S/14, ViT-B/14, ViT-L/14, ViT-g/14. |
| Returns raw features (384-dim for vits14) without projection head. |
| """ |
| def __init__(self, output_dim=128, freeze_backbone=True, model_variant='dinov2_vits14'): |
| super(DINOv2ImageEncoder, self).__init__() |
|
|
| |
| variant_dims = { |
| 'dinov2_vits14': 384, |
| 'dinov2_vitb14': 768, |
| 'dinov2_vitl14': 1024, |
| 'dinov2_vitg14': 1536 |
| } |
|
|
| if model_variant not in variant_dims: |
| raise ValueError(f"Unknown DINOv2 variant: {model_variant}") |
|
|
| self.feat_dim = variant_dims[model_variant] |
|
|
| |
| self.backbone = torch.hub.load('facebookresearch/dinov2', model_variant) |
|
|
| |
| if freeze_backbone: |
| for param in self.backbone.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, images, persistence_coords=None, pixel_coords=None): |
| """ |
| Args: |
| images: (B, 3, H, W) tensor |
| persistence_coords: ignored (for interface compatibility) |
| pixel_coords: ignored (for interface compatibility) |
| Returns: |
| features: (B, feat_dim) raw DINOv2 features (384 for vits14) |
| """ |
| |
| features = self.backbone(images) |
|
|
| return features |
|
|