| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, AutoModel |
| from typing import Optional, Tuple, Union, Literal |
|
|
| |
| try: |
| from .configuration_spatial_embeddings import SpatialEmbeddingsConfig |
| except ImportError: |
| |
| |
| try: |
| from configuration_spatial_embeddings import SpatialEmbeddingsConfig |
| except ImportError: |
| |
| import sys |
| from pathlib import Path |
| |
| current_dir = Path(__file__).parent |
| if str(current_dir) not in sys.path: |
| sys.path.insert(0, str(current_dir)) |
| from configuration_spatial_embeddings import SpatialEmbeddingsConfig |
|
|
|
|
| class EmbeddingProjector(nn.Module): |
| """ |
| Configurable MLP projection head for embedding transformation. |
| (Copied from train_specialized_embeddings/model.py for self-contained publishing) |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int = 768, |
| hidden_dim: int = 512, |
| output_dim: int = 256, |
| dropout: float = 0.1, |
| num_hidden_layers: int = 1, |
| hidden_dim_multiplier: float = 1.0, |
| activation: Literal["gelu", "relu", "silu"] = "gelu", |
| use_residual: bool = True, |
| use_layer_norm: bool = True, |
| ): |
| super().__init__() |
|
|
| self.input_dim = input_dim |
| self.hidden_dim = hidden_dim |
| self.output_dim = output_dim |
| self.use_residual = use_residual |
| self.use_layer_norm = use_layer_norm |
| self.num_hidden_layers = num_hidden_layers |
| self.hidden_dim_multiplier = hidden_dim_multiplier |
| self.activation_name = activation |
|
|
| self.hidden_dims = self._compute_hidden_dims( |
| hidden_dim, num_hidden_layers, hidden_dim_multiplier |
| ) |
|
|
| self.activation = self._resolve_activation(activation) |
|
|
| |
| self.input_layer = nn.Linear(input_dim, self.hidden_dims[0]) |
| if use_layer_norm: |
| self.input_norm = nn.LayerNorm(self.hidden_dims[0]) |
| self.input_dropout = nn.Dropout(dropout) |
|
|
| |
| self.hidden_layers = nn.ModuleList() |
| if use_layer_norm: |
| self.hidden_norms = nn.ModuleList() |
| else: |
| self.hidden_norms = None |
| self.hidden_dropouts = nn.ModuleList() |
|
|
| for idx in range(1, len(self.hidden_dims)): |
| layer = nn.Linear(self.hidden_dims[idx - 1], self.hidden_dims[idx]) |
| self.hidden_layers.append(layer) |
| if use_layer_norm: |
| self.hidden_norms.append(nn.LayerNorm(self.hidden_dims[idx])) |
| self.hidden_dropouts.append(nn.Dropout(dropout)) |
|
|
| |
| self.output_layer = nn.Linear(self.hidden_dims[-1], output_dim) |
| if use_layer_norm: |
| self.output_norm = nn.LayerNorm(output_dim) |
| self.output_dropout = nn.Dropout(dropout) |
|
|
| |
| if use_residual: |
| self.residual_proj = nn.Linear(input_dim, output_dim) |
|
|
| @staticmethod |
| def _compute_hidden_dims( |
| base_hidden_dim: int, num_layers: int, multiplier: float |
| ) -> list[int]: |
| dims: list[int] = [] |
| current_dim = base_hidden_dim |
| for layer_idx in range(num_layers): |
| if layer_idx == 0: |
| dims.append(base_hidden_dim) |
| else: |
| current_dim = max(16, int(round(current_dim * multiplier))) |
| dims.append(current_dim) |
| return dims |
|
|
| @staticmethod |
| def _resolve_activation(name: str) -> nn.Module: |
| if name == "gelu": |
| return nn.GELU() |
| if name == "relu": |
| return nn.ReLU() |
| if name == "silu": |
| return nn.SiLU() |
| raise ValueError(f"Unsupported activation: {name}") |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| out = self.input_layer(x) |
| if self.use_layer_norm: |
| out = self.input_norm(out) |
| out = self.activation(out) |
| out = self.input_dropout(out) |
|
|
| |
| for idx, layer in enumerate(self.hidden_layers): |
| out = layer(out) |
| if self.use_layer_norm and self.hidden_norms is not None: |
| out = self.hidden_norms[idx](out) |
| out = self.activation(out) |
| out = self.hidden_dropouts[idx](out) |
|
|
| |
| out = self.output_layer(out) |
| if self.use_layer_norm: |
| out = self.output_norm(out) |
| out = self.output_dropout(out) |
|
|
| |
| if self.use_residual: |
| residual = self.residual_proj(x) |
| out = out + residual |
|
|
| |
| out = F.normalize(out, p=2, dim=1) |
|
|
| return out |
|
|
|
|
| class SpatialEmbeddingsModel(PreTrainedModel): |
| config_class = SpatialEmbeddingsConfig |
|
|
| def __init__(self, config: SpatialEmbeddingsConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.backbone = AutoModel.from_pretrained( |
| config.backbone_model_name, trust_remote_code=True |
| ) |
|
|
| |
| self.projector = EmbeddingProjector( |
| input_dim=config.input_dim, |
| hidden_dim=config.hidden_dim, |
| output_dim=config.output_dim, |
| dropout=config.dropout, |
| num_hidden_layers=config.num_hidden_layers, |
| hidden_dim_multiplier=config.hidden_dim_multiplier, |
| activation=config.activation, |
| use_residual=config.use_residual, |
| use_layer_norm=config.use_layer_norm, |
| ) |
|
|
| def forward( |
| self, |
| pixel_values: Optional[torch.Tensor] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[Tuple, torch.Tensor]: |
| """ |
| Args: |
| pixel_values: Tensor of shape (batch_size, channels, height, width) |
| return_dict: Whether to return a dictionary or tuple |
| |
| Returns: |
| If return_dict is True (default for HF), returns object with 'embeddings'. |
| Otherwise returns (embeddings,). |
| """ |
| |
| outputs = self.backbone(pixel_values=pixel_values, return_dict=True, **kwargs) |
|
|
| |
| |
| if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: |
| backbone_emb = outputs.pooler_output |
| else: |
| |
| backbone_emb = outputs.last_hidden_state[:, 0] |
|
|
| |
| specialized_emb = self.projector(backbone_emb) |
|
|
| if return_dict: |
| return {"embeddings": specialized_emb, "backbone_outputs": outputs} |
| return (specialized_emb,) |
|
|