| |
|
| | from abc import ABC, abstractmethod |
| | import einops |
| | import torch |
| | import torch.nn as nn |
| | from typing import Optional |
| |
|
| |
|
| | class SpatialEncoderMeta: |
| | def __init__(self, subject_session_spatial_groups=None): |
| | """Metadata object with subject session information for spatial encoding.""" |
| | self.subject_session_spatial_groups = subject_session_spatial_groups |
| | |
| | @property |
| | def num_region_info(self): |
| | n_effective_components_across_sessions = set( |
| | [a.n_effective_components for a in self.subject_session_spatial_groups.values()] |
| | ) |
| |
|
| | assert len(n_effective_components_across_sessions) == 1, ( |
| | "Doesn't support variable number of effective components for different subject_sessions" |
| | ) |
| |
|
| | self._num_region_info = n_effective_components_across_sessions.pop() |
| | return self._num_region_info |
| |
|
| | @property |
| | def embedding_table_configs(self): |
| | configs = {} |
| | for i in range(self.num_region_info): |
| | n_embeddings_for_components_set = set( |
| | [a.max_elements_for_component[i] for a in self.subject_session_spatial_groups.values()] |
| | ) |
| | padding_indices_set = set( |
| | [a.padding_indices[i] for a in self.subject_session_spatial_groups.values()] |
| | ) |
| | |
| | assert len(n_embeddings_for_components_set) == 1, ( |
| | "Doesn't support variable number of max components for different subject_sessions, " |
| | "change to use max of values across the subject if it is not important." |
| | ) |
| | assert len(padding_indices_set) == 1, ( |
| | "Doesn't support variable number of padding indices for different subject_sessions, " |
| | "change to use max of values across the subject if it is not important." |
| | ) |
| | |
| | configs[i] = { |
| | 'num_embeddings': n_embeddings_for_components_set.pop(), |
| | 'padding_idx': padding_indices_set.pop() |
| | } |
| |
|
| | return configs |
| |
|
| |
|
| | class BaseSpatialEncoder(ABC, nn.Module): |
| | """Abstract class definition for spatial encoding modules. |
| | |
| | Implement this interface to try new spatial encoding approaches in the tokenizer. |
| | """ |
| | _SUBJ_SESH_QUERY_HASH_STR = "{0}_queryvec" |
| |
|
| | def __init__( |
| | self, |
| | dim_h: int, |
| | spatial_encoder_meta: SpatialEncoderMeta, |
| | ): |
| | super().__init__() |
| | self.dim_h = dim_h |
| | self.spatial_encoder_meta = spatial_encoder_meta |
| |
|
| | self._construct_region_encoding_meta() |
| |
|
| | def _construct_region_encoding_meta(self): |
| | """Constructs a hashmap of channel region information -> query vector for spatial encoding.""" |
| | for ( |
| | subject_session, |
| | spatial_groups, |
| | ) in self.spatial_encoder_meta.subject_session_spatial_groups.items(): |
| | query_vector = torch.tensor( |
| | [tuple(map(int, e[:spatial_groups.n_effective_components])) for e in spatial_groups.group_components] |
| | ) |
| | |
| | query_vector = self._transform_query_vector(query_vector) |
| | |
| | self.register_buffer( |
| | BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session), |
| | query_vector, persistent=False |
| | ) |
| |
|
| | def _transform_query_vector(self, query_vector: torch.Tensor): |
| | return query_vector |
| |
|
| | def get_embedding_table_query_vector(self, subject_session: str) -> torch.Tensor: |
| | return self._buffers[BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session)].to(torch.long) |
| |
|
| | def update_for_new_sessions(self, |
| | new_subject_session_spatial_groups): |
| | self.spatial_encoder_meta.subject_session_spatial_groups = new_subject_session_spatial_groups |
| | self._construct_region_encoding_meta() |
| | return [] |
| | |
| | @abstractmethod |
| | def _encode(self, x: torch.tensor) -> torch.tensor: |
| | pass |
| |
|
| | @abstractmethod |
| | def _get_position_encoding( |
| | self, x: torch.tensor, subject_session: str |
| | ) -> torch.tensor: |
| | pass |
| |
|
| | def forward( |
| | self, |
| | x: torch.tensor, |
| | subject_session: str, |
| | timepoints: int = 1, |
| | mask: torch.tensor = None, |
| | ) -> torch.tensor: |
| | """ |
| | Args: |
| | x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D. |
| | |
| | Returns: |
| | A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to |
| | the input token x. |
| | """ |
| | session_PE = self._get_position_encoding(x, subject_session) |
| | assert ( |
| | x.shape[-1] == session_PE.shape[-1] |
| | ), f"Region dimension mismatch: {x.shape[-1]} vs {session_PE.shape[-1]}." |
| |
|
| | position_encoding = einops.repeat( |
| | session_PE, "r d -> b (t r) d", b=x.shape[0], t=timepoints |
| | ) |
| |
|
| | if mask is not None: |
| | position_encoding = position_encoding[:, mask, :] |
| |
|
| | assert ( |
| | x.shape == position_encoding.shape |
| | ), "Output position encoding does not match in shape" |
| | return position_encoding |
| |
|
| |
|
| | class EmbeddingTable(BaseSpatialEncoder): |
| | def __init__( |
| | self, |
| | dim_h: int, |
| | spatial_encoder_meta: SpatialEncoderMeta, |
| | embedding_max_dim: Optional[float] = None, |
| | embedding_init_scale: float = 1.0 |
| | ): |
| | """A lookup table of different embeddings for different spatial fields.""" |
| | super().__init__(dim_h, spatial_encoder_meta) |
| |
|
| | |
| | self.subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs |
| | subcomponent_dims = self._get_subcomponent_dims() |
| |
|
| | self.subcomponent_embeddings = nn.ModuleDict() |
| | for ( |
| | subcomponent_ind, |
| | subcomponent_config, |
| | ) in self.subcomponent_embedding_info.items(): |
| | subcomponent_dim = subcomponent_dims[subcomponent_ind] |
| |
|
| | self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding( |
| | subcomponent_config["num_embeddings"], |
| | subcomponent_dim, |
| | padding_idx=subcomponent_config["padding_idx"], |
| | max_norm=embedding_max_dim, |
| | ) |
| |
|
| | self.init_weights_for_embeddings( |
| | self.subcomponent_embeddings[str(subcomponent_ind)], |
| | embedding_init_scale |
| | ) |
| |
|
| | @abstractmethod |
| | def _get_subcomponent_dims(self): |
| | raise NotImplementedError |
| |
|
| | def update_for_new_sessions(self, new_subject_session_spatial_groups): |
| | """Add need embedding table elements based on new subject session information.""" |
| | new_params = super().update_for_new_sessions(new_subject_session_spatial_groups) |
| | |
| | subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs |
| | for subcomponent_ind, subcomponent_config in subcomponent_embedding_info.items(): |
| | prev_embeddings = self.subcomponent_embeddings[str(subcomponent_ind)] |
| | n_rows, subcomponent_dim = prev_embeddings.weight.shape |
| | |
| | if subcomponent_config['num_embeddings'] == n_rows: |
| | |
| | continue |
| | |
| | new_embeddings = torch.empty( |
| | subcomponent_config['num_embeddings'] - n_rows, |
| | subcomponent_dim, |
| | device=prev_embeddings.weight.device |
| | ) |
| | nn.init.normal_(new_embeddings) |
| |
|
| | new_data = torch.cat((prev_embeddings.weight.data, new_embeddings)) |
| |
|
| | self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding( |
| | subcomponent_config["num_embeddings"], |
| | subcomponent_dim, |
| | padding_idx=subcomponent_config["padding_idx"], |
| | ) |
| | self.subcomponent_embeddings[str(subcomponent_ind)].weight.data = new_data |
| |
|
| | new_params.extend([n for n, _ in self.named_parameters()]) |
| |
|
| | return new_params |
| | |
| | def init_weights_for_embeddings(self, embedding_table: nn.Embedding, embedding_init_scale: float = 1.0): |
| | nn.init.normal_(embedding_table.weight, std=embedding_init_scale) |
| | embedding_table._fill_padding_idx_with_zero() |
| |
|
| | def _transform_query_vector(self, query_vector: torch.Tensor): |
| | return query_vector.to(torch.float).T |
| |
|
| | def _get_position_encoding( |
| | self, _: torch.tensor, subject_session: str |
| | ) -> torch.tensor: |
| | """Returns the encoding vector based on a subject session query.""" |
| | session_region_query = self.get_embedding_table_query_vector( |
| | subject_session |
| | ) |
| | single_session_PE = self._encode(session_region_query) |
| | return single_session_PE |
| |
|
| |
|
| | class EmbeddingTablePool(EmbeddingTable): |
| | def _get_subcomponent_dims(self): |
| | return {k: self.dim_h for k in self.subcomponent_embedding_info.keys()} |
| | |
| | def _encode(self, x: torch.tensor) -> torch.tensor: |
| | """ |
| | Args: |
| | x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D. |
| | |
| | Returns: |
| | A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to |
| | the input token. If token has multiple spatial fields, the encoding for |
| | each of these fields will be summed together before being return (e.g., |
| | x,y,z LPI coordinates). |
| | """ |
| | PE = torch.zeros((x.shape[0], x.shape[1], self.dim_h), device=x.get_device()) |
| | for subcomponent_ind in range(x.shape[0]): |
| | subcomponent_x = x[subcomponent_ind, ...] |
| | PE[subcomponent_ind, ...] = self.subcomponent_embeddings[ |
| | str(subcomponent_ind) |
| | ](subcomponent_x) |
| | return torch.sum(PE, axis=0) |
| |
|
| |
|
| | def create_spatial_encoder( |
| | dim_h: int, |
| | subject_session_spatial_groups=None, |
| | embedding_max_dim=None, |
| | embedding_init_scale=1.0, |
| | ) -> BaseSpatialEncoder: |
| | """Creates the spatial encoder and the cached spatial encoding information needed during forward passes.""" |
| | spatial_encoder_meta = SpatialEncoderMeta( |
| | subject_session_spatial_groups |
| | ) |
| |
|
| | spatial_encoder = EmbeddingTablePool( |
| | dim_h, |
| | spatial_encoder_meta, |
| | embedding_max_dim, |
| | embedding_init_scale |
| | ) |
| |
|
| | return spatial_encoder |