spatial-building-embeddings / modeling_spatial_embeddings.py
alexwaolson's picture
Upload folder using huggingface_hub
2d8a966 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModel
from typing import Optional, Tuple, Union, Literal
# Handle import for both local development and HuggingFace Hub loading
try:
from .configuration_spatial_embeddings import SpatialEmbeddingsConfig
except ImportError:
# When loaded from HuggingFace Hub, relative imports may not work
# Try absolute import instead
try:
from configuration_spatial_embeddings import SpatialEmbeddingsConfig
except ImportError:
# Last resort: import from the module directly
import sys
from pathlib import Path
# Get the directory where this file is located
current_dir = Path(__file__).parent
if str(current_dir) not in sys.path:
sys.path.insert(0, str(current_dir))
from configuration_spatial_embeddings import SpatialEmbeddingsConfig
class EmbeddingProjector(nn.Module):
"""
Configurable MLP projection head for embedding transformation.
(Copied from train_specialized_embeddings/model.py for self-contained publishing)
"""
def __init__(
self,
input_dim: int = 768,
hidden_dim: int = 512,
output_dim: int = 256,
dropout: float = 0.1,
num_hidden_layers: int = 1,
hidden_dim_multiplier: float = 1.0,
activation: Literal["gelu", "relu", "silu"] = "gelu",
use_residual: bool = True,
use_layer_norm: bool = True,
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.use_residual = use_residual
self.use_layer_norm = use_layer_norm
self.num_hidden_layers = num_hidden_layers
self.hidden_dim_multiplier = hidden_dim_multiplier
self.activation_name = activation
self.hidden_dims = self._compute_hidden_dims(
hidden_dim, num_hidden_layers, hidden_dim_multiplier
)
self.activation = self._resolve_activation(activation)
# First hidden block
self.input_layer = nn.Linear(input_dim, self.hidden_dims[0])
if use_layer_norm:
self.input_norm = nn.LayerNorm(self.hidden_dims[0])
self.input_dropout = nn.Dropout(dropout)
# Additional hidden blocks (if any)
self.hidden_layers = nn.ModuleList()
if use_layer_norm:
self.hidden_norms = nn.ModuleList()
else:
self.hidden_norms = None
self.hidden_dropouts = nn.ModuleList()
for idx in range(1, len(self.hidden_dims)):
layer = nn.Linear(self.hidden_dims[idx - 1], self.hidden_dims[idx])
self.hidden_layers.append(layer)
if use_layer_norm:
self.hidden_norms.append(nn.LayerNorm(self.hidden_dims[idx]))
self.hidden_dropouts.append(nn.Dropout(dropout))
# Output block
self.output_layer = nn.Linear(self.hidden_dims[-1], output_dim)
if use_layer_norm:
self.output_norm = nn.LayerNorm(output_dim)
self.output_dropout = nn.Dropout(dropout)
# Residual shortcut (projects input directly to output)
if use_residual:
self.residual_proj = nn.Linear(input_dim, output_dim)
@staticmethod
def _compute_hidden_dims(
base_hidden_dim: int, num_layers: int, multiplier: float
) -> list[int]:
dims: list[int] = []
current_dim = base_hidden_dim
for layer_idx in range(num_layers):
if layer_idx == 0:
dims.append(base_hidden_dim)
else:
current_dim = max(16, int(round(current_dim * multiplier)))
dims.append(current_dim)
return dims
@staticmethod
def _resolve_activation(name: str) -> nn.Module:
if name == "gelu":
return nn.GELU()
if name == "relu":
return nn.ReLU()
if name == "silu":
return nn.SiLU()
raise ValueError(f"Unsupported activation: {name}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
# First hidden block
out = self.input_layer(x)
if self.use_layer_norm:
out = self.input_norm(out)
out = self.activation(out)
out = self.input_dropout(out)
# Additional hidden blocks
for idx, layer in enumerate(self.hidden_layers):
out = layer(out)
if self.use_layer_norm and self.hidden_norms is not None:
out = self.hidden_norms[idx](out)
out = self.activation(out)
out = self.hidden_dropouts[idx](out)
# Output block
out = self.output_layer(out)
if self.use_layer_norm:
out = self.output_norm(out)
out = self.output_dropout(out)
# Residual connection
if self.use_residual:
residual = self.residual_proj(x)
out = out + residual
# L2 normalization
out = F.normalize(out, p=2, dim=1)
return out
class SpatialEmbeddingsModel(PreTrainedModel):
config_class = SpatialEmbeddingsConfig
def __init__(self, config: SpatialEmbeddingsConfig):
super().__init__(config)
self.config = config
# Initialize backbone
self.backbone = AutoModel.from_pretrained(
config.backbone_model_name, trust_remote_code=True
)
# Initialize projector
self.projector = EmbeddingProjector(
input_dim=config.input_dim,
hidden_dim=config.hidden_dim,
output_dim=config.output_dim,
dropout=config.dropout,
num_hidden_layers=config.num_hidden_layers,
hidden_dim_multiplier=config.hidden_dim_multiplier,
activation=config.activation,
use_residual=config.use_residual,
use_layer_norm=config.use_layer_norm,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, torch.Tensor]:
"""
Args:
pixel_values: Tensor of shape (batch_size, channels, height, width)
return_dict: Whether to return a dictionary or tuple
Returns:
If return_dict is True (default for HF), returns object with 'embeddings'.
Otherwise returns (embeddings,).
"""
# Pass through backbone
outputs = self.backbone(pixel_values=pixel_values, return_dict=True, **kwargs)
# Extract pooled output (CLS token or similar)
# DINOv2 outputs pooler_output in some versions, or last_hidden_state
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
backbone_emb = outputs.pooler_output
else:
# Fallback: Use CLS token from last hidden state
backbone_emb = outputs.last_hidden_state[:, 0]
# Project to specialized embedding
specialized_emb = self.projector(backbone_emb)
if return_dict:
return {"embeddings": specialized_emb, "backbone_outputs": outputs}
return (specialized_emb,)