File size: 10,444 Bytes
a35137b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

from abc import ABC, abstractmethod
import einops
import torch
import torch.nn as nn
from typing import Optional


class SpatialEncoderMeta:
    def __init__(self, subject_session_spatial_groups=None):
        """Metadata object with subject session information for spatial encoding."""
        self.subject_session_spatial_groups = subject_session_spatial_groups
        
    @property
    def num_region_info(self):
        n_effective_components_across_sessions = set(
            [a.n_effective_components for a in self.subject_session_spatial_groups.values()]
        )

        assert len(n_effective_components_across_sessions) == 1, (
            "Doesn't support variable number of effective components for different subject_sessions"
        )

        self._num_region_info = n_effective_components_across_sessions.pop()            
        return self._num_region_info

    @property
    def embedding_table_configs(self):
        configs = {}
        for i in range(self.num_region_info):
            n_embeddings_for_components_set = set(
                [a.max_elements_for_component[i] for a in self.subject_session_spatial_groups.values()]
            )
            padding_indices_set = set(
                [a.padding_indices[i] for a in self.subject_session_spatial_groups.values()]
            )
                                                
            assert len(n_embeddings_for_components_set) == 1, (
                "Doesn't support variable number of max components for different subject_sessions, "
                "change to use max of values across the subject if it is not important."
            )
            assert len(padding_indices_set) == 1, (
                "Doesn't support variable number of padding indices for different subject_sessions, "
                "change to use max of values across the subject if it is not important."
            )
            
            configs[i] = {
                'num_embeddings': n_embeddings_for_components_set.pop(),
                'padding_idx': padding_indices_set.pop()
            }

        return configs


class BaseSpatialEncoder(ABC, nn.Module):
    """Abstract class definition for spatial encoding modules.

    Implement this interface to try new spatial encoding approaches in the tokenizer.
    """
    _SUBJ_SESH_QUERY_HASH_STR = "{0}_queryvec"

    def __init__(
        self,
        dim_h: int,
        spatial_encoder_meta: SpatialEncoderMeta,
    ):
        super().__init__()
        self.dim_h = dim_h
        self.spatial_encoder_meta = spatial_encoder_meta

        self._construct_region_encoding_meta()

    def _construct_region_encoding_meta(self):
        """Constructs a hashmap of channel region information -> query vector for spatial encoding."""
        for (
            subject_session,
            spatial_groups,
        ) in self.spatial_encoder_meta.subject_session_spatial_groups.items():
            query_vector = torch.tensor(
                [tuple(map(int, e[:spatial_groups.n_effective_components])) for e in spatial_groups.group_components]
            )
            
            query_vector = self._transform_query_vector(query_vector)
            
            self.register_buffer(
                BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session),
                query_vector, persistent=False
            )

    def _transform_query_vector(self, query_vector: torch.Tensor):
        return query_vector

    def get_embedding_table_query_vector(self, subject_session: str) -> torch.Tensor:
        return self._buffers[BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session)].to(torch.long)

    def update_for_new_sessions(self, 
                                 new_subject_session_spatial_groups):
        self.spatial_encoder_meta.subject_session_spatial_groups = new_subject_session_spatial_groups
        self._construct_region_encoding_meta()
        return []
    
    @abstractmethod
    def _encode(self, x: torch.tensor) -> torch.tensor:
        pass

    @abstractmethod
    def _get_position_encoding(
        self, x: torch.tensor, subject_session: str
    ) -> torch.tensor:
        pass

    def forward(
        self,
        x: torch.tensor,
        subject_session: str,
        timepoints: int = 1,
        mask: torch.tensor = None,
    ) -> torch.tensor:
        """
        Args:
            x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.

        Returns:
            A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
                the input token x.
        """
        session_PE = self._get_position_encoding(x, subject_session)
        assert (
            x.shape[-1] == session_PE.shape[-1]
        ), f"Region dimension mismatch: {x.shape[-1]} vs {session_PE.shape[-1]}."

        position_encoding = einops.repeat(
            session_PE, "r d -> b (t r) d", b=x.shape[0], t=timepoints
        )

        if mask is not None:
            position_encoding = position_encoding[:, mask, :]

        assert (
            x.shape == position_encoding.shape
        ), "Output position encoding does not match in shape"
        return position_encoding


class EmbeddingTable(BaseSpatialEncoder):
    def __init__(
        self,
        dim_h: int,
        spatial_encoder_meta: SpatialEncoderMeta,
        embedding_max_dim: Optional[float] = None,
        embedding_init_scale: float = 1.0
    ):
        """A lookup table of different embeddings for different spatial fields."""
        super().__init__(dim_h, spatial_encoder_meta)

        # Create the embeddings.
        self.subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
        subcomponent_dims = self._get_subcomponent_dims()

        self.subcomponent_embeddings = nn.ModuleDict()
        for (
            subcomponent_ind,
            subcomponent_config,
        ) in self.subcomponent_embedding_info.items():
            subcomponent_dim = subcomponent_dims[subcomponent_ind]

            self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
                subcomponent_config["num_embeddings"],
                subcomponent_dim,
                padding_idx=subcomponent_config["padding_idx"],
                max_norm=embedding_max_dim,
            )

            self.init_weights_for_embeddings(
                self.subcomponent_embeddings[str(subcomponent_ind)],
                embedding_init_scale
            )

    @abstractmethod
    def _get_subcomponent_dims(self):
        raise NotImplementedError

    def update_for_new_sessions(self,  new_subject_session_spatial_groups):
        """Add need embedding table elements based on new subject session information."""
        new_params = super().update_for_new_sessions(new_subject_session_spatial_groups)
        
        subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
        for subcomponent_ind, subcomponent_config in subcomponent_embedding_info.items():
            prev_embeddings = self.subcomponent_embeddings[str(subcomponent_ind)]
            n_rows, subcomponent_dim = prev_embeddings.weight.shape
            
            if subcomponent_config['num_embeddings'] == n_rows:
                # no need to add any new embedding
                continue
            
            new_embeddings = torch.empty(
                subcomponent_config['num_embeddings'] - n_rows,
                subcomponent_dim,
                device=prev_embeddings.weight.device
            )
            nn.init.normal_(new_embeddings)

            new_data = torch.cat((prev_embeddings.weight.data, new_embeddings))

            self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
                subcomponent_config["num_embeddings"],
                subcomponent_dim,
                padding_idx=subcomponent_config["padding_idx"],
            )
            self.subcomponent_embeddings[str(subcomponent_ind)].weight.data = new_data

            new_params.extend([n for n, _ in self.named_parameters()])

        return new_params
    
    def init_weights_for_embeddings(self, embedding_table: nn.Embedding, embedding_init_scale: float = 1.0):
        nn.init.normal_(embedding_table.weight, std=embedding_init_scale)
        embedding_table._fill_padding_idx_with_zero()

    def _transform_query_vector(self, query_vector: torch.Tensor):
        return query_vector.to(torch.float).T

    def _get_position_encoding(
        self, _: torch.tensor, subject_session: str
    ) -> torch.tensor:
        """Returns the encoding vector based on a subject session query."""
        session_region_query = self.get_embedding_table_query_vector(
            subject_session
        )
        single_session_PE = self._encode(session_region_query)
        return single_session_PE


class EmbeddingTablePool(EmbeddingTable):
    def _get_subcomponent_dims(self):
        return {k: self.dim_h for k in self.subcomponent_embedding_info.keys()}
    
    def _encode(self, x: torch.tensor) -> torch.tensor:
        """
        Args:
            x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.

        Returns:
            A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
                the input token. If token has multiple spatial fields, the encoding for
                each of these fields will be summed together before being return (e.g.,
                x,y,z LPI coordinates).
        """
        PE = torch.zeros((x.shape[0], x.shape[1], self.dim_h), device=x.get_device())
        for subcomponent_ind in range(x.shape[0]):
            subcomponent_x = x[subcomponent_ind, ...]
            PE[subcomponent_ind, ...] = self.subcomponent_embeddings[
                str(subcomponent_ind)
            ](subcomponent_x)
        return torch.sum(PE, axis=0)


def create_spatial_encoder(
    dim_h: int,
    subject_session_spatial_groups=None,
    embedding_max_dim=None,
    embedding_init_scale=1.0,
) -> BaseSpatialEncoder:
    """Creates the spatial encoder and the cached spatial encoding information needed during forward passes."""
    spatial_encoder_meta = SpatialEncoderMeta(
        subject_session_spatial_groups
    )

    spatial_encoder = EmbeddingTablePool(
        dim_h,
        spatial_encoder_meta,
        embedding_max_dim,
        embedding_init_scale
    )

    return spatial_encoder