File size: 10,444 Bytes
a35137b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 |
from abc import ABC, abstractmethod
import einops
import torch
import torch.nn as nn
from typing import Optional
class SpatialEncoderMeta:
def __init__(self, subject_session_spatial_groups=None):
"""Metadata object with subject session information for spatial encoding."""
self.subject_session_spatial_groups = subject_session_spatial_groups
@property
def num_region_info(self):
n_effective_components_across_sessions = set(
[a.n_effective_components for a in self.subject_session_spatial_groups.values()]
)
assert len(n_effective_components_across_sessions) == 1, (
"Doesn't support variable number of effective components for different subject_sessions"
)
self._num_region_info = n_effective_components_across_sessions.pop()
return self._num_region_info
@property
def embedding_table_configs(self):
configs = {}
for i in range(self.num_region_info):
n_embeddings_for_components_set = set(
[a.max_elements_for_component[i] for a in self.subject_session_spatial_groups.values()]
)
padding_indices_set = set(
[a.padding_indices[i] for a in self.subject_session_spatial_groups.values()]
)
assert len(n_embeddings_for_components_set) == 1, (
"Doesn't support variable number of max components for different subject_sessions, "
"change to use max of values across the subject if it is not important."
)
assert len(padding_indices_set) == 1, (
"Doesn't support variable number of padding indices for different subject_sessions, "
"change to use max of values across the subject if it is not important."
)
configs[i] = {
'num_embeddings': n_embeddings_for_components_set.pop(),
'padding_idx': padding_indices_set.pop()
}
return configs
class BaseSpatialEncoder(ABC, nn.Module):
"""Abstract class definition for spatial encoding modules.
Implement this interface to try new spatial encoding approaches in the tokenizer.
"""
_SUBJ_SESH_QUERY_HASH_STR = "{0}_queryvec"
def __init__(
self,
dim_h: int,
spatial_encoder_meta: SpatialEncoderMeta,
):
super().__init__()
self.dim_h = dim_h
self.spatial_encoder_meta = spatial_encoder_meta
self._construct_region_encoding_meta()
def _construct_region_encoding_meta(self):
"""Constructs a hashmap of channel region information -> query vector for spatial encoding."""
for (
subject_session,
spatial_groups,
) in self.spatial_encoder_meta.subject_session_spatial_groups.items():
query_vector = torch.tensor(
[tuple(map(int, e[:spatial_groups.n_effective_components])) for e in spatial_groups.group_components]
)
query_vector = self._transform_query_vector(query_vector)
self.register_buffer(
BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session),
query_vector, persistent=False
)
def _transform_query_vector(self, query_vector: torch.Tensor):
return query_vector
def get_embedding_table_query_vector(self, subject_session: str) -> torch.Tensor:
return self._buffers[BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session)].to(torch.long)
def update_for_new_sessions(self,
new_subject_session_spatial_groups):
self.spatial_encoder_meta.subject_session_spatial_groups = new_subject_session_spatial_groups
self._construct_region_encoding_meta()
return []
@abstractmethod
def _encode(self, x: torch.tensor) -> torch.tensor:
pass
@abstractmethod
def _get_position_encoding(
self, x: torch.tensor, subject_session: str
) -> torch.tensor:
pass
def forward(
self,
x: torch.tensor,
subject_session: str,
timepoints: int = 1,
mask: torch.tensor = None,
) -> torch.tensor:
"""
Args:
x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.
Returns:
A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
the input token x.
"""
session_PE = self._get_position_encoding(x, subject_session)
assert (
x.shape[-1] == session_PE.shape[-1]
), f"Region dimension mismatch: {x.shape[-1]} vs {session_PE.shape[-1]}."
position_encoding = einops.repeat(
session_PE, "r d -> b (t r) d", b=x.shape[0], t=timepoints
)
if mask is not None:
position_encoding = position_encoding[:, mask, :]
assert (
x.shape == position_encoding.shape
), "Output position encoding does not match in shape"
return position_encoding
class EmbeddingTable(BaseSpatialEncoder):
def __init__(
self,
dim_h: int,
spatial_encoder_meta: SpatialEncoderMeta,
embedding_max_dim: Optional[float] = None,
embedding_init_scale: float = 1.0
):
"""A lookup table of different embeddings for different spatial fields."""
super().__init__(dim_h, spatial_encoder_meta)
# Create the embeddings.
self.subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
subcomponent_dims = self._get_subcomponent_dims()
self.subcomponent_embeddings = nn.ModuleDict()
for (
subcomponent_ind,
subcomponent_config,
) in self.subcomponent_embedding_info.items():
subcomponent_dim = subcomponent_dims[subcomponent_ind]
self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
subcomponent_config["num_embeddings"],
subcomponent_dim,
padding_idx=subcomponent_config["padding_idx"],
max_norm=embedding_max_dim,
)
self.init_weights_for_embeddings(
self.subcomponent_embeddings[str(subcomponent_ind)],
embedding_init_scale
)
@abstractmethod
def _get_subcomponent_dims(self):
raise NotImplementedError
def update_for_new_sessions(self, new_subject_session_spatial_groups):
"""Add need embedding table elements based on new subject session information."""
new_params = super().update_for_new_sessions(new_subject_session_spatial_groups)
subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
for subcomponent_ind, subcomponent_config in subcomponent_embedding_info.items():
prev_embeddings = self.subcomponent_embeddings[str(subcomponent_ind)]
n_rows, subcomponent_dim = prev_embeddings.weight.shape
if subcomponent_config['num_embeddings'] == n_rows:
# no need to add any new embedding
continue
new_embeddings = torch.empty(
subcomponent_config['num_embeddings'] - n_rows,
subcomponent_dim,
device=prev_embeddings.weight.device
)
nn.init.normal_(new_embeddings)
new_data = torch.cat((prev_embeddings.weight.data, new_embeddings))
self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
subcomponent_config["num_embeddings"],
subcomponent_dim,
padding_idx=subcomponent_config["padding_idx"],
)
self.subcomponent_embeddings[str(subcomponent_ind)].weight.data = new_data
new_params.extend([n for n, _ in self.named_parameters()])
return new_params
def init_weights_for_embeddings(self, embedding_table: nn.Embedding, embedding_init_scale: float = 1.0):
nn.init.normal_(embedding_table.weight, std=embedding_init_scale)
embedding_table._fill_padding_idx_with_zero()
def _transform_query_vector(self, query_vector: torch.Tensor):
return query_vector.to(torch.float).T
def _get_position_encoding(
self, _: torch.tensor, subject_session: str
) -> torch.tensor:
"""Returns the encoding vector based on a subject session query."""
session_region_query = self.get_embedding_table_query_vector(
subject_session
)
single_session_PE = self._encode(session_region_query)
return single_session_PE
class EmbeddingTablePool(EmbeddingTable):
def _get_subcomponent_dims(self):
return {k: self.dim_h for k in self.subcomponent_embedding_info.keys()}
def _encode(self, x: torch.tensor) -> torch.tensor:
"""
Args:
x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.
Returns:
A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
the input token. If token has multiple spatial fields, the encoding for
each of these fields will be summed together before being return (e.g.,
x,y,z LPI coordinates).
"""
PE = torch.zeros((x.shape[0], x.shape[1], self.dim_h), device=x.get_device())
for subcomponent_ind in range(x.shape[0]):
subcomponent_x = x[subcomponent_ind, ...]
PE[subcomponent_ind, ...] = self.subcomponent_embeddings[
str(subcomponent_ind)
](subcomponent_x)
return torch.sum(PE, axis=0)
def create_spatial_encoder(
dim_h: int,
subject_session_spatial_groups=None,
embedding_max_dim=None,
embedding_init_scale=1.0,
) -> BaseSpatialEncoder:
"""Creates the spatial encoder and the cached spatial encoding information needed during forward passes."""
spatial_encoder_meta = SpatialEncoderMeta(
subject_session_spatial_groups
)
spatial_encoder = EmbeddingTablePool(
dim_h,
spatial_encoder_meta,
embedding_max_dim,
embedding_init_scale
)
return spatial_encoder |