File size: 8,221 Bytes
a35137b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | import einops
from omegaconf import DictConfig
import torch
import torch.nn as nn
from typing import Dict, List, Union
import barista.models.spatial_encoder as spe
from barista.data.metadata import Metadata
from barista.models.mlp import MLP
from barista.models.tokenized_batched_item import TokenizedBatchedItem
from barista.models.TSEncoder2D import TSEncoder2D
class Tokenizer(nn.Module):
def __init__(
self,
config: DictConfig,
metadata: Metadata,
):
super().__init__()
self.metadata = metadata
self.config = config
self.subjects = metadata.get_subjects()
self.num_subsegments = int(
(
self.config.samp_frequency * self.config.num_seconds
- self.config.temporal_subsegment_len
)
// (self.config.temporal_subsegment_step)
+ 1
)
self.dim_h = self.config.d_hidden
self._build_temporal_encoder()
self._build_temporal_pooler()
self._build_spatial_encoder()
def _build_temporal_encoder(self):
self.config.temporal_encoder.input_dims = 1
self.config.temporal_encoder.output_dims = 1
self.temporal_encoder = TSEncoder2D(**self.config.temporal_encoder)
def _build_temporal_pooler(self):
self.temporal_pooler = MLP(
d_input=self.config.temporal_subsegment_len,
d_out=self.dim_h,
dropout=0.0,
bias=False,
)
def _build_spatial_encoder(self):
self.subject_session_spatial_groups = {}
for sub_sesh in self.metadata.get_subject_session_d_input().keys():
spatial_grouping = self.metadata.get_spatial_grouping(
subject_session=sub_sesh, name=self.config.spatial_grouping
)
self.subject_session_spatial_groups[sub_sesh] = spatial_grouping
self.spatial_encoder = spe.create_spatial_encoder(
dim_h=self.dim_h,
subject_session_spatial_groups=self.subject_session_spatial_groups,
embedding_max_dim=self.config.get('embedding_max_dim', None),
embedding_init_scale=self.config.get('embedding_init_scale', 1.0),
)
def update_for_new_sessions(
self,
new_session_d_input_dict: Dict[str, int],
new_metadata: Metadata,
) -> List:
self.subject_session_spatial_groups = {}
for sub_sesh in new_session_d_input_dict.keys():
spatial_grouping = new_metadata.get_spatial_grouping(
subject_session=sub_sesh, name=self.config.spatial_grouping
)
self.subject_session_spatial_groups[sub_sesh] = spatial_grouping
self.metadata = new_metadata
new_params = []
if self.config.add_spatial_encoding:
new_se_params = self.spatial_encoder.update_for_new_sessions(
new_subject_session_spatial_groups=self.subject_session_spatial_groups
)
new_params.extend([f"spatial_encoder.{n}" for n in new_se_params])
return new_params
def _tokenize_for_batch_tensor(
self,
x: Union[torch.Tensor, List],
subject_session: str,
add_spatial_encoding_to_tokens: bool = True,
) -> torch.tensor:
"""
Args:
x: Input tensor of shape (B, N, D) or a list of tensors each of shape (N_i, D_i)
B: Batch size
N: Time points
R: Channel dim
Returns:
Tokenized version of the same data as a TokenizedBatchedItem object.
"""
batch_size, num_timepoints, num_channels = x.shape
x = einops.rearrange(x, "b n d -> b d n")
# NOTE that unfold doesn't copy the memory, so if step is less than size (sliding window)
# and any of shared elements are changed, all occurance of that element in patches will change
x = x.unfold(
dimension=-1,
size=self.config.temporal_subsegment_len,
step=self.config.temporal_subsegment_step,
) # (B D num_subsegments subseg_len)
collapsed_x = einops.rearrange(
x, "b d t n -> (b t d) n"
) # (B * T * D, N)
transposed_tokens = einops.rearrange(
collapsed_x, "btd n -> 1 1 btd n"
) # (1, 1, B * T * D, N)
collapsed_tokens = self.temporal_encoder(transposed_tokens)
collapsed_tokens = collapsed_tokens.squeeze() # (B * T * D, N)
# "Time" dimension to hidden dimension. Using a fully connected layer here.
collapsed_tokens = self.temporal_pooler(
collapsed_tokens
) # (B * T * D, N) -> (B * T * D, HID_D)
collapsed_tokens_full = collapsed_tokens
# Create the time-space interleaved tokens.
tokens = einops.rearrange(
collapsed_tokens_full,
"(b t d) dh -> b (t d) dh",
b=batch_size,
t=self.num_subsegments,
)
seqlen_timepoints = self.num_subsegments
if self.config.add_spatial_encoding:
spatial_encoding = self.spatial_encoder(
tokens,
subject_session=subject_session,
timepoints=seqlen_timepoints,
)
# Make sure regions at differnet timestamps have same spatial encoding
assert (
seqlen_timepoints == 1
or spatial_encoding[0, 0, 0] == spatial_encoding[0, num_channels, 0]
)
if add_spatial_encoding_to_tokens:
tokens = tokens + spatial_encoding
else: # not self.config.add_spatial_encoding
spatial_encoding = None
temporal_group_ids = torch.arange(seqlen_timepoints, device=x.device)
temporal_group_ids = einops.repeat(
temporal_group_ids,
"t -> b (t d)",
b=batch_size,
d=num_channels
)
# Make sure different regions at same timestamps have same positional encoding
assert seqlen_timepoints == 1 or (
temporal_group_ids[0, 0] == temporal_group_ids[0, 1]
and temporal_group_ids[0, 0]
!= temporal_group_ids[
0, num_channels
]
)
position_ids = temporal_group_ids.clone()
return TokenizedBatchedItem(
tokens=tokens,
position_ids=position_ids,
spatial_group_ids=None,
temporal_group_ids=temporal_group_ids,
seq_lens=[tokens.shape[1]],
spatial_embeddings=spatial_encoding,
subject_sessions=[subject_session]
)
def forward(
self,
x: List,
subject_sessions: List,
output_as_list: bool = False,
add_spatial_encoding_to_tokens: bool = True,
) -> Union[TokenizedBatchedItem, List[TokenizedBatchedItem]]:
"""
Args:
x: A list of tensors each of shape (B_i, N_i, D_i)
B: Batch size
N: Time points
D: Channel dim
subject_sessions: list of strings corresponding to subject_session identifier
output_as_list: if True, will output a list of TokenizedBatchedItem, each correspond to one subject,
if False, will merge all as a long sequence
add_spatial_encoding_to_tokens: bool. Adds spatial encoding to tokens
Returns:
TokenizedBatchItem if output_as_list is False, else list of TokenizedBatchItem objects.
"""
passed_datapoints = 0
tokenized_items_list = []
for x_item in x:
tokenized_item = self._tokenize_for_batch_tensor(
x_item,
subject_sessions[passed_datapoints],
add_spatial_encoding_to_tokens=add_spatial_encoding_to_tokens,
)
tokenized_items_list.append(tokenized_item)
passed_datapoints += x_item.shape[0]
if output_as_list:
return tokenized_items_list
return TokenizedBatchedItem.get_as_one_sequence(tokenized_items_list)
|