LeoChen085 commited on
Commit
0c7f3e3
·
verified ·
1 Parent(s): 17ea1ee

Upload SLIP model, checkpoints, and source code

Browse files
caption.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff90d912d788314a8d449b4a764c7ac52ca044c0702db303bfce094869d33623
3
+ size 1386043740
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "slip",
3
+ "architectures": [
4
+ "SLIP"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModel": "modeling_slip.SLIPModel"
8
+ },
9
+ "llm_model_name": "google/gemma-3-270m",
10
+ "max_llm_len": 768,
11
+ "num_img_queries": 64,
12
+ "num_heads": 5,
13
+ "caption_loss_weight": 1.0,
14
+ "contrastive_loss_weight": 1.0,
15
+ "use_lora": false,
16
+ "unlocked_layers": 4,
17
+ "split_layer": 12,
18
+ "common_dim": 640,
19
+ "post_train": true,
20
+ "sensor_encoder": {
21
+ "embed_dim": 768,
22
+ "num_heads": 12,
23
+ "mlp_ratio": 4,
24
+ "depth": 12,
25
+ "dropout_rate": 0.1,
26
+ "learnable_pos_emb": false,
27
+ "max_position_embeddings": 4880,
28
+ "patch_size": null,
29
+ "channel_attn_type": "all_attn"
30
+ }
31
+ }
ecg.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03d627bf5b7a4d0ce61803baa1726abe4dbed3bf6b9bf2c3f48d8f9eed060c37
3
+ size 1499488484
har.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bed649ee52aaa13efd27c922a544181181ff27362147c743c9d59d5e39974c7d
3
+ size 1386043740
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ceb7c446945dd61ddab80c82a0688e798e71f8840f1bba6b79c47dba0ae2ec5
3
+ size 1386043740
model_factory/SLIP.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py
2
+
3
+ import math
4
+
5
+ from sympy import shape
6
+ from omegaconf import DictConfig
7
+ import torch
8
+ torch._dynamo.config.capture_scalar_outputs = True
9
+ from torch import Tensor, einsum, nn
10
+ import torch.nn.functional as F
11
+ from torch.autograd import Function
12
+ import torch.distributed as dist
13
+ from einops import rearrange, repeat,reduce
14
+ from model_factory.multimodal_gemma import Gemma3MultimodalModel
15
+ import hydra
16
+ # for generation
17
+ from typing import Optional, List, Union
18
+ import contextlib
19
+ from transformers.generation.utils import GenerationMixin
20
+ from model_factory.ts_transformer import AttentionPooling
21
+
22
+ # helper functions
23
+
24
+ def exists(val):
25
+ return val is not None
26
+
27
+ def default(val, d):
28
+ return val if exists(val) else d
29
+
30
+ def masked_mean(t, mask, dim = 1, eps = 1e-6):
31
+ '''
32
+ t: B, L, D
33
+ mask: B, L, 1
34
+ '''
35
+ t = t.masked_fill(~mask, 0.)
36
+ numer = t.sum(dim = dim)
37
+ denom = mask.sum(dim = dim).clamp(min = eps)
38
+ return numer / denom
39
+
40
+ # helper metric: https://arxiv.org/pdf/2005.10242
41
+ def lalign(x, y, alpha=2):
42
+ # calculate the closness of the positive pairs.
43
+ return (x - y).norm(dim=1).pow(alpha).mean()
44
+
45
+ def lunif(x, t=2):
46
+ # calculate the uniformity of one side.
47
+ sq = torch.pdist(x, p=2).pow(2)
48
+ return sq.mul(-t).exp().mean().log()
49
+
50
+ # distributed
51
+ def pad_dim_to(t, length, dim = 0):
52
+ pad_length = length - t.shape[dim]
53
+ zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
54
+ return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
55
+
56
+ # https://huggingface.co/Qwen/Qwen3-Embedding-8B
57
+ def last_token_pool(last_hidden_states: Tensor,
58
+ attention_mask: Tensor) -> Tensor:
59
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
60
+ if left_padding:
61
+ return last_hidden_states[:, -1]
62
+ else:
63
+ sequence_lengths = attention_mask.sum(dim=1) - 1
64
+ batch_size = last_hidden_states.shape[0]
65
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
66
+
67
+ def all_gather_variable_batch(x):
68
+ """
69
+ All-gather variable sized tensors across DDP ranks.
70
+ x: [B_local, D]
71
+ Returns:
72
+ out: [sum(B_local across ranks), D]
73
+ sizes: python list of sizes per rank
74
+ """
75
+ world = dist.get_world_size()
76
+ rank = dist.get_rank()
77
+ device = x.device
78
+
79
+ # 1. Gather sizes
80
+ local_size = torch.tensor([x.shape[0]], device=device, dtype=torch.long)
81
+ all_sizes = [torch.zeros_like(local_size) for _ in range(world)]
82
+ dist.all_gather(all_sizes, local_size)
83
+ sizes = [int(s.item()) for s in all_sizes]
84
+
85
+ # 2. Pad local tensor to max size
86
+ max_size = max(sizes)
87
+ if local_size < max_size:
88
+ pad_len = max_size - local_size
89
+ padding = torch.zeros(pad_len, *x.shape[1:], device=device, dtype=x.dtype)
90
+ x_padded = torch.cat([x, padding], dim=0)
91
+ else:
92
+ x_padded = x
93
+
94
+ # 3. All-gather padded tensors
95
+ gathered = [torch.zeros_like(x_padded) for _ in range(world)]
96
+ dist.all_gather(gathered, x_padded)
97
+
98
+ # 4. Trim each rank's padded slice
99
+ trimmed = [g[:sizes[i]] for i, g in enumerate(gathered)]
100
+
101
+ # 5. Concatenate true global batch
102
+ out = torch.cat(trimmed, dim=0)
103
+ return out, sizes
104
+
105
+ class AllGather(Function):
106
+ @staticmethod
107
+ def forward(ctx, x):
108
+ assert dist.is_initialized() and dist.get_world_size() > 1
109
+ x, batch_sizes = all_gather_variable_batch(x)
110
+ ctx.batch_sizes = batch_sizes
111
+ return x
112
+
113
+ @staticmethod
114
+ def backward(ctx, grads):
115
+ batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
116
+ grads_by_rank = grads.split(batch_sizes, dim = 0)
117
+ return grads_by_rank[rank]
118
+
119
+ all_gather = AllGather.apply
120
+
121
+
122
+ # to latents
123
+ class EmbedToLatents(nn.Module):
124
+ def __init__(self, dim, dim_latents):
125
+ super().__init__()
126
+ self.to_latents = nn.Linear(dim, dim_latents, bias=False)
127
+
128
+ def forward(self, x):
129
+ latents = self.to_latents(x)
130
+ return F.normalize(latents, dim=-1)
131
+
132
+
133
+
134
+ class SLIP(nn.Module,GenerationMixin):
135
+ _is_stateful = False
136
+ def __init__(
137
+ self,
138
+ tokenizer=None, #legacy argument.
139
+ **kwargs
140
+ ):
141
+ super().__init__()
142
+
143
+ self.tokenizer = tokenizer
144
+ enc_cfg = kwargs['sensor_encoder_cfg']
145
+ if isinstance(enc_cfg, (DictConfig, dict)):
146
+ self.sensor_encoder = hydra.utils.instantiate(enc_cfg)
147
+ else:
148
+ self.sensor_encoder = enc_cfg
149
+
150
+
151
+ ############################################################
152
+ dim = self.sensor_encoder.embed_dim # 384
153
+ text_encoder = kwargs['llm_model_name']
154
+ self.embed_dim = dim
155
+ self.use_lora = kwargs.get('use_lora', True)
156
+ self.post_train = kwargs.get('post_train', True)
157
+ ##########################################
158
+
159
+ ## Text encoder ####
160
+ self.caption_loss_weight = kwargs['caption_loss_weight']
161
+ self.max_llm_len = kwargs['max_llm_len']
162
+ self.multimodalModel = Gemma3MultimodalModel(text_encoder,self.post_train)
163
+
164
+ if self.caption_loss_weight <= 0:
165
+ self.multimodalModel._truncate_to_unimodal()
166
+
167
+ unlocked_layers = kwargs.get('unlocked_layers', 0)
168
+ if unlocked_layers < 12: # 12 is the split layer
169
+ self.multimodalModel._lock_text(
170
+ unlocked_layers=unlocked_layers,
171
+ freeze_layer_norm=kwargs.get('freeze_layer_norm', True)
172
+ )
173
+
174
+ lm_dim = self.multimodalModel.hidden_size #640
175
+ self.lm_dim = lm_dim
176
+ common_dim = lm_dim # harcoded for now
177
+ # self.multimodalModel.model.gradient_checkpointing_enable()
178
+ #########################################
179
+
180
+ num_img_queries = kwargs.get('num_img_queries', 0)
181
+ if num_img_queries>0:
182
+ self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, common_dim))
183
+ self.img_attn_pool = AttentionPooling(
184
+ dim=common_dim,
185
+ context_dim=dim,
186
+ num_heads=kwargs['num_heads']) # pre-norm+post_norm
187
+
188
+ dim = common_dim
189
+
190
+ # normalize.
191
+ self.img_to_latents = EmbedToLatents(dim, common_dim)
192
+ self.text_to_latents = EmbedToLatents(common_dim, common_dim)
193
+
194
+
195
+ # learnable temperature
196
+ self.temperature = nn.Parameter(torch.tensor(math.log(1/0.07)))
197
+ self.temperature_max = math.log(1/0.07)
198
+ if self.use_sig_loss:
199
+ # default implementation
200
+ self.temperature = nn.Parameter(torch.tensor(math.log(10)))
201
+ #self.temperature_max = math.log(10)
202
+ self.temperature_max = 999 # trivally large, so no upper bound.
203
+ self.logit_bias = nn.Parameter(torch.ones([]) * -10)
204
+
205
+
206
+ # multimodal decoder #############
207
+ pad_token_id = self.tokenizer.pad_token_id
208
+ self.pad_token_id = pad_token_id
209
+ self.eos_token_id = self.tokenizer.eos_token_id
210
+
211
+ self.ce = nn.CrossEntropyLoss(ignore_index=pad_token_id)
212
+ self.contrastive_loss_weight = kwargs['contrastive_loss_weight']
213
+ ##################################
214
+
215
+ self._init_weights()
216
+ # whether in data parallel setting
217
+ self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1
218
+ # for name, param in self.named_parameters():
219
+ # if param.requires_grad:
220
+ # print(f"TRAINABLE: {name}")
221
+
222
+
223
+ def _init_weights(self):
224
+ def _init(m):
225
+ if isinstance(m, nn.Linear):
226
+ nn.init.xavier_uniform_(m.weight)
227
+ if m.bias is not None:
228
+ nn.init.constant_(m.bias, 0)
229
+ elif isinstance(m, nn.LayerNorm):
230
+ nn.init.constant_(m.bias, 0)
231
+ nn.init.constant_(m.weight, 1.0)
232
+
233
+ # apply only to modules we added
234
+ self.img_to_latents.apply(_init)
235
+ self.text_to_latents.apply(_init)
236
+
237
+ if hasattr(self, 'img_attn_pool'):
238
+ self.img_attn_pool.apply(_init)
239
+ nn.init.xavier_uniform_(self.img_queries)
240
+
241
+ def get_lora_parameters(self): # make training script happy
242
+ """
243
+ Gathers:
244
+ 1. LoRA weights (A and B matrices) inside Gemma.
245
+ 2. Full-parameter updated 'modules_to_save' (Embeddings/Head).
246
+ 3. Full-parameter updated Cross-Attention blocks.
247
+ 4. Bridge layers (img_to_latents, text_to_latents, etc.).
248
+ """
249
+ if not self.use_lora:
250
+ return []
251
+
252
+ trainable_params = []
253
+
254
+ # 1. Check the multimodal LLM (Gemma + LoRA + Cross-Attn)
255
+ for name, param in self.multimodalModel.named_parameters():
256
+ if param.requires_grad:
257
+ trainable_params.append(param)
258
+
259
+ # 2. Check the Bridge modules
260
+ bridge_modules = [self.img_to_latents, self.text_to_latents]
261
+ if hasattr(self, 'img_attn_pool'):
262
+ bridge_modules.append(self.img_attn_pool)
263
+
264
+ for module in bridge_modules:
265
+ for param in module.parameters():
266
+ if param.requires_grad:
267
+ trainable_params.append(param)
268
+
269
+ # 3. Check the Queries and Sensor Encoder
270
+ if hasattr(self, 'img_queries') and self.img_queries.requires_grad:
271
+ trainable_params.append(self.img_queries)
272
+
273
+ # Optionally add sensor_encoder if you haven't locked it
274
+ for param in self.sensor_encoder.parameters():
275
+ if param.requires_grad:
276
+ trainable_params.append(param)
277
+
278
+ return trainable_params
279
+
280
+ def _pad_to_len(self, x, max_len):
281
+ # pad along dim 1 to max_len with zeros
282
+ if x.dim() == 3:
283
+ # [B, L, D]
284
+ pad_len = max_len - x.size(1)
285
+ if pad_len > 0:
286
+ pad = x.new_zeros(x.size(0), pad_len, x.size(2))
287
+ x = torch.cat([pad, x], dim=1)
288
+
289
+ elif x.dim() == 2:
290
+ # [B, L] case such as masks
291
+ pad_len = max_len - x.size(1)
292
+ if pad_len > 0:
293
+ pad = x.new_zeros(x.size(0), pad_len)
294
+ x = torch.cat([pad, x], dim=1)
295
+ return x
296
+
297
+ def _gather_features(self, img, txt, gather_with_grad=False):
298
+ """Return all features if DDP, else inputs. Same batch size per rank assumed."""
299
+ if not (dist.is_available() and dist.is_initialized()):
300
+ return img, txt
301
+
302
+ ### prepare for gathering ###
303
+ #
304
+ # get max length across ranks for padding.
305
+ img_len = torch.tensor([img.size(1)], device=img.device, dtype=torch.long)
306
+ txt_len = torch.tensor([txt.size(1)], device=txt.device, dtype=torch.long)
307
+ dist.all_reduce(img_len, op=dist.ReduceOp.MAX)
308
+ dist.all_reduce(txt_len, op=dist.ReduceOp.MAX)
309
+ max_img_len = int(img_len.item())
310
+ max_txt_len = int(txt_len.item())
311
+
312
+ img = self._pad_to_len(img, max_img_len)
313
+ txt = self._pad_to_len(txt, max_txt_len)
314
+ #################################
315
+
316
+ if gather_with_grad:
317
+ # keep grad across ranks
318
+ all_img = all_gather(img)
319
+ all_txt = all_gather(txt)
320
+ else:
321
+ # no grad path, saves memory
322
+ ws = dist.get_world_size()
323
+ outs_i = [torch.empty_like(img) for _ in range(ws)]
324
+ outs_t = [torch.empty_like(txt) for _ in range(ws)]
325
+
326
+ try:
327
+ dist.all_gather(outs_i, img.contiguous())
328
+ dist.all_gather(outs_t, txt.contiguous())
329
+
330
+ except Exception as e:
331
+ print("Error occurred while gathering features:", e)
332
+
333
+ outs_i[dist.get_rank()] = img
334
+ outs_t[dist.get_rank()] = txt
335
+ all_img = torch.cat(outs_i, dim=0)
336
+ all_txt = torch.cat(outs_t, dim=0)
337
+
338
+ return all_img, all_txt
339
+
340
+ def embed_text(self,
341
+ input_ids,
342
+ attention_mask,
343
+ text_embed=None):
344
+ '''
345
+ need to make this casual to avoid representation leak.
346
+
347
+ text: (BS, llm_seq_len) token_ids
348
+ attn_mask: (Bs, llm_seq_len)
349
+ '''
350
+
351
+ if text_embed is not None:
352
+ hidden_states = text_embed # (BS, max_seq_len, lm_dim)
353
+
354
+ else:
355
+ outputs = self.llm(
356
+ input_ids=input_ids,
357
+ attention_mask=attention_mask,
358
+ return_dict=True,
359
+ output_hidden_states=False, # Set to False or remove
360
+ # use_cache=False # Ensure cache is off for training/gradient ckpt
361
+ )
362
+
363
+ hidden_states = outputs.last_hidden_state
364
+
365
+ return hidden_states
366
+
367
+
368
+
369
+ def embed_sensor(self, sensors, sensor_attn_mask=None, time_index=None):
370
+ '''
371
+ sensors: (BS, num_channels, L)
372
+ '''
373
+
374
+ sensor_tokens, attn_mask = self.sensor_encoder(sensors, sensor_attn_mask, time_index=time_index)
375
+ # sensor_tokens: Bs,(nvar, num_p), img_dim
376
+ # attn_mask: BS, nvar, num_p
377
+
378
+ if hasattr(self, 'img_attn_pool'):
379
+ img_queries = repeat(self.img_queries, 'n d -> b n d', b=sensor_tokens.shape[0])
380
+ sensor_tokens = self.img_attn_pool(img_queries, sensor_tokens,attn_mask)
381
+
382
+ return sensor_tokens, attn_mask.bool()
383
+
384
+ # use an openCLIP implementation
385
+ def forward_loss(self,
386
+ text_hidden,
387
+ sensor_hidden,
388
+ sensor_mask,
389
+ gather_with_grad=False):
390
+
391
+ '''
392
+ text_embd: tuple of (text_cls, text_tokens)
393
+ sensor_embed: tuple of (sensor_cls, sensor_tokens)
394
+ sensor_mask: (BS, nvar, num_p)
395
+ '''
396
+
397
+ # global features
398
+ if hasattr(self, 'img_attn_pool'):
399
+ # use cls token
400
+ sensor_hidden = sensor_hidden[:, 0, :]
401
+ else:
402
+ sensor_hidden = masked_mean(sensor_hidden, rearrange(sensor_mask, 'b n p -> b (n p) 1'), dim=1) # BS, img_dim
403
+
404
+ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
405
+ world = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
406
+
407
+ if world > 1:
408
+ all_img, all_txt = self._gather_features(sensor_hidden, text_hidden, gather_with_grad=gather_with_grad)
409
+ else:
410
+ all_img, all_txt = sensor_hidden, text_hidden
411
+
412
+
413
+ contrastive_loss = self.CLIP_loss(all_txt, all_img)*self.contrastive_loss_weight
414
+
415
+ # some supplementry losses
416
+ align_loss = lalign(all_txt, all_img)
417
+ unif_txt = lunif(all_txt)
418
+ unif_img = lunif(all_img)
419
+
420
+ outputs = {
421
+ "loss": contrastive_loss,
422
+ 'contrastive_loss': contrastive_loss,
423
+ "align_loss": align_loss,
424
+ "unif_txt": unif_txt,
425
+ "unif_img": unif_img,
426
+ }
427
+
428
+ return outputs
429
+
430
+
431
+ def CLIP_loss(
432
+ self,
433
+ text_cls,
434
+ sensor_cls,):
435
+
436
+ temperature = (self.temperature.clamp(max=self.temperature_max)).exp()
437
+ logits_t2i = temperature * (text_cls @ sensor_cls.t()) # [B_global, B_global]
438
+ targets = torch.arange(logits_t2i.size(0), device=sensor_cls.device)
439
+ contrastive_loss = 0.5 * (
440
+ F.cross_entropy(logits_t2i, targets) +
441
+ F.cross_entropy(logits_t2i.t(), targets)
442
+ )
443
+
444
+ return contrastive_loss
445
+
446
+ def sig_loss(self, text_hidden, sensor_hidden, sensor_mask):
447
+ '''
448
+ SigLip Loss: Decoupling contrastive-loss with batch size
449
+ text_hidden: (BS, dim)
450
+ sensor_hidden: (BS, sensor_len, dim)
451
+ text_mask: (BS, text_len)
452
+ sensor_mask: (BS, sensor_len)
453
+ '''
454
+
455
+ if hasattr(self, 'img_attn_pool'):
456
+ # use cls token
457
+ sensor_hidden = sensor_hidden[:, 0, :]
458
+ else:
459
+ sensor_hidden = masked_mean(sensor_hidden, rearrange(sensor_mask, 'b n p -> b (n p) 1'), dim=1) # BS, img_dim
460
+
461
+
462
+ logit_scale = self.temperature.clamp(max=self.temperature_max).exp()
463
+ loss = self._sig_loss(sensor_hidden, text_hidden, logit_scale, self.logit_bias)
464
+
465
+ return {'loss': loss, 'contrastive_loss': loss}
466
+
467
+
468
+ def forward(
469
+ self,
470
+ text,
471
+ sensors,
472
+ prompt=None, # legacy input
473
+ return_embeddings=False,
474
+ ):
475
+
476
+
477
+ sensor_hidden, sensor_mask = self.embed_sensor(sensors=sensors['input_ids'],
478
+ sensor_attn_mask=sensors['attention_mask'], # this is pixel-level mask
479
+ time_index=sensors['time_index'])
480
+
481
+ # sensor_hidden: (BS, num_sensor_token, dim)
482
+ self.multimodalModel.condition_image(sensor_hidden)
483
+ text_hidden, logits = self.multimodalModel(input_ids=text['input_ids'][:,:-1],
484
+ attention_mask=text['attention_mask'][:,:-1], )
485
+ # text_sentence_embed: (BS, dim)
486
+ # logits: (BS, pred_len, vocab_size)
487
+
488
+ labels = text['input_ids'][:,1:] # bs, pred_len
489
+ #logits = rearrange(logits, 'b n c -> b c n') # bs, vocab_size, pred_len
490
+
491
+ text_hidden = self.text_to_latents(text_hidden)
492
+ sensor_hidden = self.img_to_latents(sensor_hidden)
493
+
494
+ if self.use_sig_loss:
495
+ loss_dict = self.sig_loss(text_hidden,
496
+ sensor_hidden,
497
+ sensor_mask)
498
+ else:
499
+ # This branch will need all-gather.
500
+ loss_dict = self.forward_loss(text_hidden,
501
+ sensor_hidden,
502
+ sensor_mask,)
503
+
504
+
505
+ if self.caption_loss_weight > 0:
506
+ loss_logits = logits.reshape(-1, logits.size(-1)) # Shape: [BS * Seq, Vocab]
507
+ loss_labels = labels.reshape(-1) # Shape: [BS * Seq]
508
+ caption_loss = self.ce(loss_logits, loss_labels) * self.caption_loss_weight
509
+
510
+ loss_dict['caption_loss'] = caption_loss
511
+ loss_dict['loss'] = loss_dict['contrastive_loss'] + caption_loss
512
+
513
+
514
+ return loss_dict
515
+
516
+ def _lock_sensor(self,):
517
+ # Freeze all sensor-related parameters (cross-attn blocks)
518
+ for name, param in self.sensor_encoder.named_parameters():
519
+ param.requires_grad = False
520
+
521
+ def sft_training(self,text,sensors,return_output=False):
522
+ sensor_hidden, _ = self.embed_sensor(sensors=sensors['input_ids'],
523
+ sensor_attn_mask=sensors['attention_mask'],
524
+ time_index=sensors['time_index'])
525
+
526
+ # sensor_hidden: (BS, num_sensor_token, dim)
527
+ self.multimodalModel.condition_image(sensor_hidden)
528
+
529
+ # debugging code.
530
+ # sample_text = text['input_ids'][0]
531
+ # sample_label = text['labels'][0]
532
+ # # make the -100 to be the pad token id for decoding
533
+ # sample_label = torch.where(sample_label==-100, self.tokenizer.pad_token_id, sample_label)
534
+ # print('sample text:', self.tokenizer.decode(sample_text))
535
+ # print('sample label:', self.tokenizer.decode(sample_label))
536
+ # exit()
537
+
538
+
539
+ outputs = self.multimodalModel.model(input_ids=text['input_ids'],
540
+ attention_mask=text['attention_mask'],
541
+ return_dict=True,)
542
+ # labels=text['labels'], )
543
+ if return_output:
544
+ return outputs
545
+
546
+ logits = outputs.logits # (BS, pred_len, vocab_size)
547
+ labels = text['labels'] # (BS, pred_len)
548
+ # shift for causal lm
549
+ shift_logits = logits[:, :-1, :].contiguous()
550
+ shift_labels = labels[:, 1:].contiguous()
551
+
552
+ # flatten logits for efficiency
553
+ logss_logits = shift_logits.view(-1, shift_logits.size(-1)) # Shape: [BS * Seq, Vocab]
554
+ loss_labels = shift_labels.view(-1) # Shape: [BS * Seq]
555
+
556
+ # define a new loss for stf
557
+ ce = torch.nn.functional.cross_entropy(
558
+ logss_logits,
559
+ loss_labels,
560
+ reduction='none',
561
+ ignore_index=-100,
562
+ )
563
+
564
+ if 'loss_weights' in text:
565
+ loss_weights = text['loss_weights']
566
+ loss_weights = loss_weights[:,1:].contiguous()
567
+ loss_weights = loss_weights.view(-1) # Shape: [BS * Seq]
568
+
569
+ # apply weights
570
+ weighted_ce = ce * loss_weights
571
+ loss = weighted_ce.sum() / loss_weights.sum()
572
+
573
+ else:
574
+ loss = ce.mean()
575
+
576
+ return {'loss': loss}
577
+
578
+ def generate(self,
579
+ text,
580
+ sensors,
581
+ **generate_kwargs):
582
+ """
583
+ Generates text conditioned on image embeddings.
584
+ """
585
+
586
+ sensor_hidden, _ = self.embed_sensor(sensors=sensors['input_ids'],
587
+ sensor_attn_mask=sensors['attention_mask'], # this is pixel-level mask
588
+ time_index=sensors['time_index'])
589
+
590
+ self.multimodalModel.condition_image(sensor_hidden)
591
+
592
+ generated_text = self.multimodalModel.model.generate(
593
+ input_ids=text['input_ids'],
594
+ attention_mask=text['attention_mask'],
595
+ max_new_tokens=300,
596
+ do_sample=False,
597
+ num_beams=1,
598
+ early_stopping=False,
599
+ )
600
+
601
+ return generated_text
602
+
603
+
604
+ @ torch.no_grad()
605
+ def get_embedding(self,text,sensors):
606
+ sensor_hidden, sensor_mask = self.embed_sensor(sensors=sensors['input_ids'],
607
+ sensor_attn_mask=sensors['attention_mask'], # this is pixel-level mask
608
+ time_index=sensors['time_index'])
609
+
610
+ self.multimodalModel.condition_image(sensor_hidden)
611
+ text_hidden, _ = self.multimodalModel(input_ids=text['input_ids'][:,:-1],
612
+ attention_mask=text['attention_mask'][:,:-1], )
613
+
614
+ text_hidden = self.text_to_latents(text_hidden)
615
+ sensor_hidden = self.img_to_latents(sensor_hidden)
616
+
617
+ if hasattr(self, 'img_attn_pool'):
618
+ # use cls token
619
+ sensor_hidden = sensor_hidden[:, 0, :]
620
+ else:
621
+ sensor_hidden = masked_mean(sensor_hidden, rearrange(sensor_mask, 'b n p -> b (n p) 1'), dim=1) # BS, img_dim # (BS, dim)
622
+
623
+ return text_hidden, sensor_hidden
624
+
625
+ @ torch.no_grad()
626
+ def get_sensor_embedding(self,input_ids,mask,time_index):
627
+ sensor_hidden, sensor_mask = self.embed_sensor(sensors=input_ids,
628
+ sensor_attn_mask=mask,
629
+ time_index=time_index)
630
+ sensor_hidden = self.img_to_latents(sensor_hidden)
631
+
632
+ if hasattr(self, 'img_attn_pool'):
633
+ # use cls token
634
+ sensor_hidden = sensor_hidden[:, 0, :]
635
+ else:
636
+ sensor_hidden = masked_mean(sensor_hidden, rearrange(sensor_mask, 'b n p -> b (n p) 1'), dim=1) # BS, img_dim
637
+
638
+ return sensor_hidden
639
+
640
+ @ torch.no_grad()
641
+ def get_text_embedding(self,text):
642
+ text_mask = text['attention_mask']
643
+ text_hidden = self.embed_text(text['input_ids'],
644
+ attention_mask=text_mask,)
645
+
646
+ text_hidden = self.text_to_latents(text_hidden)
647
+
648
+ if self.llm.config.pooler == 'mean':
649
+ text_hidden = masked_mean(text_hidden, rearrange(text_mask, 'b l -> b l 1').bool(), dim=1) # BS, lm_dim
650
+ else:
651
+ text_hidden = last_token_pool(text_hidden, text_mask) # (BS, dim)
652
+
653
+ return text_hidden
654
+
655
+ def get_multimodal_feature(self, question, sensors):
656
+ sensor_hidden, sensor_mask = self.embed_sensor(sensors=sensors['input_ids'],
657
+ sensor_attn_mask=sensors['attention_mask'], # this is pixel-level mask
658
+ time_index=sensors['time_index'])
659
+
660
+ # sensor_hidden: (BS, num_sensor_token, dim)
661
+ self.multimodalModel.condition_image(sensor_hidden)
662
+ outputs = self.multimodalModel(input_ids=question['input_ids'],
663
+ attention_mask=question['attention_mask'],
664
+ return_embeddings=True)
665
+ # text_sentence_embed: (BS, dim)
666
+ # logits: (BS, pred_len, vocab_size)
667
+ multimodal_hidden = outputs.hidden_states[-1][:,-1,:] # (BS, dim)
668
+
669
+ return multimodal_hidden
670
+
671
+
672
+
673
+
674
+ class Config(dict):
675
+ def __getattr__(self, key):
676
+ return self[key]
677
+
678
+
model_factory/__init__.py ADDED
File without changes
model_factory/multimodal_gemma.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+ from model_factory.ts_transformer import CrossAttention
7
+
8
+ class Residual(nn.Module):
9
+ def __init__(self, fn):
10
+ super().__init__()
11
+ self.fn = fn
12
+
13
+ def forward(self, x, *args, **kwargs):
14
+ return self.fn(x, *args, **kwargs) + x
15
+
16
+ class Gemma3MultimodalLayer(nn.Module):
17
+ def __init__(self, original_layer, cross_attn_block):
18
+ super().__init__()
19
+ self.original_layer = original_layer
20
+ self.cross_attn_block = cross_attn_block
21
+ self.vis_x = None
22
+
23
+ def condition_vis_x(self, vis_x):
24
+ self.vis_x = vis_x
25
+
26
+ def __getattr__(self, name):
27
+ """Forward all unknown attributes to the original layer."""
28
+ # This is CRITICAL for 'attention_type' and other internal HF flags
29
+ try:
30
+ return super().__getattr__(name)
31
+ except AttributeError:
32
+ return getattr(self.original_layer, name)
33
+
34
+ def forward(self, hidden_states, **kwargs):
35
+ # 1. Run the original unimodal Gemma Layer (Self-Attn + MLP)
36
+ # have to have self.vis_x
37
+ assert self.vis_x is not None, "vis_x must be set before forward pass."
38
+
39
+ outputs = self.original_layer(hidden_states, **kwargs) # gemma layer output
40
+ hidden_states = outputs[0]
41
+ hidden_states = self.cross_attn_block(hidden_states, context=self.vis_x)
42
+
43
+ return (hidden_states,) + outputs[1:] # make hf happy
44
+
45
+
46
+ class Gemma3MultimodalModel(nn.Module):
47
+ def __init__(self,
48
+ model_id="google/gemma-3-270m",
49
+ post_train = True,
50
+ split_layer=12):
51
+ super().__init__()
52
+ self.model = AutoModelForCausalLM.from_pretrained(
53
+ model_id,
54
+ dtype=torch.bfloat16,
55
+ attn_implementation="flash_attention_2",
56
+ trust_remote_code=True
57
+ )
58
+
59
+ if post_train:
60
+ # Load pre-trained weights
61
+ self.model = AutoModelForCausalLM.from_pretrained(
62
+ model_id,
63
+ dtype=torch.bfloat16,
64
+ attn_implementation="flash_attention_2",
65
+ trust_remote_code=True
66
+ )
67
+ else:
68
+ # INITIALIZE FROM SCRATCH
69
+ print(f"Initializing {model_id} from SCRATCH (Random Weights)...")
70
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
71
+ self.model = AutoModelForCausalLM.from_config(
72
+ config,
73
+ torch_dtype=torch.bfloat16,
74
+ attn_implementation="flash_attention_2",
75
+ trust_remote_code=True
76
+ )
77
+
78
+
79
+ self.split_layer = split_layer
80
+ self.device = self.model.device
81
+
82
+ # Initialize and insert cross-attention
83
+ hidden_size = self.model.config.hidden_size # 640
84
+ num_heads = self.model.config.num_attention_heads
85
+ self.hidden_size = hidden_size
86
+
87
+ for i in range(split_layer, len(self.model.model.layers)):
88
+ # Create the specific cross-attn block for this layer
89
+ cross_attn = CrossAttention(
90
+ dim=hidden_size,
91
+ context_dim=hidden_size,
92
+ num_heads=num_heads,
93
+ dropout_rate=0.1
94
+ )
95
+
96
+ # Wrap the original layer
97
+ original_layer = self.model.model.layers[i]
98
+ self.model.model.layers[i] = Gemma3MultimodalLayer(
99
+ original_layer,
100
+ Residual(cross_attn)
101
+ )
102
+
103
+ self.to(torch.bfloat16)
104
+
105
+ def condition_image(self, image_embeds):
106
+ """Passes image embeddings (Bs, img_q, 640) to layers 12+"""
107
+ # Ensure we match the model's device and dtype
108
+ self.image_embeds = image_embeds.to(next(self.parameters()).device, dtype=torch.bfloat16)
109
+
110
+ for layer in self.model.model.layers:
111
+ if isinstance(layer, Gemma3MultimodalLayer):
112
+ layer.condition_vis_x(self.image_embeds)
113
+
114
+ def forward(self,
115
+ input_ids,
116
+ attention_mask=None,
117
+ return_embeddings=False,
118
+ **kwargs):
119
+ # HF Forward
120
+ outputs = self.model(
121
+ input_ids=input_ids,
122
+ attention_mask=attention_mask,
123
+ output_hidden_states=True,
124
+ **kwargs
125
+ )
126
+
127
+ # Extraction for contrastive learning
128
+ # Index split_layer gives the output of (split_layer - 1)
129
+ # e.g., index 12 = output of Layer 11
130
+ unimodal_hidden_states = outputs.hidden_states[self.split_layer]
131
+ text_sentence_embedding = unimodal_hidden_states[:, -1, :]
132
+
133
+ if return_embeddings:
134
+ return outputs
135
+ else:
136
+ return text_sentence_embedding, outputs.logits
137
+
138
+ def _lock_text(self,
139
+ unlocked_layers: int = 0,
140
+ freeze_layer_norm: bool = True):
141
+ """
142
+ Locks the unimodal encoder.
143
+ unlocked_layers: How many unimodal layers (counting back from split_layer) to keep trainable.
144
+ freeze_layer_norm: Whether to freeze Norm layers (RMSNorm/LayerNorm).
145
+ """
146
+ # 1. Ensure the Multimodal Decoder and Head are ALWAYS trainable
147
+ for param in self.model.parameters():
148
+ param.requires_grad = True
149
+
150
+ # 2. Identify Unimodal components
151
+ embeddings = self.model.model.embed_tokens
152
+ unimodal_layer_list = self.model.model.layers[:self.split_layer]
153
+ modules = [embeddings, *unimodal_layer_list]
154
+
155
+ if unlocked_layers > 0:
156
+ modules_to_freeze = modules[:-unlocked_layers]
157
+ else:
158
+ modules_to_freeze = modules
159
+
160
+ first_unlocked_layer_idx = (len(modules) - unlocked_layers) - 1
161
+
162
+ print(f"Locking {len(modules_to_freeze)} unimodal modules (Embeddings + Layers 0 to {first_unlocked_layer_idx - 1}).")
163
+ print(f"Unimodal layers {max(0, first_unlocked_layer_idx)} to {self.split_layer - 1} remain trainable.")
164
+
165
+ # 4. Perform Freezing
166
+ for module in modules_to_freeze:
167
+ for n, p in module.named_parameters():
168
+ is_norm = any(x in n.split(".") for x in ["norm", "LayerNorm", "input_layernorm", "post_attention_layernorm"])
169
+
170
+ if is_norm:
171
+ p.requires_grad = not freeze_layer_norm
172
+ else:
173
+ p.requires_grad = False
174
+
175
+ def _truncate_to_unimodal(self):
176
+ """
177
+ Deletes all layers from split_layer onwards, keeping only the
178
+ unimodal layers (0 to split_layer-1).
179
+ """
180
+ # 1. Physically remove the layers (indices split_layer to end)
181
+ # This deletes the Gemma3MultimodalLayer wrappers and their weights
182
+ self.model.model.layers = nn.ModuleList(self.model.model.layers[:self.split_layer])
183
+
184
+ # 2. Update the config so the model handles the new length correctly
185
+ # (This ensures the final layer-norm and LM-head use the correct hidden state)
186
+ self.model.config.num_hidden_layers = self.split_layer
187
+
188
+ # 3. Cleanup image references
189
+ if hasattr(self, 'image_embeds'):
190
+ del self.image_embeds
191
+
192
+ print(f"Multimodal layers deleted. Model truncated to {self.split_layer} layers.")
model_factory/ts_transformer.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://huggingface.co/thuml/sundial-base-128m/blob/main/modeling_sundial.py
2
+
3
+ import contextlib
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional, Tuple, List, Union
8
+ from util.pos_embed import RotaryEmbedding, apply_rotary_pos_emb,apply_rotary_pos_emb_2d, build_2d_position_ids
9
+ from transformers.activations import ACT2FN
10
+ from einops import rearrange,reduce
11
+
12
+ class TsRoPEAttention(nn.Module):
13
+ def __init__(self, layer_idx: int, **cfg):
14
+ super().__init__()
15
+ self.layer_idx = layer_idx
16
+ self.hidden_size = cfg.get("embed_dim", 768)
17
+ self.num_heads = cfg.get("num_heads", 12)
18
+ self.head_dim = self.hidden_size // self.num_heads
19
+ self.attention_dropout = cfg.get("dropout_rate", 0.1)
20
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
21
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
22
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
23
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
24
+ # 2d RoPE
25
+ self.rotary_emb = RotaryEmbedding(
26
+ self.head_dim//2, max_position_embeddings=cfg.get("max_position_embeddings"))
27
+
28
+ def forward(
29
+ self,
30
+ hidden_states: torch.Tensor,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ **kwargs,
33
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
34
+ '''
35
+ hidden_states: [bs, seq_len, hidden_size]
36
+ attention_mask: [bs, nvar, num_p]
37
+ '''
38
+ bsz, q_len, _ = hidden_states.size()
39
+
40
+ tmp_attn_mask = rearrange(attention_mask, 'b nvar p -> b (nvar p)')
41
+ query_states = self.q_proj(hidden_states)
42
+ key_states = self.k_proj(hidden_states)
43
+ value_states = self.v_proj(hidden_states) # Bs, L, hidden_size
44
+
45
+ query_states = query_states.view(
46
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
47
+ key_states = key_states.view(
48
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
49
+ value_states = value_states.view(
50
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
51
+
52
+ tmp_attn_mask = tmp_attn_mask.unsqueeze(1).unsqueeze(2).expand(-1, 1, q_len, q_len).bool() # bs, 1, L, L
53
+
54
+ pos_var, pos_patch = build_2d_position_ids(attention_mask,flatten=True)
55
+ q_h = query_states[..., : self.head_dim // 2]
56
+ q_w = query_states[..., self.head_dim // 2 :]
57
+ cos_h, sin_h = self.rotary_emb(q_h, seq_len=int(pos_var.max().item()) + 1)
58
+ cos_w, sin_w = self.rotary_emb(q_w, seq_len=int(pos_patch.max().item()) + 1)
59
+
60
+ query_states, key_states = apply_rotary_pos_emb_2d(
61
+ query_states, key_states,
62
+ cos_h, sin_h,
63
+ cos_w, sin_w,
64
+ pos_var, pos_patch
65
+ )
66
+
67
+ attn_output = F.scaled_dot_product_attention(
68
+ query_states,
69
+ key_states,
70
+ value_states,
71
+ tmp_attn_mask,
72
+ dropout_p=self.attention_dropout
73
+ )
74
+
75
+ attn_output = attn_output.transpose(1, 2).contiguous()
76
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
77
+ attn_output = self.o_proj(attn_output)
78
+
79
+
80
+ return attn_output
81
+
82
+ # helper function
83
+ def flatten_list(input_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
84
+ """
85
+ Flatten a nested list of lists into a single list.
86
+ Args:
87
+ input_list (List[List[Tensor]]): Nested list to flatten.
88
+ Returns:
89
+ List[Tensor]: Flattened list.
90
+ """
91
+ return [item for sublist in input_list for item in sublist]
92
+
93
+ class MultiSizePatchEmbed(nn.Module):
94
+ def __init__(self, base_patch=32, **cfg):
95
+ super().__init__()
96
+
97
+ self.base_patch = base_patch
98
+ hidden_size = cfg['embed_dim']
99
+ intermediate_size = cfg['mlp_ratio'] * hidden_size # 3072
100
+ self.intermediate_size = intermediate_size
101
+ self.hidden_size = hidden_size
102
+
103
+ # [ts, time_idx, mask] concatenated together
104
+ self.shared_linear = nn.Linear(base_patch*3, intermediate_size) # putting mask on hidden.
105
+ self.shared_residual = nn.Linear(base_patch*3, hidden_size)
106
+
107
+ # MLP embedder ###
108
+ self.dropout = nn.Dropout(cfg['dropout_rate'])
109
+ self.act = ACT2FN['silu']
110
+ self.output_layer = nn.Linear(
111
+ intermediate_size, hidden_size)
112
+
113
+ self.initialize_weights()
114
+
115
+ def initialize_weights(self):
116
+ # initialize nn.Linear and nn.LayerNorm
117
+ def _init_weights(m):
118
+ if isinstance(m, nn.Linear):
119
+ # we use xavier_uniform following official JAX ViT:
120
+ torch.nn.init.xavier_uniform_(m.weight)
121
+ if isinstance(m, nn.Linear) and m.bias is not None:
122
+ nn.init.constant_(m.bias, 0)
123
+ elif isinstance(m, nn.LayerNorm):
124
+ nn.init.constant_(m.bias, 0)
125
+ nn.init.constant_(m.weight, 1.0)
126
+ self.apply(_init_weights)
127
+
128
+
129
+ def resize_weight(self, patch_size: int):
130
+ """
131
+ Interpolate weights along the patch dimension to target patch size.
132
+ """
133
+
134
+ base_w = self.shared_linear.weight # [out_dim, base_patch]
135
+ base_b = self.shared_linear.bias
136
+
137
+ res_w = self.shared_residual.weight
138
+ res_b = self.shared_residual.bias
139
+
140
+ # FlexiViT: interpolate kernel linearly along patch axis
141
+ # interpolate (base_patch, d) -> (patch_size,d)
142
+ new_w = F.interpolate(
143
+ base_w.unsqueeze(1), size=patch_size, mode="linear", align_corners=False
144
+ ).squeeze(1).to(base_w.dtype)
145
+
146
+ new_res_w = F.interpolate(
147
+ res_w.unsqueeze(1), size=patch_size, mode="linear", align_corners=False
148
+ ).squeeze(1).to(res_w.dtype)
149
+
150
+ return new_w, base_b,new_res_w,res_b
151
+
152
+
153
+ def forward(self, x_list, attention_mask, time_idx):
154
+ """
155
+ x_list: list of tensors of shape (num_patches, patch_size)
156
+ attention_mask: list of tensors.
157
+
158
+
159
+ Returns:
160
+ list of transformed tensors in the same order.
161
+ """
162
+
163
+ amp_dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else torch.float32
164
+ device = torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu")
165
+
166
+ # group by patch size
167
+ sizes = torch.tensor([x.shape[-1] for x in x_list])
168
+ unique_sizes = sizes.unique(sorted=True)
169
+ N = x_list[0].shape[0] # number of patches
170
+
171
+ outputs = torch.empty(len(x_list), N, self.intermediate_size,
172
+ device=device,dtype=amp_dtype)
173
+ res_outputs = torch.empty(len(x_list), N, self.hidden_size,
174
+ device=device,dtype=amp_dtype)
175
+
176
+ for psize in unique_sizes.tolist():
177
+ idxs = (sizes == psize).nonzero(as_tuple=True)[0]
178
+ xs = torch.stack([x_list[i] for i in idxs]) # B_g, num_p, ps
179
+ mask = torch.stack([attention_mask[i] for i in idxs]) # B_g, num_p, ps
180
+ ti = torch.stack([time_idx[i] for i in idxs])
181
+
182
+ xs = xs.to(device=device, non_blocking=True)
183
+ mask = mask.to(device=device, non_blocking=True)
184
+ ti = ti.to(device=device, non_blocking=True)
185
+
186
+ xs = torch.cat([xs,mask,ti],dim=-1) # B_g, num_p, ps*3
187
+ w, b, r_w, r_b = self.resize_weight(psize*3)
188
+
189
+ res_outputs[idxs] = F.linear(xs,r_w,r_b)
190
+ outputs[idxs] = F.linear(xs, w, b)
191
+
192
+ hid = self.act(outputs) # BS, num_p, intermediate_size
193
+ out = self.dropout(self.output_layer(hid)) # BS, num_p, hidden
194
+ out = out + res_outputs
195
+
196
+ return out
197
+
198
+
199
+ class PatchEmbedding(nn.Module):
200
+ def __init__(self, **cfg):
201
+ super().__init__()
202
+ patch_size = cfg['patch_size']
203
+ self.patch_size = patch_size
204
+
205
+ self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1))
206
+ hidden_size = cfg['embed_dim']
207
+ self.hidden_layer = nn.Linear(
208
+ patch_size * 3, hidden_size)
209
+ self.act = ACT2FN['silu']
210
+ self.output_layer = nn.Linear(
211
+ hidden_size, hidden_size)
212
+ self.residual_layer = nn.Linear(
213
+ patch_size * 3, hidden_size)
214
+ self.patch_size = patch_size
215
+
216
+ def forward(self, x, mask, time_idx):
217
+ '''
218
+ x,mask,time_idx: bs, nvar,L
219
+ '''
220
+ x = rearrange(x, 'bs nvar (nump ps) -> (bs nvar) nump ps', ps=self.patch_size)
221
+ mask = rearrange(mask, 'bs nvar (nump ps) -> (bs nvar) nump ps', ps=self.patch_size)
222
+ time_idx = rearrange(time_idx, 'bs nvar (nump ps) -> (bs nvar) nump ps', ps=self.patch_size)
223
+
224
+ x = torch.cat([x, mask,time_idx], dim=-1)
225
+ hid = self.act(self.hidden_layer(x))
226
+ out = self.dropout(self.output_layer(hid))
227
+ res = self.residual_layer(x)
228
+ out = out + res
229
+
230
+ return out # bs*nvar, num_p, hidden_size
231
+
232
+ class Attention(nn.Module):
233
+ def __init__(self, layer_idx: int, is_rope=True, **cfg):
234
+ super().__init__()
235
+ self.layer_idx = layer_idx
236
+ self.is_rope = is_rope
237
+ self.hidden_size = cfg.get("embed_dim", 768)
238
+ self.num_heads = cfg.get("num_heads", 12)
239
+ self.sensor_max_len = cfg.get("sensor_max_len", 2880)
240
+ self.head_dim = self.hidden_size // self.num_heads
241
+ self.attention_dropout = cfg.get("dropout_rate", 0.1)
242
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
243
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
244
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
245
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
246
+
247
+ if self.is_rope:
248
+ self.rotary_emb = RotaryEmbedding(
249
+ self.head_dim, max_position_embeddings=self.sensor_max_len)
250
+ else:
251
+ self.rotary_emb = None
252
+
253
+ def forward(
254
+ self,
255
+ hidden_states: torch.Tensor,
256
+ attention_mask: Optional[torch.Tensor] = None,
257
+ position_ids: Optional[torch.Tensor] = None, # index of positions.
258
+ **kwargs,
259
+ ) -> torch.Tensor:
260
+ '''
261
+ hidden_states: [bs, seq_len, hidden_size]
262
+ attention_mask: [bs, 1, seq_len, seq_len]
263
+ position_ids: [bs, seq_len]
264
+ '''
265
+
266
+ bsz, q_len, _ = hidden_states.size()
267
+ query_states = self.q_proj(hidden_states)
268
+ key_states = self.k_proj(hidden_states)
269
+ value_states = self.v_proj(hidden_states) # Bs, L, hidden_size
270
+
271
+ query_states = query_states.view(
272
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
273
+ key_states = key_states.view(
274
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
275
+ value_states = value_states.view(
276
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
277
+
278
+ if self.is_rope:
279
+ kv_seq_len = key_states.shape[-2]
280
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
281
+ query_states, key_states = apply_rotary_pos_emb(
282
+ query_states, key_states, cos, sin, position_ids)
283
+
284
+ attn_output = F.scaled_dot_product_attention(
285
+ query_states,
286
+ key_states,
287
+ value_states,
288
+ attention_mask,
289
+ dropout_p=self.attention_dropout
290
+ )
291
+
292
+ attn_output = attn_output.transpose(1, 2).contiguous()
293
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
294
+ attn_output = self.o_proj(attn_output)
295
+
296
+
297
+ return attn_output
298
+
299
+ class CrossAttention(nn.Module):
300
+ def __init__(self,
301
+ dim=768, # unifed embed space
302
+ *,
303
+ context_dim=384,
304
+ num_heads=12,
305
+ dropout_rate=0.1):
306
+ super().__init__()
307
+
308
+ self.dim = dim
309
+ self.num_heads = num_heads
310
+ self.head_dim = int(dim // num_heads)
311
+ self.scale = self.head_dim ** -0.5
312
+ self.attn_dropout = dropout_rate
313
+
314
+ self.norm = nn.LayerNorm(dim)
315
+ self.context_norm = nn.LayerNorm(context_dim)
316
+
317
+ self.q_proj = nn.Linear(dim, dim, bias=True)
318
+ self.k_proj = nn.Linear(context_dim, dim, bias=True)
319
+ self.v_proj = nn.Linear(context_dim, dim, bias=True)
320
+ self.o_proj = nn.Linear(dim, dim, bias=False)
321
+
322
+
323
+ def forward(
324
+ self,
325
+ query,
326
+ context,
327
+ attention_mask: Optional[torch.Tensor] = None,
328
+ **kwargs,
329
+ ) -> torch.Tensor:
330
+ '''
331
+ hidden_states: [bs, seq_len, hidden_size]
332
+ attention_mask: [BS, 1, seq_len, context_len]
333
+ position_ids: [bs, seq_len]
334
+ '''
335
+
336
+ bsz, q_len, _ = query.size()
337
+ bsc, k_len, _ = context.size()
338
+
339
+ assert bsz == bsc, f"Batch size mismatch: {bsz} vs {bsc}"
340
+
341
+ # pre-norm
342
+ query = self.norm(query)
343
+ context = self.context_norm(context)
344
+
345
+ query_states = self.q_proj(query)
346
+ key_states = self.k_proj(context)
347
+ value_states = self.v_proj(context) # Bs, L, hidden_size
348
+
349
+ query_states = query_states.view(
350
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
351
+ key_states = key_states.view(
352
+ bsz, k_len, self.num_heads, self.head_dim).transpose(1, 2)
353
+ value_states = value_states.view(
354
+ bsz, k_len, self.num_heads, self.head_dim).transpose(1, 2)
355
+
356
+
357
+ attn_output = F.scaled_dot_product_attention(
358
+ query_states,
359
+ key_states,
360
+ value_states,
361
+ attention_mask,
362
+ dropout_p=self.attn_dropout
363
+ )
364
+
365
+ attn_output = attn_output.transpose(1, 2).contiguous()
366
+ attn_output = attn_output.reshape(bsz, q_len, self.dim)
367
+ attn_output = self.o_proj(attn_output) # bs, q_len, dim
368
+
369
+ return attn_output
370
+
371
+ class MLP(nn.Module):
372
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
373
+ super().__init__()
374
+ self.hidden_size = hidden_size
375
+ self.intermediate_size = intermediate_size
376
+ self.gate_proj = nn.Linear(
377
+ self.hidden_size, self.intermediate_size, bias=False)
378
+ self.up_proj = nn.Linear(
379
+ self.hidden_size, self.intermediate_size, bias=False)
380
+ self.down_proj = nn.Linear(
381
+ self.intermediate_size, self.hidden_size, bias=False)
382
+ self.act_fn = ACT2FN[hidden_act]
383
+
384
+ def forward(self, hidden_state):
385
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
386
+
387
+
388
+
389
+ class AllAttention(nn.Module):
390
+ def __init__(self, layer_idx, **cfg):
391
+ super().__init__()
392
+ self.self_attention = TsRoPEAttention(**cfg, layer_idx=layer_idx)
393
+ self.layer_norm = nn.LayerNorm(cfg.get('embed_dim'))
394
+ self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1))
395
+
396
+ def forward(
397
+ self,
398
+ hidden_states: torch.Tensor,
399
+ attention_mask: torch.Tensor,
400
+ ):
401
+ '''
402
+ ts self attention with residual
403
+ hidden_states: bs (nvar L) d
404
+ attention_mask: bs, nvar, L
405
+
406
+ '''
407
+
408
+ normed_hidden_states = self.layer_norm(hidden_states) # pre-norm
409
+ attention_output = self.self_attention(
410
+ normed_hidden_states,
411
+ attention_mask,
412
+ )
413
+
414
+ # residual
415
+ hidden_states = hidden_states + self.dropout(attention_output)
416
+
417
+ return hidden_states
418
+
419
+ class TimeSelfAttention(nn.Module):
420
+ def __init__(self, layer_idx, **cfg):
421
+ super().__init__()
422
+ self.self_attention = Attention(layer_idx=layer_idx, is_rope=True, **cfg)
423
+ self.layer_norm = nn.LayerNorm(cfg.get('embed_dim', 768))
424
+ self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1))
425
+
426
+ def forward(
427
+ self,
428
+ hidden_states: torch.Tensor,
429
+ attention_mask: torch.Tensor,
430
+ position_ids: torch.Tensor,
431
+ ):
432
+ '''
433
+ ts self attention with residual
434
+ hidden_states: bs*nvar, L, d
435
+ attention_mask: bs, nvar, L
436
+
437
+ '''
438
+
439
+ q_len = hidden_states.size(1)
440
+ attention_mask = rearrange(attention_mask, 'b nvar p -> (b nvar) p') # bs*nvar, L
441
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).expand(-1, 1, q_len, q_len) # bs*nvar, 1, L, L
442
+ attention_mask = attention_mask.bool() # convert to bool
443
+
444
+ normed_hidden_states = self.layer_norm(hidden_states) # pre-norm
445
+ attention_output = self.self_attention(
446
+ normed_hidden_states,
447
+ attention_mask,
448
+ position_ids
449
+ )
450
+
451
+ # residual
452
+ hidden_states = hidden_states + self.dropout(attention_output)
453
+
454
+ return hidden_states
455
+
456
+
457
+ class GroupSelfAttention(nn.Module):
458
+ """Self-attention applied along the batch axis masked by the group attention mask"""
459
+
460
+ def __init__(self, layer_idx: int, **cfg):
461
+ super().__init__()
462
+ # we don't use RoPE here because there's no natural ordering along the batch axis
463
+ self.self_attention = Attention(layer_idx, is_rope=False, **cfg)
464
+ self.layer_norm = nn.LayerNorm(cfg.get('embed_dim', 768))
465
+ self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1))
466
+
467
+ def _construct_group_mask(self,
468
+ group_ids: torch.Tensor,
469
+ attention_mask: torch.Tensor) -> torch.Tensor:
470
+
471
+ # construct group_mask (batch, batch) from group ids
472
+ # a cell is True if both row and col had the same group id
473
+ group_mask = group_ids[:, None] == group_ids[None, :]
474
+
475
+ # group_mask: bs*nvar, bs*nvar
476
+ # attention_mask: bs*nvar, L
477
+ group_time_mask = torch.einsum("qb, bt -> qbt", group_mask, attention_mask).float() # bs*nvar, bs*nvar, L
478
+ group_time_mask = rearrange(group_time_mask, "q b t -> t 1 q b") # L,1, bs*nvar, bs*nvar
479
+ group_time_mask = group_time_mask.bool() # convert to bool
480
+
481
+ return group_time_mask
482
+
483
+ def forward(
484
+ self,
485
+ hidden_states: torch.Tensor,
486
+ attention_mask: torch.Tensor,
487
+ group_ids: torch.Tensor,
488
+ ):
489
+
490
+ '''
491
+ hidden_states: bs*nvar, L, d
492
+ attention_mask: bs, nvar, L
493
+ group_ids: bs*nvar
494
+ '''
495
+
496
+
497
+ # attention_mask = rearrange(attention_mask, 'b nvar l -> (b nvar) l') # bs*nvar, L
498
+ # hidden_states = rearrange(hidden_states, 'bs l d -> l bs d',) # L, bs*nvar, d
499
+ # group_attn_mask = self._construct_group_mask(group_ids, attention_mask) #L,1, bs*nvar, bs*nvar
500
+
501
+ BS, nvar, _ = attention_mask.shape
502
+ hidden_states = rearrange(hidden_states, '(bs nvar) l d -> (bs l) nvar d', bs=BS, nvar=nvar)
503
+ attention_mask = rearrange(attention_mask, 'bs nvar l -> (bs l) nvar') # (bs*L), nvar
504
+ group_attn_mask = attention_mask.unsqueeze(1).unsqueeze(2).expand(-1, 1, nvar, nvar).bool() # (bs*L), 1, nvar, nvar
505
+
506
+ normed_hidden_states = self.layer_norm(hidden_states)
507
+ attention_output = self.self_attention(
508
+ normed_hidden_states,
509
+ group_attn_mask,
510
+ )
511
+ hidden_states = hidden_states + self.dropout(attention_output)
512
+ # flip time and batch axes back to their original position
513
+ hidden_states = rearrange(hidden_states, '(bs l) nvar d -> (bs nvar) l d', bs=BS, nvar=nvar)
514
+ # hidden_states = rearrange(hidden_states, "time batch d -> batch time d") # Bs*nvar, L, d
515
+
516
+
517
+ return hidden_states
518
+
519
+ class AttentionPooling(nn.Module):
520
+ def __init__(self,
521
+ dim=768,
522
+ mlp_ratio=4,
523
+ context_dim=384,
524
+ num_heads=12,
525
+ dropout_rate=0.1):
526
+ super().__init__()
527
+
528
+ self.cross_attn = CrossAttention(dim=dim,
529
+ context_dim=context_dim,
530
+ num_heads=num_heads,
531
+ dropout_rate=dropout_rate)
532
+
533
+ self.ffn_norm = nn.LayerNorm(dim)
534
+ self.ffn_layer = MLP(
535
+ hidden_size=dim,
536
+ intermediate_size=dim * mlp_ratio,
537
+ hidden_act='silu',
538
+ )
539
+
540
+ self.post_norm = nn.LayerNorm(dim)
541
+
542
+ def forward(self, x, context, attn_mask=None):
543
+ # x: BS, num_query, dim
544
+ # context: BS, num_kv, context_dim
545
+ # attn_mask: BS, nvar, num_p,
546
+ b,n,_ = x.shape
547
+ kv_len = context.shape[1]
548
+
549
+ attn_mask = rearrange(attn_mask, 'b nvar p -> b (nvar p)')
550
+ attn_mask = attn_mask.view(b, 1, 1, kv_len).expand(b, 1, n, kv_len).bool()
551
+
552
+ x = self.cross_attn(x, context, attn_mask)
553
+ x = x + self.ffn_layer(self.ffn_norm(x))
554
+ x = self.post_norm(x)
555
+
556
+ return x
557
+
558
+
559
+ class SensorEncoderLayer(nn.Module):
560
+ def __init__(self, layer_idx: int, **cfg):
561
+ super().__init__()
562
+
563
+ hidden_size = cfg['embed_dim']
564
+ intermediate_size = cfg['mlp_ratio'] * hidden_size
565
+
566
+ self.channel_attn_type = cfg.get('channel_attn_type', 'group_attn')
567
+ if self.channel_attn_type == 'group_attn':
568
+ self.ts_attn = TimeSelfAttention(layer_idx=layer_idx, **cfg) # pre-norm
569
+ self.group_attn = GroupSelfAttention(layer_idx=layer_idx, **cfg) # pre-norm
570
+ elif self.channel_attn_type == 'univariate':
571
+ self.ts_attn = TimeSelfAttention(layer_idx=layer_idx, **cfg)
572
+ else:
573
+ self.ts_attn = AllAttention(layer_idx=layer_idx, **cfg)
574
+
575
+ self.norm = nn.LayerNorm(hidden_size) # post-norm
576
+
577
+ self.ffn_layer = MLP(
578
+ hidden_size=hidden_size,
579
+ intermediate_size=intermediate_size,
580
+ hidden_act='silu',
581
+ )
582
+
583
+ def forward(
584
+ self,
585
+ hidden_states: torch.Tensor,
586
+ attention_mask: Optional[torch.Tensor] = None,
587
+ group_ids: Optional[torch.Tensor] = None,
588
+ position_ids: Optional[torch.Tensor] = None,
589
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
590
+
591
+
592
+ if self.channel_attn_type == 'group_attn':
593
+ '''
594
+ Time self attention with residual
595
+ hidden_states: bs*nvar, L, d
596
+ attention_mask: bs, nvar, L
597
+ group_attention_mask: bs*nvar, bs*nvar
598
+ '''
599
+ hidden_states = self.ts_attn(
600
+ hidden_states=hidden_states,
601
+ attention_mask=attention_mask,
602
+ position_ids=position_ids
603
+ ) # handled residual
604
+
605
+
606
+ hidden_states = self.group_attn(
607
+ hidden_states=hidden_states,
608
+ attention_mask=attention_mask,
609
+ group_ids=group_ids,
610
+ ) # handled residual
611
+
612
+ # Fully Connected
613
+ residual = hidden_states
614
+ hidden_states = self.norm(hidden_states)
615
+ hidden_states = self.ffn_layer(hidden_states)
616
+ hidden_states = residual + hidden_states
617
+
618
+ elif self.channel_attn_type == 'univariate':
619
+ # hidden_states: bs*nvar, L, d
620
+ hidden_states = self.ts_attn(
621
+ hidden_states=hidden_states,
622
+ attention_mask=attention_mask,
623
+ position_ids=position_ids
624
+ ) # handled residual
625
+
626
+ # Fully Connected
627
+ residual = hidden_states
628
+ hidden_states = self.norm(hidden_states)
629
+ hidden_states = self.ffn_layer(hidden_states)
630
+ hidden_states = residual + hidden_states
631
+
632
+ else:
633
+ # hidden_states: bs (nvar L) d
634
+ hidden_states = self.ts_attn(
635
+ hidden_states=hidden_states,
636
+ attention_mask=attention_mask,
637
+ ) # b (nvar l) d
638
+
639
+ residual = hidden_states
640
+ hidden_states = self.norm(hidden_states)
641
+ hidden_states = self.ffn_layer(hidden_states)
642
+ hidden_states = residual + hidden_states
643
+
644
+
645
+ return hidden_states
646
+
647
+
648
+
649
+ class SensorTransformerModel(nn.Module):
650
+ def __init__(self, **cfg):
651
+ super().__init__()
652
+ patch_size = cfg.get('patch_size', None)
653
+ self.patch_size = patch_size
654
+ if patch_size is not None:
655
+ # fixed patch size embedder
656
+ self.patch_embed = PatchEmbedding(**cfg)
657
+ else:
658
+ self.patch_embed = MultiSizePatchEmbed(**cfg)
659
+
660
+ self.blocks = nn.ModuleList(
661
+ [SensorEncoderLayer(layer_idx, **cfg)
662
+ for layer_idx in range(cfg['depth'])]
663
+ )
664
+ self.norm = torch.nn.LayerNorm(cfg['embed_dim'])
665
+ self.embed_dim = cfg['embed_dim']
666
+ self.channel_attn_type = cfg.get('channel_attn_type', 'group_attn') # group_attn, all_attn, univariate
667
+
668
+ def forward(
669
+ self,
670
+ input_ids,
671
+ attention_mask,
672
+ time_index,):
673
+
674
+
675
+ if self.patch_size is None:
676
+ '''
677
+ input_ids: list of list of tensor # BS, nvar, num_p, patch_size
678
+ attention_mask: same as input_ids
679
+
680
+ self.patch_embed will handle device.
681
+ '''
682
+ BS = len(input_ids)
683
+ flat_input_ids = flatten_list(input_ids)
684
+ flat_attention_mask = flatten_list(attention_mask)
685
+ flat_time_index = flatten_list(time_index)
686
+
687
+ # embed each variable separately
688
+ hidden_states = self.patch_embed(flat_input_ids,flat_attention_mask,flat_time_index) # (bs*nvar, seq_len, embed_dim)
689
+
690
+ attention_mask = self._get_self_attn_mask(attention_mask).to(hidden_states.device) # BS, nvar, num_p
691
+ position_ids = self._build_rope_position_ids(attention_mask) # BS, nvar, num_p
692
+ position_ids = rearrange(position_ids, 'b nvar p -> (b nvar) p') # BS*nvar, num_p
693
+
694
+ else:
695
+ '''
696
+ input_ids: tensor # BS, nvar, L
697
+ attention_mask: tensor # BS, nvar, L
698
+ time_index: tensor # BS, nvar, L
699
+ '''
700
+
701
+ BS, nvar, L = input_ids.shape
702
+ hidden_states = self.patch_embed(input_ids, attention_mask, time_index) # (bs*nvar, seq_len, embed_dim)
703
+ # transform pixel-level attn mask (BS, nvar, L)to patch-level attn mask (BS, nvar, num_p), element would be 1 if all pixel is 1,if all pixel is 0, then is 0
704
+ attention_mask = reduce(
705
+ attention_mask,
706
+ 'b v (p ps) -> b v p',
707
+ 'max',
708
+ ps=self.patch_size
709
+ )
710
+
711
+ position_ids = self._build_rope_position_ids(attention_mask) # BS, nvar, num_p
712
+ position_ids = rearrange(position_ids, 'b nvar p -> (b nvar) p') # BS*nvar, num_p
713
+
714
+ if self.channel_attn_type == 'all_attn':
715
+ hidden_states = rearrange(hidden_states, '(b nvar) l d -> b (nvar l) d', b=BS)
716
+
717
+ for blk in self.blocks:
718
+ hidden_states = blk(
719
+ hidden_states,
720
+ attention_mask=attention_mask,
721
+ group_ids=None, # legacy argument
722
+ position_ids=position_ids,
723
+ ) # bs*nvar, seq, emb or bs (nvar l) d
724
+
725
+ if self.channel_attn_type == 'group_attn':
726
+ hidden_states = rearrange(hidden_states, '(b nvar) l d -> b (nvar l) d', b=BS)
727
+
728
+ hidden_states = self.norm(hidden_states) # (Bs*nvar), seq, emb
729
+
730
+ return hidden_states, attention_mask
731
+
732
+ def _build_rope_position_ids(self,attention_mask):
733
+ """
734
+ attention_mask: Tensor [BS, nvar, num_p]
735
+ returns: LongTensor [BS, nvar, num_p]
736
+ """
737
+ assert attention_mask.dim() == 3
738
+ BS, nvar, num_p = attention_mask.shape
739
+
740
+ mask = attention_mask.to(torch.long)
741
+
742
+ # position index increases inside each variable
743
+ pos = (mask.cumsum(dim=-1) - 1) * mask # [BS, nvar, num_p]
744
+
745
+ return pos
746
+
747
+ def _get_self_attn_mask(self,attn_mask_list):
748
+ """
749
+ Collapse a nested list of attention masks from shape
750
+ [BS][nvar][num_p, patch_size]
751
+ into tensors of shape [BS, nvar, num_p].
752
+
753
+ Args:
754
+ attention_mask (list[list[Tensor]]):
755
+ Each tensor has shape [num_p, patch_size], and all have the same shape.
756
+
757
+ Returns:
758
+ torch.Tensor (BS, nvar, num_p)
759
+ """
760
+ collapsed_batch = []
761
+ for sample_masks in attn_mask_list: # loop over batch
762
+ # collapse each [num_p, patch_size] → [num_p]
763
+ nvar_collapsed = [
764
+ (var_mask.sum(dim=-1) > 0).to(var_mask.dtype) for var_mask in sample_masks
765
+ ]
766
+ nvar_collapsed = torch.stack(nvar_collapsed, dim=0) # [nvar, num_p]
767
+ collapsed_batch.append(nvar_collapsed)
768
+
769
+ collapsed_batch = torch.stack(collapsed_batch, dim=0) # [BS, nvar, num_p]
770
+ return collapsed_batch
771
+
772
+ def _get_group_ids(self,attn_mask_list):
773
+ """
774
+ attn_mask_list: list of list of tensor
775
+ BS, nvar
776
+ each tensor is shape (num_p, patch_size)
777
+
778
+ Returns:
779
+ group_mask: (BS*nvar, BS*nvar) boolean tensor
780
+ True means same group
781
+ False means different group
782
+ """
783
+ BS = len(attn_mask_list)
784
+ nvar = len(attn_mask_list[0])
785
+
786
+ # build group ids
787
+ # each sample i repeats nvar times
788
+ group_ids = torch.arange(BS).repeat_interleave(nvar) # (BS*nvar)
789
+
790
+ return group_ids
791
+
792
+
793
+
794
+
795
+ if __name__ == "__main__":
796
+ from model_factory.coca import Config
797
+ cfg = Config(embed_dim=384,
798
+ num_heads=6,
799
+ mlp_ratio=4,
800
+ depth=12,
801
+ dropout_rate=0.1,)
802
+ sensor_model = SensorTransformerModel(**cfg)
803
+ dummy_input = [[torch.randn(14,40),torch.randn(14,40)],[torch.randn(14,40),torch.randn(14,30)]]
804
+ mask = [[torch.ones(14,40),torch.zeros(14,40)],[torch.zeros(14,40),torch.zeros(14,30)]]
805
+ time_idx = [[torch.ones(14,40),torch.ones(14,40)],[torch.ones(14,40),torch.ones(14,30)]]
806
+
807
+ out, attn_mask = sensor_model(dummy_input,mask,time_idx)
808
+ print(out.shape) # expect (2*2, max_num_patches, embed
809
+ # python -m model_factory.ts_transformer
modeling_slip.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SLIP Model - HuggingFace Hub Loading Interface
3
+
4
+ Usage:
5
+ from modeling_slip import SLIPModel
6
+ model = SLIPModel.from_pretrained("LeoChen085/SLIP")
7
+
8
+ # Or load a task-specific checkpoint:
9
+ model = SLIPModel.from_pretrained("LeoChen085/SLIP", checkpoint="har.safetensors")
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import json
15
+ import math
16
+ import torch
17
+ import torch.nn as nn
18
+ from pathlib import Path
19
+ from typing import Optional
20
+
21
+ # Ensure model_factory and util are importable from the HF cache directory
22
+ _THIS_DIR = Path(__file__).resolve().parent
23
+ if str(_THIS_DIR) not in sys.path:
24
+ sys.path.insert(0, str(_THIS_DIR))
25
+
26
+ from model_factory.ts_transformer import (
27
+ SensorTransformerModel,
28
+ AttentionPooling,
29
+ CrossAttention,
30
+ )
31
+ from model_factory.multimodal_gemma import (
32
+ Gemma3MultimodalModel,
33
+ Residual,
34
+ )
35
+
36
+
37
+ # ── Lightweight helpers (from SLIP.py, no distributed deps) ──
38
+
39
+ def masked_mean(t, mask, dim=1, eps=1e-6):
40
+ t = t.masked_fill(~mask, 0.)
41
+ numer = t.sum(dim=dim)
42
+ denom = mask.sum(dim=dim).clamp(min=eps)
43
+ return numer / denom
44
+
45
+
46
+ class EmbedToLatents(nn.Module):
47
+ def __init__(self, dim, dim_latents):
48
+ super().__init__()
49
+ self.to_latents = nn.Linear(dim, dim_latents, bias=False)
50
+
51
+ def forward(self, x):
52
+ latents = self.to_latents(x)
53
+ return torch.nn.functional.normalize(latents, dim=-1)
54
+
55
+
56
+ class SLIPModel(nn.Module):
57
+ """
58
+ SLIP model for inference. Loads from HuggingFace Hub without Hydra dependency.
59
+
60
+ Supports:
61
+ - get_embedding(text, sensors) -> (text_emb, sensor_emb)
62
+ - get_sensor_embedding(input_ids, mask, time_index) -> sensor_emb
63
+ - generate(text, sensors) -> generated_token_ids
64
+ - sft_training(text, sensors) -> loss_dict
65
+ """
66
+
67
+ def __init__(self, config: dict):
68
+ super().__init__()
69
+
70
+ # Build sensor encoder directly (no Hydra)
71
+ sensor_cfg = config["sensor_encoder"]
72
+ self.sensor_encoder = SensorTransformerModel(**sensor_cfg)
73
+
74
+ dim = self.sensor_encoder.embed_dim # 768
75
+ self.embed_dim = dim
76
+
77
+ # Build multimodal LLM
78
+ llm_model_name = config.get("llm_model_name", "google/gemma-3-270m")
79
+ post_train = config.get("post_train", True)
80
+ split_layer = config.get("split_layer", 12)
81
+ self.multimodalModel = Gemma3MultimodalModel(
82
+ llm_model_name, post_train, split_layer
83
+ )
84
+
85
+ lm_dim = self.multimodalModel.hidden_size # 640
86
+ self.lm_dim = lm_dim
87
+ common_dim = config.get("common_dim", lm_dim)
88
+
89
+ # Attention pooling
90
+ num_img_queries = config.get("num_img_queries", 0)
91
+ if num_img_queries > 0:
92
+ self.img_queries = nn.Parameter(
93
+ torch.randn(num_img_queries + 1, common_dim)
94
+ )
95
+ self.img_attn_pool = AttentionPooling(
96
+ dim=common_dim,
97
+ context_dim=dim,
98
+ num_heads=config.get("num_heads", 5),
99
+ )
100
+ dim = common_dim
101
+
102
+ # Bridge projections
103
+ self.img_to_latents = EmbedToLatents(dim, common_dim)
104
+ self.text_to_latents = EmbedToLatents(common_dim, common_dim)
105
+
106
+ # Temperature
107
+ self.temperature = nn.Parameter(torch.tensor(math.log(1 / 0.07)))
108
+ self.temperature_max = math.log(1 / 0.07)
109
+
110
+ # Store config
111
+ self.config_dict = config
112
+
113
+ @classmethod
114
+ def from_pretrained(
115
+ cls,
116
+ repo_id_or_path: str,
117
+ checkpoint: str = "model.safetensors",
118
+ device: str = "cpu",
119
+ dtype: torch.dtype = torch.bfloat16,
120
+ **kwargs,
121
+ ) -> "SLIPModel":
122
+ """
123
+ Load SLIP from a HuggingFace repo or local directory.
124
+
125
+ Args:
126
+ repo_id_or_path: HuggingFace repo ID (e.g., "LeoChen085/SLIP")
127
+ or local directory path.
128
+ checkpoint: Which checkpoint file to load.
129
+ Default "model.safetensors" (base pretrained).
130
+ Options: "har.safetensors", "sleep.safetensors",
131
+ "ecg.safetensors", "tsqa.safetensors",
132
+ "caption.safetensors"
133
+ device: Device to load model on.
134
+ dtype: Model dtype (default bfloat16).
135
+ """
136
+ local_path = Path(repo_id_or_path)
137
+
138
+ if local_path.is_dir():
139
+ # Load from local directory
140
+ config_path = local_path / "config.json"
141
+ weights_path = local_path / checkpoint
142
+ else:
143
+ # Download from HuggingFace Hub
144
+ from huggingface_hub import hf_hub_download
145
+
146
+ config_path = hf_hub_download(repo_id_or_path, "config.json")
147
+ weights_path = hf_hub_download(repo_id_or_path, checkpoint)
148
+
149
+ # Also download source files (needed for model classes)
150
+ for src_file in [
151
+ "model_factory/__init__.py",
152
+ "model_factory/SLIP.py",
153
+ "model_factory/multimodal_gemma.py",
154
+ "model_factory/ts_transformer.py",
155
+ "util/__init__.py",
156
+ "util/pos_embed.py",
157
+ ]:
158
+ try:
159
+ hf_hub_download(repo_id_or_path, src_file)
160
+ except Exception:
161
+ pass # File may not exist separately
162
+
163
+ # Load config
164
+ with open(config_path) as f:
165
+ config = json.load(f)
166
+
167
+ # Build model
168
+ print(f"Building SLIP model...")
169
+ model = cls(config)
170
+
171
+ # Load weights
172
+ print(f"Loading weights from {checkpoint}...")
173
+ if str(weights_path).endswith(".safetensors"):
174
+ from safetensors.torch import load_file
175
+ state_dict = load_file(weights_path, device=device)
176
+ else:
177
+ state_dict = torch.load(weights_path, map_location=device, weights_only=False)
178
+ if isinstance(state_dict, dict):
179
+ if "model" in state_dict:
180
+ state_dict = state_dict["model"]
181
+ elif "state_dict" in state_dict:
182
+ state_dict = state_dict["state_dict"]
183
+ # Remove DDP module. prefix
184
+ state_dict = {
185
+ k.replace("module.", "", 1) if k.startswith("module.") else k: v
186
+ for k, v in state_dict.items()
187
+ if isinstance(v, torch.Tensor)
188
+ }
189
+
190
+ # Load state dict
191
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
192
+ if missing:
193
+ print(f"Missing keys ({len(missing)}): {missing[:5]}{'...' if len(missing) > 5 else ''}")
194
+ if unexpected:
195
+ print(f"Unexpected keys ({len(unexpected)}): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
196
+
197
+ model = model.to(dtype=dtype, device=device)
198
+ model.eval()
199
+ print("Model loaded successfully.")
200
+ return model
201
+
202
+ # ── Inference methods ─────────────────────────────────────
203
+
204
+ def embed_sensor(self, sensors, sensor_attn_mask=None, time_index=None):
205
+ from einops import repeat
206
+ sensor_tokens, attn_mask = self.sensor_encoder(
207
+ sensors, sensor_attn_mask, time_index=time_index
208
+ )
209
+ if hasattr(self, "img_attn_pool"):
210
+ img_queries = repeat(
211
+ self.img_queries, "n d -> b n d", b=sensor_tokens.shape[0]
212
+ )
213
+ sensor_tokens = self.img_attn_pool(img_queries, sensor_tokens, attn_mask)
214
+ return sensor_tokens, attn_mask.bool()
215
+
216
+ @torch.no_grad()
217
+ def get_embedding(self, text, sensors):
218
+ from einops import rearrange
219
+ sensor_hidden, sensor_mask = self.embed_sensor(
220
+ sensors=sensors["input_ids"],
221
+ sensor_attn_mask=sensors["attention_mask"],
222
+ time_index=sensors["time_index"],
223
+ )
224
+ self.multimodalModel.condition_image(sensor_hidden)
225
+ text_hidden, _ = self.multimodalModel(
226
+ input_ids=text["input_ids"][:, :-1],
227
+ attention_mask=text["attention_mask"][:, :-1],
228
+ )
229
+ text_hidden = self.text_to_latents(text_hidden)
230
+ sensor_hidden = self.img_to_latents(sensor_hidden)
231
+ if hasattr(self, "img_attn_pool"):
232
+ sensor_hidden = sensor_hidden[:, 0, :]
233
+ else:
234
+ sensor_hidden = masked_mean(
235
+ sensor_hidden,
236
+ rearrange(sensor_mask, "b n p -> b (n p) 1"),
237
+ dim=1,
238
+ )
239
+ return text_hidden, sensor_hidden
240
+
241
+ @torch.no_grad()
242
+ def get_sensor_embedding(self, input_ids, mask, time_index):
243
+ from einops import rearrange
244
+ sensor_hidden, sensor_mask = self.embed_sensor(
245
+ sensors=input_ids, sensor_attn_mask=mask, time_index=time_index
246
+ )
247
+ sensor_hidden = self.img_to_latents(sensor_hidden)
248
+ if hasattr(self, "img_attn_pool"):
249
+ sensor_hidden = sensor_hidden[:, 0, :]
250
+ else:
251
+ sensor_hidden = masked_mean(
252
+ sensor_hidden,
253
+ rearrange(sensor_mask, "b n p -> b (n p) 1"),
254
+ dim=1,
255
+ )
256
+ return sensor_hidden
257
+
258
+ @torch.no_grad()
259
+ def generate(self, text, sensors, **generate_kwargs):
260
+ sensor_hidden, _ = self.embed_sensor(
261
+ sensors=sensors["input_ids"],
262
+ sensor_attn_mask=sensors["attention_mask"],
263
+ time_index=sensors["time_index"],
264
+ )
265
+ self.multimodalModel.condition_image(sensor_hidden)
266
+ return self.multimodalModel.model.generate(
267
+ input_ids=text["input_ids"],
268
+ attention_mask=text["attention_mask"],
269
+ max_new_tokens=generate_kwargs.get("max_new_tokens", 300),
270
+ do_sample=generate_kwargs.get("do_sample", False),
271
+ num_beams=generate_kwargs.get("num_beams", 1),
272
+ )
sleep.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9685e181a9b4038d03744647621f864ff3a3e866520ec7038c061e8ce0e88b13
3
+ size 1386043740
tsqa.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8f6b02d497fa409d65c18b776da0132046143e88514336cdd83255dbbf76833
3
+ size 1386043740
util/__init__.py ADDED
File without changes
util/pos_embed.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False):
6
+ """
7
+ Create 1D sine-cosine positional embeddings.
8
+
9
+ Args:
10
+ embed_dim (int): Dimension of the embedding (must be even)
11
+ length (int): Number of positions (sequence length)
12
+ cls_token (bool): Whether to include an extra zero vector for [CLS] token
13
+
14
+ Returns:
15
+ np.ndarray of shape (length, embed_dim) or (1+length, embed_dim) if cls_token=True
16
+ """
17
+ # position indices 0 ... length-1
18
+ pos = np.arange(length, dtype=np.float32)
19
+
20
+ # get embedding from grid
21
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, pos) # (L, D)
22
+
23
+ # optionally add CLS token embedding
24
+ if cls_token:
25
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
26
+ return pos_embed
27
+
28
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
29
+ # --------------------------------------------------------
30
+ # 2D sine-cosine position embedding
31
+ # References:
32
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
33
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
34
+ # --------------------------------------------------------
35
+
36
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
37
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
38
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
39
+ grid = np.stack(grid, axis=0)
40
+
41
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
42
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
43
+ if cls_token:
44
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
45
+ return pos_embed
46
+
47
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
48
+ assert embed_dim % 2 == 0
49
+
50
+ # use half of dimensions to encode grid_h
51
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) #changed(H*W, D/2)
52
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) #changed (H*W, D/2)
53
+
54
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
55
+ return emb
56
+
57
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
58
+ """
59
+ embed_dim: output dimension for each position
60
+ pos: a list of positions to be encoded: size (M,)
61
+ out: (M, D)
62
+ """
63
+ assert embed_dim % 2 == 0
64
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
65
+ omega /= embed_dim / 2.
66
+ omega = 1. / 10000**omega # (D/2,)
67
+
68
+ pos = pos.reshape(-1) # (M,)
69
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
70
+
71
+ emb_sin = np.sin(out) # (M, D/2)
72
+ emb_cos = np.cos(out) # (M, D/2)
73
+
74
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
75
+ return emb
76
+
77
+ def interpolate_pos_embed(model, checkpoint_model, orig_size, new_size):
78
+ '''
79
+ Input: model: the class is definging for downstream
80
+ checkpoint_model: pre-train weight
81
+ orig_size = patch size in the ckpt
82
+ new_size = patch size in the current model
83
+ '''
84
+
85
+ if 'pos_embed' in checkpoint_model:
86
+ pos_embed_checkpoint = checkpoint_model['pos_embed'] # 1 x 560 x 768 (1 x num_patches x E)
87
+ embedding_size = pos_embed_checkpoint.shape[-1] # 768
88
+
89
+ # number of special tokens (e.g. in this case num_extra_tokens = 1 for the cls token)
90
+ num_patches = model.patch_embed.num_patches
91
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
92
+
93
+ if orig_size != new_size:
94
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size[0], orig_size[1], new_size[0], new_size[1]))
95
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
96
+ # only the position tokens are interpolated
97
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] # old positions
98
+ pos_tokens = pos_tokens.reshape(-1, orig_size[0], orig_size[1], embedding_size).permute(0, 3, 1, 2)
99
+ pos_tokens = torch.nn.functional.interpolate(
100
+ pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
101
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
102
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
103
+ checkpoint_model['pos_embed'] = new_pos_embed
104
+
105
+
106
+
107
+ # RoPE: https://huggingface.co/thuml/sundial-base-128m/blob/main/modeling_sundial.py
108
+ class RotaryEmbedding(torch.nn.Module):
109
+ def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None):
110
+ super().__init__()
111
+ self.dim = dim
112
+ self.max_position_embeddings = max_position_embeddings
113
+ self.base = base
114
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
115
+ 2, dtype=torch.int64).float().to(device) / self.dim))
116
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
117
+
118
+ # Build here to make `torch.jit.trace` work.
119
+ self._set_cos_sin_cache(
120
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
121
+ )
122
+
123
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
124
+ self.max_seq_len_cached = seq_len
125
+ t = torch.arange(self.max_seq_len_cached, device=device,
126
+ dtype=torch.int64).type_as(self.inv_freq)
127
+
128
+ freqs = torch.outer(t, self.inv_freq)
129
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
130
+ emb = torch.cat((freqs, freqs), dim=-1)
131
+ self.register_buffer(
132
+ "cos_cached", emb.cos().to(dtype), persistent=False)
133
+ self.register_buffer(
134
+ "sin_cached", emb.sin().to(dtype), persistent=False)
135
+
136
+ def forward(self, x, seq_len=None):
137
+ # x: [bs, num_attention_heads, seq_len, head_size]
138
+ if seq_len > self.max_seq_len_cached:
139
+ self._set_cos_sin_cache(
140
+ seq_len=seq_len, device=x.device, dtype=x.dtype)
141
+
142
+ return (
143
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
144
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
145
+ )
146
+
147
+ def rotate_half(x):
148
+ x1 = x[..., : x.shape[-1] // 2]
149
+ x2 = x[..., x.shape[-1] // 2:]
150
+ return torch.cat((-x2, x1), dim=-1)
151
+
152
+
153
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
154
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
155
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
156
+ q_embed = (q * cos) + (rotate_half(q) * sin)
157
+ k_embed = (k * cos) + (rotate_half(k) * sin)
158
+ return q_embed, k_embed
159
+
160
+ # two dimensional version
161
+ def apply_rotary_pos_emb_2d(q, k,
162
+ cos_h, sin_h,
163
+ cos_w, sin_w,
164
+ pos_h, pos_w,
165
+ unsqueeze_dim=1):
166
+ """
167
+ q, k: [B, heads, N, Dh]
168
+ cos_h, sin_h: caches from 1D rotary with dim = Dh // 2 for the first axis
169
+ cos_w, sin_w: caches from 1D rotary with dim = Dh // 2 for the second axis
170
+ pos_h, pos_w: [B, N] integer positions for each token along the two axes
171
+ returns q_out, k_out with same shape as q, k
172
+ """
173
+ Dh = q.shape[-1]
174
+ assert Dh % 4 == 0, "head dim must be divisible by 4 so each half is even for rotate_half"
175
+
176
+ # split channel dim into two halves
177
+ q_h, q_w = q.split(Dh // 2, dim=-1)
178
+ k_h, k_w = k.split(Dh // 2, dim=-1)
179
+
180
+ # apply 1D RoPE on each half with its own positions
181
+ pos_h = pos_h.long()
182
+ pos_w = pos_w.long()
183
+ q_h, k_h = apply_rotary_pos_emb(q_h, k_h, cos_h, sin_h, pos_h, unsqueeze_dim=unsqueeze_dim)
184
+ q_w, k_w = apply_rotary_pos_emb(q_w, k_w, cos_w, sin_w, pos_w, unsqueeze_dim=unsqueeze_dim)
185
+
186
+ # concat back
187
+ q_out = torch.cat([q_h, q_w], dim=-1)
188
+ k_out = torch.cat([k_h, k_w], dim=-1)
189
+ return q_out, k_out
190
+
191
+
192
+ def build_2d_position_ids(attention_mask: torch.Tensor,
193
+ flatten: bool = True):
194
+ """
195
+ attention_mask: Tensor [BS, nvar, num_p] with 1 for valid patches, 0 for padding.
196
+
197
+ Returns:
198
+ If flatten is True:
199
+ pos_var_flat: LongTensor [BS, nvar*num_p]
200
+ pos_patch_flat: LongTensor [BS, nvar*num_p]
201
+ Else:
202
+ pos_var: LongTensor [BS, nvar, num_p]
203
+ pos_patch: LongTensor [BS, nvar, num_p]
204
+ """
205
+ assert attention_mask.dim() == 3, "attention_mask must be [BS, nvar, num_p]"
206
+ B, V, P = attention_mask.shape
207
+ mask = attention_mask.to(dtype=torch.long)
208
+
209
+ # per patch index within each variable, ignores padding
210
+ pos_patch = (mask.cumsum(dim=-1) - 1) * mask # [B, V, P]
211
+
212
+ # per variable index, ignores variables that are entirely padded
213
+ var_valid = mask.any(dim=-1).to(dtype=torch.long) # [B, V]
214
+ pos_var_base = (var_valid.cumsum(dim=1) - 1) * var_valid # [B, V]
215
+ pos_var = pos_var_base.unsqueeze(-1).expand(B, V, P) * mask # [B, V, P]
216
+
217
+ if flatten:
218
+ return pos_var.reshape(B, V * P).long(), pos_patch.reshape(B, V * P).long()
219
+
220
+ return pos_var.long(), pos_patch.long()
221
+
222
+ def build_1d_position_ids(attention_mask: torch.Tensor):
223
+ """
224
+ Build 1D position ids for [BS, nvar, num_p],
225
+ output shape [BS * nvar, num_p].
226
+
227
+ Each (batch, variable) pair gets its own 1D position index sequence
228
+ along the patch axis, skipping padded positions.
229
+
230
+ Args:
231
+ attention_mask: Tensor [BS, nvar, num_p], 1 for valid, 0 for padding.
232
+
233
+ Returns:
234
+ pos_ids: LongTensor [BS * nvar, num_p]
235
+ """
236
+ assert attention_mask.dim() == 3, "attention_mask must be [BS, nvar, num_p]"
237
+ B, V, P = attention_mask.shape
238
+ mask = attention_mask.to(dtype=torch.long)
239
+
240
+ # Compute per-variable cumulative index
241
+ pos_ids = (mask.cumsum(dim=-1) - 1) * mask # [B, V, P]
242
+
243
+ # Reshape to [BS * nvar, num_p]
244
+ pos_ids = pos_ids.view(B * V, P).long()
245
+
246
+ return pos_ids