BaRISTA / barista /models /spatial_encoder.py
savaw's picture
Upload folder using huggingface_hub
a35137b verified
from abc import ABC, abstractmethod
import einops
import torch
import torch.nn as nn
from typing import Optional
class SpatialEncoderMeta:
def __init__(self, subject_session_spatial_groups=None):
"""Metadata object with subject session information for spatial encoding."""
self.subject_session_spatial_groups = subject_session_spatial_groups
@property
def num_region_info(self):
n_effective_components_across_sessions = set(
[a.n_effective_components for a in self.subject_session_spatial_groups.values()]
)
assert len(n_effective_components_across_sessions) == 1, (
"Doesn't support variable number of effective components for different subject_sessions"
)
self._num_region_info = n_effective_components_across_sessions.pop()
return self._num_region_info
@property
def embedding_table_configs(self):
configs = {}
for i in range(self.num_region_info):
n_embeddings_for_components_set = set(
[a.max_elements_for_component[i] for a in self.subject_session_spatial_groups.values()]
)
padding_indices_set = set(
[a.padding_indices[i] for a in self.subject_session_spatial_groups.values()]
)
assert len(n_embeddings_for_components_set) == 1, (
"Doesn't support variable number of max components for different subject_sessions, "
"change to use max of values across the subject if it is not important."
)
assert len(padding_indices_set) == 1, (
"Doesn't support variable number of padding indices for different subject_sessions, "
"change to use max of values across the subject if it is not important."
)
configs[i] = {
'num_embeddings': n_embeddings_for_components_set.pop(),
'padding_idx': padding_indices_set.pop()
}
return configs
class BaseSpatialEncoder(ABC, nn.Module):
"""Abstract class definition for spatial encoding modules.
Implement this interface to try new spatial encoding approaches in the tokenizer.
"""
_SUBJ_SESH_QUERY_HASH_STR = "{0}_queryvec"
def __init__(
self,
dim_h: int,
spatial_encoder_meta: SpatialEncoderMeta,
):
super().__init__()
self.dim_h = dim_h
self.spatial_encoder_meta = spatial_encoder_meta
self._construct_region_encoding_meta()
def _construct_region_encoding_meta(self):
"""Constructs a hashmap of channel region information -> query vector for spatial encoding."""
for (
subject_session,
spatial_groups,
) in self.spatial_encoder_meta.subject_session_spatial_groups.items():
query_vector = torch.tensor(
[tuple(map(int, e[:spatial_groups.n_effective_components])) for e in spatial_groups.group_components]
)
query_vector = self._transform_query_vector(query_vector)
self.register_buffer(
BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session),
query_vector, persistent=False
)
def _transform_query_vector(self, query_vector: torch.Tensor):
return query_vector
def get_embedding_table_query_vector(self, subject_session: str) -> torch.Tensor:
return self._buffers[BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session)].to(torch.long)
def update_for_new_sessions(self,
new_subject_session_spatial_groups):
self.spatial_encoder_meta.subject_session_spatial_groups = new_subject_session_spatial_groups
self._construct_region_encoding_meta()
return []
@abstractmethod
def _encode(self, x: torch.tensor) -> torch.tensor:
pass
@abstractmethod
def _get_position_encoding(
self, x: torch.tensor, subject_session: str
) -> torch.tensor:
pass
def forward(
self,
x: torch.tensor,
subject_session: str,
timepoints: int = 1,
mask: torch.tensor = None,
) -> torch.tensor:
"""
Args:
x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.
Returns:
A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
the input token x.
"""
session_PE = self._get_position_encoding(x, subject_session)
assert (
x.shape[-1] == session_PE.shape[-1]
), f"Region dimension mismatch: {x.shape[-1]} vs {session_PE.shape[-1]}."
position_encoding = einops.repeat(
session_PE, "r d -> b (t r) d", b=x.shape[0], t=timepoints
)
if mask is not None:
position_encoding = position_encoding[:, mask, :]
assert (
x.shape == position_encoding.shape
), "Output position encoding does not match in shape"
return position_encoding
class EmbeddingTable(BaseSpatialEncoder):
def __init__(
self,
dim_h: int,
spatial_encoder_meta: SpatialEncoderMeta,
embedding_max_dim: Optional[float] = None,
embedding_init_scale: float = 1.0
):
"""A lookup table of different embeddings for different spatial fields."""
super().__init__(dim_h, spatial_encoder_meta)
# Create the embeddings.
self.subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
subcomponent_dims = self._get_subcomponent_dims()
self.subcomponent_embeddings = nn.ModuleDict()
for (
subcomponent_ind,
subcomponent_config,
) in self.subcomponent_embedding_info.items():
subcomponent_dim = subcomponent_dims[subcomponent_ind]
self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
subcomponent_config["num_embeddings"],
subcomponent_dim,
padding_idx=subcomponent_config["padding_idx"],
max_norm=embedding_max_dim,
)
self.init_weights_for_embeddings(
self.subcomponent_embeddings[str(subcomponent_ind)],
embedding_init_scale
)
@abstractmethod
def _get_subcomponent_dims(self):
raise NotImplementedError
def update_for_new_sessions(self, new_subject_session_spatial_groups):
"""Add need embedding table elements based on new subject session information."""
new_params = super().update_for_new_sessions(new_subject_session_spatial_groups)
subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
for subcomponent_ind, subcomponent_config in subcomponent_embedding_info.items():
prev_embeddings = self.subcomponent_embeddings[str(subcomponent_ind)]
n_rows, subcomponent_dim = prev_embeddings.weight.shape
if subcomponent_config['num_embeddings'] == n_rows:
# no need to add any new embedding
continue
new_embeddings = torch.empty(
subcomponent_config['num_embeddings'] - n_rows,
subcomponent_dim,
device=prev_embeddings.weight.device
)
nn.init.normal_(new_embeddings)
new_data = torch.cat((prev_embeddings.weight.data, new_embeddings))
self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
subcomponent_config["num_embeddings"],
subcomponent_dim,
padding_idx=subcomponent_config["padding_idx"],
)
self.subcomponent_embeddings[str(subcomponent_ind)].weight.data = new_data
new_params.extend([n for n, _ in self.named_parameters()])
return new_params
def init_weights_for_embeddings(self, embedding_table: nn.Embedding, embedding_init_scale: float = 1.0):
nn.init.normal_(embedding_table.weight, std=embedding_init_scale)
embedding_table._fill_padding_idx_with_zero()
def _transform_query_vector(self, query_vector: torch.Tensor):
return query_vector.to(torch.float).T
def _get_position_encoding(
self, _: torch.tensor, subject_session: str
) -> torch.tensor:
"""Returns the encoding vector based on a subject session query."""
session_region_query = self.get_embedding_table_query_vector(
subject_session
)
single_session_PE = self._encode(session_region_query)
return single_session_PE
class EmbeddingTablePool(EmbeddingTable):
def _get_subcomponent_dims(self):
return {k: self.dim_h for k in self.subcomponent_embedding_info.keys()}
def _encode(self, x: torch.tensor) -> torch.tensor:
"""
Args:
x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.
Returns:
A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
the input token. If token has multiple spatial fields, the encoding for
each of these fields will be summed together before being return (e.g.,
x,y,z LPI coordinates).
"""
PE = torch.zeros((x.shape[0], x.shape[1], self.dim_h), device=x.get_device())
for subcomponent_ind in range(x.shape[0]):
subcomponent_x = x[subcomponent_ind, ...]
PE[subcomponent_ind, ...] = self.subcomponent_embeddings[
str(subcomponent_ind)
](subcomponent_x)
return torch.sum(PE, axis=0)
def create_spatial_encoder(
dim_h: int,
subject_session_spatial_groups=None,
embedding_max_dim=None,
embedding_init_scale=1.0,
) -> BaseSpatialEncoder:
"""Creates the spatial encoder and the cached spatial encoding information needed during forward passes."""
spatial_encoder_meta = SpatialEncoderMeta(
subject_session_spatial_groups
)
spatial_encoder = EmbeddingTablePool(
dim_h,
spatial_encoder_meta,
embedding_max_dim,
embedding_init_scale
)
return spatial_encoder