GraPHFormer / graphformer /models /image_encoder.py
uzshah's picture
Initial commit: GraPHFormer codebase
cf84204
"""
Image Encoder Components for Persistence Images
This module contains various image encoders optimized for persistence images:
- SimpleCNN: Lightweight CNN
- SmallViT: Compact Vision Transformer
- PersistenceViT: Topologically-aware Vision Transformer with persistence-weighted positional encoding
- ResNet18/ResNet50: Standard ResNet encoders
- DINOv2: Self-supervised visual encoders
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import ResNet18_Weights, ResNet50_Weights, ResNet101_Weights
from typing import Tuple, Optional
# ============================================================================
# Simple CNN Encoder
# ============================================================================
class SimpleCNN(nn.Module):
"""Lightweight CNN for persistence images"""
def __init__(self, output_dim=128):
super(SimpleCNN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.3),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.fc = nn.Linear(128, output_dim)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# ============================================================================
# Vision Transformer Components
# ============================================================================
class MultiHeadAttention(nn.Module):
"""Multi-head self-attention for Vision Transformer"""
def __init__(self, dim, heads, dropout=0.1):
super().__init__()
self.heads = heads
self.scale = (dim // heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.attn_drop = nn.Dropout(dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerBlock(nn.Module):
"""Transformer block with self-attention and MLP"""
def __init__(self, dim, heads, mlp_dim, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(dim, heads, dropout)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
x = x + self.attention(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class SmallViT(nn.Module):
"""Lightweight Vision Transformer for persistence images"""
def __init__(
self,
image_size=224,
patch_size=14,
output_dim=128,
dim=128,
depth=4,
heads=4,
mlp_dim=256,
channels=3,
dropout=0.2,
emb_dropout=0.2
):
super().__init__()
assert image_size % patch_size == 0
num_patches = (image_size // patch_size) ** 2
self.patch_embedding = nn.Sequential(
nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size),
nn.Flatten(2),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = nn.ModuleList([
TransformerBlock(dim, heads, mlp_dim, dropout)
for _ in range(depth)
])
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Dropout(dropout),
nn.Linear(dim, output_dim)
)
def forward(self, img):
x = self.patch_embedding(img).transpose(1, 2)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
for transformer_block in self.transformer:
x = transformer_block(x)
return self.mlp_head(x[:, 0])
# ============================================================================
# Topologically-Aware Vision Transformer (PersistenceViT)
# ============================================================================
class TopologicalPatchEmbedding(nn.Module):
"""Topological-Aware Patch Embedding with importance weighting"""
def __init__(self, in_channels, dim, patch_size, image_size):
super().__init__()
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
# Importance estimator for high-persistence regions
self.importance_estimator = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=patch_size, stride=patch_size),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
B, C, H, W = x.shape
patches = self.proj(x).flatten(2).transpose(1, 2) # (B, N, dim)
importance = self.importance_estimator(x).flatten(2).transpose(1, 2) # (B, N, 1)
weighted_patches = patches * (1 + importance)
return weighted_patches, importance.squeeze(-1)
class BirthDeathAttention(nn.Module):
"""Attention mechanism with persistence-based modulation"""
def __init__(self, dim, heads, dropout=0.1):
super().__init__()
self.heads = heads
self.scale = (dim // heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.attn_drop = nn.Dropout(dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(dropout)
self.persistence_bias = nn.Parameter(torch.zeros(1, heads, 1, 1))
def forward(self, x, importance_weights=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
if importance_weights is not None:
importance = importance_weights.unsqueeze(1).unsqueeze(-1)
attn = attn + self.persistence_bias + importance * 0.1
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class BirthDeathTransformerBlock(nn.Module):
"""Transformer block with topological awareness"""
def __init__(self, dim, heads, mlp_dim, dropout=0.1):
super().__init__()
self.attention = BirthDeathAttention(dim, heads, dropout)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x, importance_weights=None):
x = x + self.attention(self.norm1(x), importance_weights)
x = x + self.mlp(self.norm2(x))
return x
class TopologicalFeatureAggregation(nn.Module):
"""Aggregates CLS token with homology-specific tokens"""
def __init__(self, dim, num_homology_dims):
super().__init__()
self.num_homology = num_homology_dims
self.cross_attn = nn.MultiheadAttention(dim, num_heads=4, dropout=0.1, batch_first=True)
self.norm = nn.LayerNorm(dim)
self.aggregation_weights = nn.Parameter(torch.ones(1, num_homology_dims + 1, 1))
def forward(self, tokens):
B = tokens.shape[0]
cls_token = tokens[:, 0:1]
homology_tokens = tokens[:, 1:]
aggregated, _ = self.cross_attn(cls_token, homology_tokens, homology_tokens)
aggregated = self.norm(aggregated)
weights = F.softmax(self.aggregation_weights, dim=1)
weighted_tokens = tokens * weights
final = aggregated + weighted_tokens.sum(dim=1, keepdim=True)
return final.squeeze(1)
class PersistenceWeightedPositionalEncoding(nn.Module):
"""
NOVEL CONTRIBUTION: Positional encoding based on topological persistence
rather than just spatial grid location.
Key insight: Patches corresponding to high-persistence features should have
similar positional encodings regardless of their spatial location, because
they represent similar topological significance.
"""
def __init__(self, dim, image_size=224, patch_size=14):
super().__init__()
self.dim = dim
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
# Spatial positional encoding (standard grid-based)
self.spatial_pos = nn.Parameter(torch.randn(1, self.num_patches, dim // 2))
# Persistence-based encoding
self.birth_encoder = nn.Sequential(
nn.Linear(1, dim // 4),
nn.ReLU(),
nn.Linear(dim // 4, dim // 4)
)
self.persistence_encoder = nn.Sequential(
nn.Linear(1, dim // 4),
nn.ReLU(),
nn.Linear(dim // 4, dim // 4)
)
# Fusion of spatial and topological encodings
self.fusion = nn.Sequential(
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.Tanh()
)
def encode_patch_persistence(
self,
persistence_coords: torch.Tensor, # (N_features, 3)
pixel_coords: torch.Tensor # (N_features, 2)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Assign persistence values to patches based on which features fall in each patch.
Returns:
patch_birth: (num_patches, 1) - average birth for each patch
patch_persistence: (num_patches, 1) - average persistence for each patch
"""
device = persistence_coords.device
num_patches_h = self.image_size // self.patch_size
num_patches_w = self.image_size // self.patch_size
# Initialize patch statistics
patch_birth_sum = torch.zeros(num_patches_h, num_patches_w, device=device)
patch_pers_sum = torch.zeros(num_patches_h, num_patches_w, device=device)
patch_count = torch.zeros(num_patches_h, num_patches_w, device=device)
# Assign features to patches
for i in range(len(pixel_coords)):
# Skip padded entries (all zeros)
if torch.all(pixel_coords[i] == 0) and torch.all(persistence_coords[i] == 0):
continue
x, y = pixel_coords[i]
patch_x = int(torch.clamp(x / self.patch_size, 0, num_patches_w - 1))
patch_y = int(torch.clamp(y / self.patch_size, 0, num_patches_h - 1))
birth = persistence_coords[i, 0]
pers = persistence_coords[i, 2]
patch_birth_sum[patch_y, patch_x] += birth
patch_pers_sum[patch_y, patch_x] += pers
patch_count[patch_y, patch_x] += 1
# Average persistence values per patch
mask = patch_count > 0
patch_birth_avg = torch.zeros_like(patch_birth_sum)
patch_pers_avg = torch.zeros_like(patch_pers_sum)
patch_birth_avg[mask] = patch_birth_sum[mask] / patch_count[mask]
patch_pers_avg[mask] = patch_pers_sum[mask] / patch_count[mask]
# Flatten to (num_patches, 1)
patch_birth_flat = patch_birth_avg.flatten().unsqueeze(-1) # (num_patches, 1)
patch_pers_flat = patch_pers_avg.flatten().unsqueeze(-1) # (num_patches, 1)
return patch_birth_flat, patch_pers_flat
def forward(
self,
batch_size: int,
persistence_coords: Optional[torch.Tensor] = None,
pixel_coords: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Generate positional encoding combining spatial and persistence information.
Args:
batch_size: Batch size
persistence_coords: (B, N_features, 3) - optional persistence coordinates
pixel_coords: (B, N_features, 2) - optional pixel coordinates
Returns:
pos_encoding: (B, num_patches, dim)
"""
# Start with spatial encoding
spatial_enc = self.spatial_pos.expand(batch_size, -1, -1) # (B, num_patches, dim//2)
if persistence_coords is None or pixel_coords is None:
# Fall back to spatial-only encoding
zero_enc = torch.zeros(batch_size, self.num_patches, self.dim // 2,
device=spatial_enc.device)
return torch.cat([spatial_enc, zero_enc], dim=-1)
# Encode persistence information for each sample in batch
batch_persistence_enc = []
for b in range(batch_size):
pers_coords_b = persistence_coords[b] # (N_features, 3)
pix_coords_b = pixel_coords[b] # (N_features, 2)
# Get per-patch persistence statistics
patch_birth, patch_pers = self.encode_patch_persistence(pers_coords_b, pix_coords_b)
# Encode through MLPs
birth_enc = self.birth_encoder(patch_birth) # (num_patches, dim//4)
pers_enc = self.persistence_encoder(patch_pers) # (num_patches, dim//4)
# Combine
persistence_enc = torch.cat([birth_enc, pers_enc], dim=-1) # (num_patches, dim//2)
batch_persistence_enc.append(persistence_enc)
persistence_enc = torch.stack(batch_persistence_enc, dim=0) # (B, num_patches, dim//2)
# Combine spatial and persistence encodings
combined = torch.cat([spatial_enc, persistence_enc], dim=-1) # (B, num_patches, dim)
# Fuse through learned transformation
pos_encoding = self.fusion(combined)
return pos_encoding
class PersistenceViT(nn.Module):
"""
Vision Transformer for Persistence Images with topological inductive biases.
Features:
- Topological-aware patch embedding with importance weighting
- Birth-death attention mechanism
- Multi-scale persistence encoding via homology tokens
- Topological feature aggregation
- Persistence-weighted positional encoding
"""
def __init__(
self,
image_size=256,
patch_size=16,
output_dim=128,
dim=128,
depth=4,
heads=4,
mlp_dim=256,
channels=3,
dropout=0.2,
homology_dims=3,
):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
num_patches = (image_size // patch_size) ** 2
# Topological-aware patch embedding
self.patch_embedding = TopologicalPatchEmbedding(
channels, dim, patch_size, image_size
)
# Multi-scale persistence tokens
self.homology_tokens = nn.Parameter(torch.randn(1, homology_dims, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# Persistence-weighted positional encoding
self.pos_encoder = PersistenceWeightedPositionalEncoding(
dim=dim,
image_size=image_size,
patch_size=patch_size
)
self.dropout = nn.Dropout(dropout)
# Birth-death transformer blocks
self.transformer_blocks = nn.ModuleList([
BirthDeathTransformerBlock(dim, heads, mlp_dim, dropout)
for _ in range(depth)
])
# Topological aggregation
self.topological_aggregation = TopologicalFeatureAggregation(dim, homology_dims)
# Output head
self.norm = nn.LayerNorm(dim)
self.mlp_head = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(dim, output_dim)
)
def forward(self, persistence_img, persistence_coords=None, pixel_coords=None):
"""
Forward pass with optional persistence coordinate information.
Args:
persistence_img: (B, C, H, W) - persistence image
persistence_coords: (B, N_features, 3) - optional normalized (birth, death, persistence)
pixel_coords: (B, N_features, 2) - optional (x, y) pixel locations
Returns:
output: (B, output_dim) - encoded representation
"""
B = persistence_img.shape[0]
# Topological-aware patch embedding
x, importance_weights = self.patch_embedding(persistence_img) # (B, N, dim)
# Generate persistence-weighted positional encoding
pos_encoding = self.pos_encoder(
batch_size=B,
persistence_coords=persistence_coords,
pixel_coords=pixel_coords
) # (B, num_patches, dim)
# Add positional encoding
x = x + pos_encoding
# Add special tokens
cls_tokens = self.cls_token.expand(B, -1, -1)
homology_tokens = self.homology_tokens.expand(B, -1, -1)
x = torch.cat([cls_tokens, homology_tokens, x], dim=1)
# Pad importance_weights for special tokens (cls + homology)
num_special_tokens = 1 + self.homology_tokens.shape[1]
importance_padding = torch.zeros(B, num_special_tokens, device=importance_weights.device)
importance_weights_padded = torch.cat([importance_padding, importance_weights], dim=1)
x = self.dropout(x)
# Birth-death transformer blocks
for block in self.transformer_blocks:
x = block(x, importance_weights_padded)
x = self.norm(x)
# Topological aggregation
x = self.topological_aggregation(x[:, :1+self.homology_tokens.shape[1]])
return self.mlp_head(x)
# ============================================================================
# Unified Image Encoder Interface
# ============================================================================
class ImageEncoder(nn.Module):
"""
Unified image encoder interface supporting multiple architectures:
- SimpleCNN: Lightweight CNN
- SmallViT: Compact Vision Transformer
- PersistenceViT: Topologically-aware ViT
- ResNet18/ResNet50/ResNet101: Standard CNNs with ImageNet pretraining
- DINOv2: Self-supervised vision encoders (ViT-S/B/L/g)
- ConvNeXt-Small: Modern CNN architecture
- HybridResNetViT: ResNet conv layers + PersistenceViT attention
"""
def __init__(self, output_dim=128, model_type='resnet18', image_size=224, freeze_backbone=False):
super(ImageEncoder, self).__init__()
if model_type == 'hybrid_resnet18_vit':
from .hybrid_resnet_persistencevit import create_hybrid_encoder
self.encoder = create_hybrid_encoder(
output_dim=output_dim,
image_size=image_size,
freeze_backbone=freeze_backbone,
resnet_type='resnet18',
resnet_layers=2,
dim=256,
depth=4,
heads=8,
mlp_dim=512,
dropout=0.2,
)
self.use_simple_model = True
elif model_type == 'hybrid_resnet50_vit':
from .hybrid_resnet_persistencevit import create_hybrid_encoder
self.encoder = create_hybrid_encoder(
output_dim=output_dim,
image_size=image_size,
freeze_backbone=freeze_backbone,
resnet_type='resnet50',
resnet_layers=2,
dim=256,
depth=4,
heads=8,
mlp_dim=512,
dropout=0.2,
)
self.use_simple_model = True
elif model_type == 'simplecnn':
self.encoder = SimpleCNN(output_dim=output_dim)
self.use_simple_model = True
elif model_type == 'smallvit':
self.encoder = SmallViT(image_size=image_size, patch_size=16, output_dim=output_dim,
dim=128, depth=6, heads=4, mlp_dim=256, channels=3,
dropout=0.2, emb_dropout=0.2)
self.use_simple_model = True
elif model_type == 'persistencevit':
self.encoder = PersistenceViT(image_size=image_size, patch_size=16, output_dim=output_dim,
dim=128, depth=6, heads=4, mlp_dim=256, channels=3,
dropout=0.2, homology_dims=3)
self.use_simple_model = True
elif model_type == 'resnet18':
resnet = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
feature_dim = 512
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.fc = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, output_dim),
)
self.use_simple_model = False
elif model_type == 'resnet50':
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
feature_dim = 2048
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.fc = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, output_dim),
)
self.use_simple_model = False
elif model_type == 'resnet101':
resnet = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
feature_dim = 2048
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.fc = nn.Sequential(
nn.Linear(feature_dim, 256)
)
self.use_simple_model = False
elif model_type.startswith('dinov2'):
self.encoder = DINOv2ImageEncoder(output_dim=output_dim, freeze_backbone=False, model_variant=model_type)
self.use_simple_model = True
elif model_type == 'convnext_small':
from torchvision.models import convnext_small, ConvNeXt_Small_Weights
convnext = convnext_small(weights=ConvNeXt_Small_Weights.IMAGENET1K_V1)
feature_dim = 768
self.encoder = nn.Sequential(*list(convnext.children())[:-1])
if freeze_backbone:
for param in self.encoder.parameters():
param.requires_grad = False
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(feature_dim, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, output_dim),
)
self.use_simple_model = False
else:
raise ValueError(f"model_type must be 'simplecnn', 'smallvit', 'persistencevit', 'resnet18', 'resnet50', 'resnet101', 'dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14', 'convnext_small', or 'hybrid_resnet_vit', got {model_type}")
self.model_type = model_type
def forward(self, images, persistence_coords=None, pixel_coords=None):
"""
Args:
images: (B, 3, H, W) or (B, 1, H, W) tensor of persistence images
persistence_coords: (B, N_features, 3) - optional for PersistenceViT
pixel_coords: (B, N_features, 2) - optional for PersistenceViT
Returns:
features: (B, output_dim) tensor of image features
"""
if self.use_simple_model:
# PersistenceViT can use the coordinates
if self.model_type == 'persistencevit':
return self.encoder(images, persistence_coords, pixel_coords)
else:
return self.encoder(images)
else:
features = self.encoder(images) # (B, feature_dim, 1, 1)
features = features.view(features.size(0), -1) # (B, feature_dim)
features = self.fc(features) # (B, output_dim)
return features
class DINOv2ImageEncoder(nn.Module):
"""
DINOv2-based image encoder for persistence images.
Supports multiple DINOv2 variants: ViT-S/14, ViT-B/14, ViT-L/14, ViT-g/14.
Returns raw features (384-dim for vits14) without projection head.
"""
def __init__(self, output_dim=128, freeze_backbone=True, model_variant='dinov2_vits14'):
super(DINOv2ImageEncoder, self).__init__()
# Map model variant to feature dimension
variant_dims = {
'dinov2_vits14': 384,
'dinov2_vitb14': 768,
'dinov2_vitl14': 1024,
'dinov2_vitg14': 1536
}
if model_variant not in variant_dims:
raise ValueError(f"Unknown DINOv2 variant: {model_variant}")
self.feat_dim = variant_dims[model_variant]
# Load DINOv2 model
self.backbone = torch.hub.load('facebookresearch/dinov2', model_variant)
# Freeze backbone if requested
if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False
def forward(self, images, persistence_coords=None, pixel_coords=None):
"""
Args:
images: (B, 3, H, W) tensor
persistence_coords: ignored (for interface compatibility)
pixel_coords: ignored (for interface compatibility)
Returns:
features: (B, feat_dim) raw DINOv2 features (384 for vits14)
"""
# DINOv2 forward - return raw features
features = self.backbone(images) # (B, feat_dim)
return features