File size: 8,221 Bytes
a35137b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import einops
from omegaconf import DictConfig
import torch
import torch.nn as nn
from typing import Dict, List, Union

import barista.models.spatial_encoder as spe
from barista.data.metadata import Metadata
from barista.models.mlp import MLP
from barista.models.tokenized_batched_item import TokenizedBatchedItem
from barista.models.TSEncoder2D import TSEncoder2D


class Tokenizer(nn.Module):
    def __init__(
        self,
        config: DictConfig,
        metadata: Metadata,
    ):
        super().__init__()

        self.metadata = metadata
        self.config = config

        self.subjects = metadata.get_subjects()

        self.num_subsegments = int(
            (
                self.config.samp_frequency * self.config.num_seconds
                - self.config.temporal_subsegment_len
            )
            // (self.config.temporal_subsegment_step)
            + 1
        )

        self.dim_h = self.config.d_hidden

        self._build_temporal_encoder()

        self._build_temporal_pooler()

        self._build_spatial_encoder()

    def _build_temporal_encoder(self):
        self.config.temporal_encoder.input_dims = 1
        self.config.temporal_encoder.output_dims = 1
        self.temporal_encoder = TSEncoder2D(**self.config.temporal_encoder)

    def _build_temporal_pooler(self):
        self.temporal_pooler = MLP(
            d_input=self.config.temporal_subsegment_len,
            d_out=self.dim_h,
            dropout=0.0,
            bias=False,
        )

    def _build_spatial_encoder(self):
        self.subject_session_spatial_groups = {}
        for sub_sesh in self.metadata.get_subject_session_d_input().keys():
            spatial_grouping = self.metadata.get_spatial_grouping(
                subject_session=sub_sesh, name=self.config.spatial_grouping
            )
            self.subject_session_spatial_groups[sub_sesh] = spatial_grouping

        self.spatial_encoder = spe.create_spatial_encoder(
            dim_h=self.dim_h,
            subject_session_spatial_groups=self.subject_session_spatial_groups,
            embedding_max_dim=self.config.get('embedding_max_dim', None),
            embedding_init_scale=self.config.get('embedding_init_scale', 1.0),
        )

    def update_for_new_sessions(
        self,
        new_session_d_input_dict: Dict[str, int],
        new_metadata: Metadata,
    ) -> List:
        
        self.subject_session_spatial_groups = {}
        for sub_sesh in new_session_d_input_dict.keys():
            spatial_grouping = new_metadata.get_spatial_grouping(
                subject_session=sub_sesh, name=self.config.spatial_grouping
            )
            self.subject_session_spatial_groups[sub_sesh] = spatial_grouping

        self.metadata = new_metadata


        new_params = []
        if self.config.add_spatial_encoding:
            new_se_params = self.spatial_encoder.update_for_new_sessions(
                        new_subject_session_spatial_groups=self.subject_session_spatial_groups
                    )
            
            new_params.extend([f"spatial_encoder.{n}" for n in new_se_params])
        
        return new_params

    def _tokenize_for_batch_tensor(
        self,
        x: Union[torch.Tensor, List],
        subject_session: str,
        add_spatial_encoding_to_tokens: bool = True,
    ) -> torch.tensor:
        """
        Args:
            x: Input tensor of shape (B, N, D) or a list of tensors each of shape (N_i, D_i)
                B: Batch size
                N: Time points
                R: Channel dim

        Returns:
            Tokenized version of the same data as a TokenizedBatchedItem object.
        """
        batch_size, num_timepoints, num_channels = x.shape

        x = einops.rearrange(x, "b n d -> b d n")
    
        # NOTE that unfold doesn't copy the memory, so if step is less than size (sliding window)
        # and any of shared elements are changed, all occurance of that element in patches will change
        x = x.unfold(
            dimension=-1,
            size=self.config.temporal_subsegment_len,
            step=self.config.temporal_subsegment_step,
        )  # (B D num_subsegments subseg_len)

        collapsed_x = einops.rearrange(
            x, "b d t n -> (b t d) n"
        )  # (B * T * D, N)

        transposed_tokens = einops.rearrange(
            collapsed_x, "btd n -> 1 1 btd n"
        )  # (1, 1, B * T * D, N)

        collapsed_tokens = self.temporal_encoder(transposed_tokens)
        collapsed_tokens = collapsed_tokens.squeeze()  # (B * T * D, N)

        # "Time" dimension to hidden dimension. Using a fully connected layer here.
        collapsed_tokens = self.temporal_pooler(
            collapsed_tokens
        )  # (B * T * D, N) -> (B * T * D, HID_D)

        collapsed_tokens_full = collapsed_tokens

        # Create the time-space interleaved tokens.
        tokens = einops.rearrange(
            collapsed_tokens_full,
            "(b t d) dh -> b (t d) dh",
            b=batch_size,
            t=self.num_subsegments,
        )

        seqlen_timepoints = self.num_subsegments

        if self.config.add_spatial_encoding:
            spatial_encoding = self.spatial_encoder(
                tokens,
                subject_session=subject_session,
                timepoints=seqlen_timepoints,
            )

            # Make sure regions at differnet timestamps have same spatial encoding
            assert (
                seqlen_timepoints == 1
                or spatial_encoding[0, 0, 0] == spatial_encoding[0, num_channels, 0]
            )

            if add_spatial_encoding_to_tokens:
                    tokens = tokens + spatial_encoding

        else: # not self.config.add_spatial_encoding
            spatial_encoding = None

        temporal_group_ids = torch.arange(seqlen_timepoints, device=x.device)
        temporal_group_ids = einops.repeat(
            temporal_group_ids,
            "t -> b (t d)",
            b=batch_size,
            d=num_channels
        )
        # Make sure different regions at same timestamps have same positional encoding
        assert seqlen_timepoints == 1 or (
            temporal_group_ids[0, 0] == temporal_group_ids[0, 1]
            and temporal_group_ids[0, 0]
            != temporal_group_ids[
                0, num_channels
            ] 
        )

        position_ids = temporal_group_ids.clone()

        return TokenizedBatchedItem(
            tokens=tokens,
            position_ids=position_ids,
            spatial_group_ids=None,
            temporal_group_ids=temporal_group_ids,
            seq_lens=[tokens.shape[1]],
            spatial_embeddings=spatial_encoding,
            subject_sessions=[subject_session]
        )

    def forward(
        self,
        x: List,
        subject_sessions: List,
        output_as_list: bool = False,
        add_spatial_encoding_to_tokens: bool = True,
    ) -> Union[TokenizedBatchedItem, List[TokenizedBatchedItem]]:
        """
        Args:
            x: A list of tensors each of shape (B_i, N_i, D_i)
                B: Batch size
                N: Time points
                D: Channel dim
            subject_sessions: list of strings corresponding to subject_session identifier
            output_as_list: if True, will output a list of TokenizedBatchedItem, each correspond to one subject,
                            if False, will merge all as a long sequence
            add_spatial_encoding_to_tokens: bool. Adds spatial encoding to tokens

        Returns:
            TokenizedBatchItem if output_as_list is False, else list of TokenizedBatchItem objects.
        """
        passed_datapoints = 0
        tokenized_items_list = []

        for x_item in x:
            tokenized_item = self._tokenize_for_batch_tensor(
                x_item,
                subject_sessions[passed_datapoints],
                add_spatial_encoding_to_tokens=add_spatial_encoding_to_tokens,
            )

            tokenized_items_list.append(tokenized_item)
            passed_datapoints += x_item.shape[0]

        if output_as_list:
            return tokenized_items_list

        return TokenizedBatchedItem.get_as_one_sequence(tokenized_items_list)