| | import einops |
| | from omegaconf import DictConfig |
| | import torch |
| | import torch.nn as nn |
| | from typing import Dict, List, Union |
| |
|
| | import barista.models.spatial_encoder as spe |
| | from barista.data.metadata import Metadata |
| | from barista.models.mlp import MLP |
| | from barista.models.tokenized_batched_item import TokenizedBatchedItem |
| | from barista.models.TSEncoder2D import TSEncoder2D |
| |
|
| |
|
| | class Tokenizer(nn.Module): |
| | def __init__( |
| | self, |
| | config: DictConfig, |
| | metadata: Metadata, |
| | ): |
| | super().__init__() |
| |
|
| | self.metadata = metadata |
| | self.config = config |
| |
|
| | self.subjects = metadata.get_subjects() |
| |
|
| | self.num_subsegments = int( |
| | ( |
| | self.config.samp_frequency * self.config.num_seconds |
| | - self.config.temporal_subsegment_len |
| | ) |
| | // (self.config.temporal_subsegment_step) |
| | + 1 |
| | ) |
| |
|
| | self.dim_h = self.config.d_hidden |
| |
|
| | self._build_temporal_encoder() |
| |
|
| | self._build_temporal_pooler() |
| |
|
| | self._build_spatial_encoder() |
| |
|
| | def _build_temporal_encoder(self): |
| | self.config.temporal_encoder.input_dims = 1 |
| | self.config.temporal_encoder.output_dims = 1 |
| | self.temporal_encoder = TSEncoder2D(**self.config.temporal_encoder) |
| |
|
| | def _build_temporal_pooler(self): |
| | self.temporal_pooler = MLP( |
| | d_input=self.config.temporal_subsegment_len, |
| | d_out=self.dim_h, |
| | dropout=0.0, |
| | bias=False, |
| | ) |
| |
|
| | def _build_spatial_encoder(self): |
| | self.subject_session_spatial_groups = {} |
| | for sub_sesh in self.metadata.get_subject_session_d_input().keys(): |
| | spatial_grouping = self.metadata.get_spatial_grouping( |
| | subject_session=sub_sesh, name=self.config.spatial_grouping |
| | ) |
| | self.subject_session_spatial_groups[sub_sesh] = spatial_grouping |
| |
|
| | self.spatial_encoder = spe.create_spatial_encoder( |
| | dim_h=self.dim_h, |
| | subject_session_spatial_groups=self.subject_session_spatial_groups, |
| | embedding_max_dim=self.config.get('embedding_max_dim', None), |
| | embedding_init_scale=self.config.get('embedding_init_scale', 1.0), |
| | ) |
| |
|
| | def update_for_new_sessions( |
| | self, |
| | new_session_d_input_dict: Dict[str, int], |
| | new_metadata: Metadata, |
| | ) -> List: |
| | |
| | self.subject_session_spatial_groups = {} |
| | for sub_sesh in new_session_d_input_dict.keys(): |
| | spatial_grouping = new_metadata.get_spatial_grouping( |
| | subject_session=sub_sesh, name=self.config.spatial_grouping |
| | ) |
| | self.subject_session_spatial_groups[sub_sesh] = spatial_grouping |
| |
|
| | self.metadata = new_metadata |
| |
|
| |
|
| | new_params = [] |
| | if self.config.add_spatial_encoding: |
| | new_se_params = self.spatial_encoder.update_for_new_sessions( |
| | new_subject_session_spatial_groups=self.subject_session_spatial_groups |
| | ) |
| | |
| | new_params.extend([f"spatial_encoder.{n}" for n in new_se_params]) |
| | |
| | return new_params |
| |
|
| | def _tokenize_for_batch_tensor( |
| | self, |
| | x: Union[torch.Tensor, List], |
| | subject_session: str, |
| | add_spatial_encoding_to_tokens: bool = True, |
| | ) -> torch.tensor: |
| | """ |
| | Args: |
| | x: Input tensor of shape (B, N, D) or a list of tensors each of shape (N_i, D_i) |
| | B: Batch size |
| | N: Time points |
| | R: Channel dim |
| | |
| | Returns: |
| | Tokenized version of the same data as a TokenizedBatchedItem object. |
| | """ |
| | batch_size, num_timepoints, num_channels = x.shape |
| |
|
| | x = einops.rearrange(x, "b n d -> b d n") |
| | |
| | |
| | |
| | x = x.unfold( |
| | dimension=-1, |
| | size=self.config.temporal_subsegment_len, |
| | step=self.config.temporal_subsegment_step, |
| | ) |
| |
|
| | collapsed_x = einops.rearrange( |
| | x, "b d t n -> (b t d) n" |
| | ) |
| |
|
| | transposed_tokens = einops.rearrange( |
| | collapsed_x, "btd n -> 1 1 btd n" |
| | ) |
| |
|
| | collapsed_tokens = self.temporal_encoder(transposed_tokens) |
| | collapsed_tokens = collapsed_tokens.squeeze() |
| |
|
| | |
| | collapsed_tokens = self.temporal_pooler( |
| | collapsed_tokens |
| | ) |
| |
|
| | collapsed_tokens_full = collapsed_tokens |
| |
|
| | |
| | tokens = einops.rearrange( |
| | collapsed_tokens_full, |
| | "(b t d) dh -> b (t d) dh", |
| | b=batch_size, |
| | t=self.num_subsegments, |
| | ) |
| |
|
| | seqlen_timepoints = self.num_subsegments |
| |
|
| | if self.config.add_spatial_encoding: |
| | spatial_encoding = self.spatial_encoder( |
| | tokens, |
| | subject_session=subject_session, |
| | timepoints=seqlen_timepoints, |
| | ) |
| |
|
| | |
| | assert ( |
| | seqlen_timepoints == 1 |
| | or spatial_encoding[0, 0, 0] == spatial_encoding[0, num_channels, 0] |
| | ) |
| |
|
| | if add_spatial_encoding_to_tokens: |
| | tokens = tokens + spatial_encoding |
| |
|
| | else: |
| | spatial_encoding = None |
| |
|
| | temporal_group_ids = torch.arange(seqlen_timepoints, device=x.device) |
| | temporal_group_ids = einops.repeat( |
| | temporal_group_ids, |
| | "t -> b (t d)", |
| | b=batch_size, |
| | d=num_channels |
| | ) |
| | |
| | assert seqlen_timepoints == 1 or ( |
| | temporal_group_ids[0, 0] == temporal_group_ids[0, 1] |
| | and temporal_group_ids[0, 0] |
| | != temporal_group_ids[ |
| | 0, num_channels |
| | ] |
| | ) |
| |
|
| | position_ids = temporal_group_ids.clone() |
| |
|
| | return TokenizedBatchedItem( |
| | tokens=tokens, |
| | position_ids=position_ids, |
| | spatial_group_ids=None, |
| | temporal_group_ids=temporal_group_ids, |
| | seq_lens=[tokens.shape[1]], |
| | spatial_embeddings=spatial_encoding, |
| | subject_sessions=[subject_session] |
| | ) |
| |
|
| | def forward( |
| | self, |
| | x: List, |
| | subject_sessions: List, |
| | output_as_list: bool = False, |
| | add_spatial_encoding_to_tokens: bool = True, |
| | ) -> Union[TokenizedBatchedItem, List[TokenizedBatchedItem]]: |
| | """ |
| | Args: |
| | x: A list of tensors each of shape (B_i, N_i, D_i) |
| | B: Batch size |
| | N: Time points |
| | D: Channel dim |
| | subject_sessions: list of strings corresponding to subject_session identifier |
| | output_as_list: if True, will output a list of TokenizedBatchedItem, each correspond to one subject, |
| | if False, will merge all as a long sequence |
| | add_spatial_encoding_to_tokens: bool. Adds spatial encoding to tokens |
| | |
| | Returns: |
| | TokenizedBatchItem if output_as_list is False, else list of TokenizedBatchItem objects. |
| | """ |
| | passed_datapoints = 0 |
| | tokenized_items_list = [] |
| |
|
| | for x_item in x: |
| | tokenized_item = self._tokenize_for_batch_tensor( |
| | x_item, |
| | subject_sessions[passed_datapoints], |
| | add_spatial_encoding_to_tokens=add_spatial_encoding_to_tokens, |
| | ) |
| |
|
| | tokenized_items_list.append(tokenized_item) |
| | passed_datapoints += x_item.shape[0] |
| |
|
| | if output_as_list: |
| | return tokenized_items_list |
| |
|
| | return TokenizedBatchedItem.get_as_one_sequence(tokenized_items_list) |
| |
|