ccloud0525 commited on
Commit
d5cfa8f
·
1 Parent(s): ed6e4db

feat: first commit

Browse files
__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ '''
2
+ * @author: EmpyreanMoon
3
+ *
4
+ * @create: 2025-07-17 19:20
5
+ *
6
+ * @description:
7
+ '''
bert_config/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "gradient_checkpointing": false,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "transformers_version": "4.6.0.dev0",
20
+ "type_vocab_size": 2,
21
+ "use_cache": true,
22
+ "vocab_size": 30522
23
+ }
bert_config/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
bert_config/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "model_max_length": 512}
bert_config/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "aurora_base",
3
+ "architectures": [
4
+ "AuroraForPrediction"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_aurora.AuroraConfig",
8
+ "AutoModelForCausalLM": "modeling_aurora.AuroraForPrediction"
9
+ },
10
+ "dropout_rate": 0.2,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 256,
13
+ "token_len": 48,
14
+ "intermediate_size": 512,
15
+ "max_position_embeddings": 10000,
16
+ "model_type": "aurora",
17
+ "num_attention_heads": 8,
18
+ "num_enc_layers": 1,
19
+ "num_dec_layers": 9,
20
+ "rope_theta": 10000,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.40.1",
23
+ "num_sampling_steps": 50,
24
+ "flow_loss_depth": 3,
25
+ "diffusion_batch_mul": 4,
26
+ "threshold_ratio": [0.2, 0.3, 0.4, 0.5],
27
+ "mask_ratio": 0.5,
28
+ "norm_mode": "batch",
29
+ "num_prototypes": 1000,
30
+ "num_retriever_enc_layers": 1,
31
+ "num_retriever_dec_layers": 1,
32
+ "num_text_cross_layers": 1,
33
+ "num_vision_cross_layers": 1,
34
+ "num_text_connect_layers": 1,
35
+ "num_vision_connect_layers": 1,
36
+ "num_distill": 10
37
+ }
configuration_aurora.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class AuroraConfig(PretrainedConfig):
5
+ model_type = "aurora"
6
+
7
+ def __init__(
8
+ self,
9
+ token_len: int = 48,
10
+ hidden_size: int = 512,
11
+ intermediate_size: int = 1024,
12
+ num_enc_layers: int = 12,
13
+ num_dec_layers: int = 12,
14
+ num_attention_heads: int = 8,
15
+ hidden_act: str = "silu",
16
+ rope_theta: int = 10000,
17
+ dropout_rate: float = 0.2,
18
+ max_position_embeddings: int = 10000,
19
+ num_sampling_steps: int = 50,
20
+ flow_loss_depth: int = 3,
21
+ diffusion_batch_mul: int = 4,
22
+ threshold_ratio: list[float] = [0.2, 0.3, 0.4, 0.5],
23
+ mask_ratio: float = 0.5,
24
+ norm_mode: str = 'batch',
25
+ num_prototypes: int = 1024,
26
+ num_retriever_enc_layers: int = 1,
27
+ num_retriever_dec_layers: int = 1,
28
+ num_text_cross_layers: int = 1,
29
+ num_vision_cross_layers: int = 1,
30
+ num_text_connect_layers: int = 1,
31
+ num_vision_connect_layers: int = 1,
32
+ num_distill: int = 10,
33
+ **kwargs,
34
+ ):
35
+ self.token_len = token_len
36
+ self.hidden_size = hidden_size
37
+ self.intermediate_size = intermediate_size
38
+ self.num_enc_layers = num_enc_layers
39
+ self.num_dec_layers = num_dec_layers
40
+ self.num_attention_heads = num_attention_heads
41
+ self.hidden_act = hidden_act
42
+ self.rope_theta = rope_theta
43
+ self.dropout_rate = dropout_rate
44
+ self.max_position_embeddings = max_position_embeddings
45
+ self.num_sampling_steps = num_sampling_steps
46
+ self.flow_loss_depth = flow_loss_depth
47
+ self.diffusion_batch_mul = diffusion_batch_mul
48
+ self.threshold_ratio = threshold_ratio
49
+ self.mask_ratio = mask_ratio
50
+ self.norm_mode = norm_mode
51
+ self.num_prototypes = num_prototypes
52
+ self.num_retriever_enc_layers = num_retriever_enc_layers
53
+ self.num_retriever_dec_layers = num_retriever_dec_layers
54
+ self.num_text_cross_layers = num_text_cross_layers
55
+ self.num_vision_cross_layers = num_vision_cross_layers
56
+ self.num_text_connect_layers = num_text_connect_layers
57
+ self.num_vision_connect_layers = num_vision_connect_layers
58
+ self.num_distill = num_distill
59
+
60
+ super().__init__(
61
+ **kwargs,
62
+ )
flow_loss.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .util_functions import resample
7
+
8
+
9
+ class FlowLoss(nn.Module):
10
+ """Flow Loss"""
11
+
12
+ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps):
13
+ super(FlowLoss, self).__init__()
14
+ self.in_channels = target_channels
15
+ self.net = SimpleMLPAdaLN(
16
+ in_channels=target_channels,
17
+ model_channels=width,
18
+ out_channels=target_channels,
19
+ z_channels=z_channels,
20
+ num_res_blocks=depth
21
+ )
22
+ self.num_sampling_steps = num_sampling_steps
23
+
24
+ def forward(self, target, z, prototype=None, mask=None, eps=1e2):
25
+ noise = torch.randn_like(target)
26
+ t = torch.rand(target.shape[0], device=target.device)
27
+
28
+ if prototype is not None:
29
+ noised_target = t[:, None] * target + (1 - t[:, None]) * (prototype + noise)
30
+ else:
31
+ noised_target = t[:, None] * target + (1 - t[:, None]) * noise
32
+
33
+ predict_v = self.net(noised_target, t * 1000, z)
34
+
35
+ loss = ((predict_v - target) ** 2)
36
+ if mask is not None:
37
+ loss = (loss * mask).sum(dim=-1) / mask.sum(dim=-1)
38
+
39
+ value_mask = loss < eps
40
+ loss = loss[value_mask].sum() / value_mask.sum()
41
+
42
+ return loss.mean()
43
+
44
+ def sample(self, z, prototype=None, num_samples=1, inference_token_len=48):
45
+ z = z.repeat(num_samples, 1)
46
+ noise = torch.randn(z.shape[0], self.in_channels).to(z.device)
47
+ if prototype is not None:
48
+ prototype = prototype.repeat(num_samples, 1)
49
+ start_point = noise + prototype
50
+ x = noise + prototype
51
+ else:
52
+ start_point = noise
53
+ x = noise
54
+ dt = 1.0 / self.num_sampling_steps
55
+ for i in range(self.num_sampling_steps):
56
+ t = (torch.ones((x.shape[0])) * i /
57
+ self.num_sampling_steps).to(x.device)
58
+ pred = self.net(x, t * 1000, z)
59
+ x = x + (pred - start_point) * dt
60
+
61
+ if not self.training:
62
+ old_weight = torch.eye(self.in_channels).to(x.device)
63
+ new_weight = resample(old_weight, inference_token_len).T
64
+ x = F.linear(x, new_weight)
65
+ x = x.reshape(num_samples, -1, inference_token_len).transpose(0, 1)
66
+ return x
67
+
68
+ x = x.reshape(num_samples, -1, self.in_channels).transpose(0, 1)
69
+ return x
70
+
71
+
72
+ def modulate(x, shift, scale):
73
+ return x * (1 + scale) + shift
74
+
75
+
76
+ class TimestepEmbedder(nn.Module):
77
+ """
78
+ Embeds scalar timesteps into vector representations.
79
+ """
80
+
81
+ def __init__(self, hidden_size, frequency_embedding_size=256):
82
+ super().__init__()
83
+ self.mlp = nn.Sequential(
84
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
85
+ nn.SiLU(),
86
+ nn.Linear(hidden_size, hidden_size, bias=True),
87
+ )
88
+ self.frequency_embedding_size = frequency_embedding_size
89
+
90
+ @staticmethod
91
+ def timestep_embedding(t, dim, max_period=10000):
92
+ """
93
+ Create sinusoidal timestep embeddings.
94
+ :param t: a 1-D Tensor of N indices, one per batch element.
95
+ These may be fractional.
96
+ :param dim: the dimension of the output.
97
+ :param max_period: controls the minimum frequency of the embeddings.
98
+ :return: an (N, D) Tensor of positional embeddings.
99
+ """
100
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
101
+ half = dim // 2
102
+ freqs = torch.exp(
103
+ -math.log(max_period) * torch.arange(start=0,
104
+ end=half, dtype=torch.float32) / half
105
+ ).to(device=t.device)
106
+ args = t[:, None].float() * freqs[None]
107
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
108
+ if dim % 2:
109
+ embedding = torch.cat(
110
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
111
+ return embedding
112
+
113
+ def forward(self, t):
114
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
115
+ t_emb = self.mlp(t_freq)
116
+ return t_emb
117
+
118
+
119
+ class ResBlock(nn.Module):
120
+ """
121
+ A residual block that can optionally change the number of channels.
122
+ :param channels: the number of input channels.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ channels
128
+ ):
129
+ super().__init__()
130
+ self.channels = channels
131
+
132
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
133
+ self.mlp = nn.Sequential(
134
+ nn.Linear(channels, channels, bias=True),
135
+ nn.SiLU(),
136
+ nn.Linear(channels, channels, bias=True),
137
+ )
138
+
139
+ self.adaLN_modulation = nn.Sequential(
140
+ nn.SiLU(),
141
+ nn.Linear(channels, 3 * channels, bias=True)
142
+ )
143
+
144
+ def forward(self, x, y):
145
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
146
+ y).chunk(3, dim=-1)
147
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
148
+ h = self.mlp(h)
149
+ return x + gate_mlp * h
150
+
151
+
152
+ class FinalLayer(nn.Module):
153
+ """
154
+ The final layer adopted from DiT.
155
+ """
156
+
157
+ def __init__(self, model_channels, out_channels):
158
+ super().__init__()
159
+ self.norm_final = nn.LayerNorm(
160
+ model_channels, elementwise_affine=False, eps=1e-6)
161
+ self.linear = nn.Linear(model_channels, out_channels, bias=False)
162
+ self.adaLN_modulation = nn.Sequential(
163
+ nn.SiLU(),
164
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
165
+ )
166
+
167
+ def forward(self, x, c):
168
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
169
+ x = modulate(self.norm_final(x), shift, scale)
170
+ o = self.linear(x)
171
+ return o
172
+
173
+
174
+ class SimpleMLPAdaLN(nn.Module):
175
+ """
176
+ The MLP for Diffusion Loss.
177
+ :param in_channels: channels in the input Tensor.
178
+ :param model_channels: base channel count for the model.
179
+ :param out_channels: channels in the output Tensor.
180
+ :param z_channels: channels in the condition.
181
+ :param num_res_blocks: number of residual blocks per downsample.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ in_channels,
187
+ model_channels,
188
+ out_channels,
189
+ z_channels,
190
+ num_res_blocks,
191
+ ):
192
+ super().__init__()
193
+
194
+ self.in_channels = in_channels
195
+ self.model_channels = model_channels
196
+ self.out_channels = out_channels
197
+ self.num_res_blocks = num_res_blocks
198
+
199
+ self.time_embed = TimestepEmbedder(model_channels)
200
+ self.cond_embed = nn.Linear(z_channels, model_channels)
201
+
202
+ self.input_proj = nn.Linear(in_channels, model_channels)
203
+
204
+ res_blocks = []
205
+ for i in range(num_res_blocks):
206
+ res_blocks.append(ResBlock(
207
+ model_channels,
208
+ ))
209
+
210
+ self.res_blocks = nn.ModuleList(res_blocks)
211
+ self.final_layer = FinalLayer(model_channels, out_channels)
212
+
213
+ self.initialize_weights()
214
+
215
+ def initialize_weights(self):
216
+ def _basic_init(module):
217
+ if isinstance(module, nn.Linear):
218
+ torch.nn.init.xavier_uniform_(module.weight)
219
+ if module.bias is not None:
220
+ nn.init.constant_(module.bias, 0)
221
+
222
+ self.apply(_basic_init)
223
+
224
+ # Initialize timestep embedding MLP
225
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
226
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
227
+
228
+ # Zero-out adaLN modulation layers
229
+ for block in self.res_blocks:
230
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
231
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
232
+
233
+ # Zero-out output layers
234
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
235
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
236
+ nn.init.constant_(self.final_layer.linear.weight, 0)
237
+
238
+ def forward(self, x, t, c):
239
+ """
240
+ Apply the model to an input batch.
241
+ :param x: an [N x C] Tensor of inputs.
242
+ :param t: a 1-D batch of timesteps.
243
+ :param c: conditioning from AR transformer.
244
+ :return: an [N x C] Tensor of outputs.
245
+ """
246
+ x = self.input_proj(x)
247
+ t = self.time_embed(t)
248
+ c = self.cond_embed(c)
249
+ y = t + c
250
+
251
+ for block in self.res_blocks:
252
+ x = block(x, y)
253
+
254
+ return self.final_layer(x, y)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.40.1"
4
+ }
modality_connector.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import einops
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision.transforms import Resize
8
+ from transformers import ViTImageProcessor, ViTModel, BertModel, ViTConfig, BertConfig
9
+
10
+ from .configuration_aurora import AuroraConfig
11
+
12
+
13
+ class VisionEncoder(nn.Module):
14
+ config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vit_config')
15
+ def __init__(self, config: AuroraConfig):
16
+ super().__init__()
17
+ self.processor = UnifiedImageProcessor(config)
18
+ self.model = ViTModel(ViTConfig.from_json_file(os.path.join(self.config_path, 'config.json')))
19
+ for param in self.model.parameters():
20
+ param.requires_grad = False
21
+ self.hidden_size = self.model.config.hidden_size
22
+ self.output_dim = config.hidden_size
23
+ self.num_distill = config.num_distill
24
+
25
+ self.projection = nn.Linear(self.hidden_size, self.output_dim)
26
+
27
+ self.target_vision_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim))
28
+
29
+ # Cross-attention layer
30
+ self.cross_vision = nn.TransformerDecoder(
31
+ nn.TransformerDecoderLayer(
32
+ d_model=config.hidden_size,
33
+ nhead=config.num_attention_heads,
34
+ dim_feedforward=config.intermediate_size,
35
+ dropout=config.dropout_rate,
36
+ batch_first=True,
37
+ ),
38
+ norm=nn.LayerNorm(config.hidden_size),
39
+ num_layers=config.num_vision_cross_layers,
40
+ )
41
+
42
+ def extract_vit_features(self, image_tensor):
43
+ """
44
+ Extract image features using ViT
45
+ Args:
46
+ image_tensor: Preprocessed image tensor with shape [batch_size, 3, H, W]
47
+ Returns:
48
+ cls_feature: [CLS] token feature with shape [batch_size, hidden_size]
49
+ patch_features: Features of all patches with shape [batch_size, num_patches, hidden_size]
50
+ """
51
+ outputs = self.model(pixel_values=image_tensor)
52
+
53
+ last_hidden_state = outputs.last_hidden_state
54
+
55
+ cls_feature = last_hidden_state[:, 0, :] # [batch_size, hidden_size]
56
+
57
+ patch_features = last_hidden_state[:, 1:, :] # [batch_size, num_patches, hidden_size]
58
+
59
+ return cls_feature, patch_features
60
+
61
+ def forward(self, x, type='pseudo'):
62
+ x = self.processor(x, type=type)
63
+ _, patch_features = self.extract_vit_features(x)
64
+ patch_features = self.projection(patch_features)
65
+ target_vision_tokens = self.target_vision_tokens.unsqueeze(0).repeat(patch_features.shape[0], 1, 1)
66
+ output_tokens = self.cross_vision(target_vision_tokens, patch_features)
67
+ return output_tokens # [batch_size, num_patches, hidden_size]
68
+
69
+
70
+ class UnifiedImageProcessor(nn.Module):
71
+ config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vit_config')
72
+ def __init__(self, config: AuroraConfig):
73
+ super().__init__()
74
+ # Load ViT preprocessor to get pretrained normalization parameters and target size
75
+ self.vit_processor = ViTImageProcessor.from_json_file(os.path.join(self.config_path, 'preprocessor_config.json'))
76
+ self.target_size = self.vit_processor.size["height"] # e.g., 224 (default ViT input size)
77
+
78
+ # Define resizer for pseudo-images (matches real image target size)
79
+ self.pseudo_resizer = Resize((self.target_size, self.target_size))
80
+
81
+ self.token_len = config.token_len
82
+
83
+ def process_real_image(self, images):
84
+ """Process real images: automatic resizing, cropping, and normalization"""
85
+ # Directly use ViTImageProcessor to ensure consistency with pretraining pipeline
86
+ inputs = self.vit_processor(images=images, return_tensors="pt")
87
+ return inputs["pixel_values"] # Shape: [batch_size, 3, H, W]
88
+
89
+ def _period_search(self, x):
90
+ xf = torch.fft.rfft(x, dim=-1)
91
+ # find period by amplitudes
92
+ frequency_list = abs(xf).mean(0)
93
+ frequency_list[0] = 0
94
+ _, top_list = torch.topk(frequency_list, 1)
95
+ top_list = top_list.detach().cpu().numpy()
96
+ period = x.shape[1] // top_list
97
+ return period
98
+
99
+ def process_pseudo_image(self, x):
100
+ """Process pseudo-images (converted from time series): ensure consistent normalization with real images"""
101
+
102
+ # Segmentation
103
+ input_length = x.shape[-1]
104
+ period = list(self._period_search(x))[0]
105
+ period = period if 0 < period < input_length else self.token_len
106
+ if period > input_length:
107
+ period = input_length
108
+
109
+ padding_length = (period - (input_length %
110
+ period)) % period
111
+ x_pad = F.pad(x, (padding_length, 0))
112
+ x_2d = einops.rearrange(x_pad, 'b (p f) -> b 1 f p', f=period)
113
+
114
+ # 3. Render & Alignment
115
+ x_resize = self.pseudo_resizer(x_2d)
116
+ image_input = einops.repeat(x_resize, 'b 1 h w -> b c h w', c=3)
117
+ return image_input
118
+
119
+ def forward(self, x, type='pseudo'):
120
+ if type == 'pseudo':
121
+ return self.process_pseudo_image(x)
122
+ else:
123
+ return self.process_real_image(x)
124
+
125
+
126
+ class TextEncoder(nn.Module):
127
+ config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert_config')
128
+ def __init__(self, config: AuroraConfig):
129
+ super().__init__()
130
+ self.model = BertModel(BertConfig.from_json_file(os.path.join(self.config_path, 'config.json')))
131
+ for param in self.model.parameters():
132
+ param.requires_grad = False
133
+ self.hidden_size = self.model.config.hidden_size
134
+ self.output_dim = config.hidden_size
135
+ self.num_distill = config.num_distill
136
+ self.max_length = 125
137
+
138
+ self.projection = nn.Linear(self.hidden_size, self.output_dim)
139
+
140
+ # Define learnable target tokens (shape: [num_distill_tokens, hidden_size])
141
+ self.target_text_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim))
142
+
143
+ self.cross_text = nn.TransformerDecoder(
144
+ nn.TransformerDecoderLayer(
145
+ d_model=config.hidden_size,
146
+ nhead=config.num_attention_heads,
147
+ dim_feedforward=config.intermediate_size,
148
+ dropout=config.dropout_rate,
149
+ batch_first=True,
150
+ ),
151
+ norm=nn.LayerNorm(config.hidden_size),
152
+ num_layers=config.num_text_cross_layers,
153
+ )
154
+
155
+ def extract_bert_features(self, input_dict):
156
+ """Extract and clean BERT features with fixed output shape"""
157
+ outputs = self.model(**input_dict)
158
+
159
+ last_hidden_state = outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
160
+ cls_feature = last_hidden_state[:, 0, :] # [batch_size, hidden_size]
161
+ token_features = last_hidden_state
162
+
163
+ # Create mask to exclude [CLS], [SEP], and padding tokens
164
+ attention_mask = input_dict["attention_mask"] # [batch_size, seq_len]
165
+ batch_size, seq_len = attention_mask.shape
166
+ valid_mask = torch.ones_like(attention_mask)
167
+ valid_mask[:, 0] = 0 # Exclude [CLS]
168
+
169
+ for i in range(batch_size):
170
+ sep_pos = torch.where(attention_mask[i] == 1)[0][-1]
171
+ valid_mask[i, sep_pos] = 0 # Exclude [SEP]
172
+
173
+ # Apply mask and get valid tokens
174
+ valid_token_mask = valid_mask.unsqueeze(-1).expand(-1, -1, self.hidden_size)
175
+ clean_token_features = token_features * valid_token_mask
176
+
177
+ # Convert to fixed shape [batch_size, max_valid_tokens, hidden_size]
178
+ fixed_features = torch.zeros(batch_size, self.max_length, self.hidden_size,
179
+ device=clean_token_features.device)
180
+ valid_counts = []
181
+
182
+ for i in range(batch_size):
183
+ # Get valid tokens (excluding zeros)
184
+ valid_tokens = clean_token_features[i][clean_token_features[i].sum(dim=1) != 0]
185
+ valid_count = valid_tokens.shape[0]
186
+ valid_counts.append(valid_count)
187
+
188
+ # Truncate if longer than max_length, else pad with zeros
189
+ if valid_count > self.max_length:
190
+ fixed_features[i] = valid_tokens[:self.max_length]
191
+ else:
192
+ fixed_features[i, :valid_count] = valid_tokens
193
+
194
+ return cls_feature, token_features, fixed_features, valid_counts
195
+
196
+ def forward(self, texts):
197
+ """Return fixed-shape token features [batch_size, max_valid_tokens, hidden_size]"""
198
+ _, _, fixed_features, _ = self.extract_bert_features(texts)
199
+ fixed_features = self.projection(fixed_features)
200
+
201
+ target_text_tokens = self.target_text_tokens.unsqueeze(0).repeat(fixed_features.shape[0], 1, 1)
202
+
203
+ output_tokens = self.cross_text(target_text_tokens, fixed_features)
204
+ return output_tokens
205
+
206
+
207
+ class ModalityConnector(nn.Module):
208
+ def __init__(self, config: AuroraConfig):
209
+ """
210
+ Args:
211
+ hidden_size: Feature dimension (must match text/vision feature dimensions)
212
+ num_distill_tokens: Unified token count (constant N)
213
+ """
214
+ super().__init__()
215
+ self.hidden_size = config.hidden_size
216
+
217
+ # Define learnable target tokens (shape: [num_distill_tokens, hidden_size])
218
+ self.connect_text = nn.TransformerDecoder(
219
+ nn.TransformerDecoderLayer(
220
+ d_model=config.hidden_size,
221
+ nhead=config.num_attention_heads,
222
+ dim_feedforward=config.intermediate_size,
223
+ dropout=config.dropout_rate,
224
+ batch_first=True,
225
+ ),
226
+ norm=nn.LayerNorm(config.hidden_size),
227
+ num_layers=config.num_text_connect_layers,
228
+ )
229
+
230
+ self.connect_vision = nn.TransformerDecoder(
231
+ nn.TransformerDecoderLayer(
232
+ d_model=config.hidden_size,
233
+ nhead=config.num_attention_heads,
234
+ dim_feedforward=config.intermediate_size,
235
+ dropout=config.dropout_rate,
236
+ batch_first=True,
237
+ ),
238
+ norm=nn.LayerNorm(config.hidden_size),
239
+ num_layers=config.num_vision_connect_layers,
240
+ )
241
+
242
+ def forward(self, x, text_features, vision_features):
243
+ """
244
+ Distill text and vision tokens to the same count N
245
+ Args:
246
+ x: Time Series with shape [batch_size, n, hidden_size] (n is time series token count)
247
+ text_features: Text features with shape [batch_size, T, hidden_size] (T is text token count)
248
+ vision_features: Vision features with shape [batch_size, V, hidden_size] (V is vision token count)
249
+ Returns:
250
+ text_distilled: Distilled text tokens with shape [batch_size, N, hidden_size]
251
+ vision_distilled: Distilled vision tokens with shape [batch_size, N, hidden_size]
252
+ """
253
+ if text_features is not None:
254
+ from_text = self.connect_text(
255
+ x,
256
+ text_features
257
+ )
258
+ else:
259
+ from_text = None
260
+
261
+ from_vision = self.connect_vision(
262
+ x,
263
+ vision_features
264
+ )
265
+
266
+ return from_text, from_vision
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df2fb96852a59515a14552d5bddc35c03588b6a8bea69355984b3dd926a72b58
3
+ size 843564328
modeling_aurora.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch import nn
9
+ from transformers import PreTrainedModel
10
+ from transformers.activations import ACT2FN
11
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
12
+
13
+ from .configuration_aurora import AuroraConfig
14
+ from .flow_loss import FlowLoss
15
+ from .modality_connector import ModalityConnector, VisionEncoder, TextEncoder
16
+ from .prototype_retriever import PrototypeRetriever
17
+ from .ts_generation_mixin import TSGenerationMixin
18
+ from .util_functions import resample, Transpose, causal_attention_mask, RoPE_decoder
19
+
20
+
21
+ class AuroraPatchEmbedding(nn.Module):
22
+ def __init__(self, config: AuroraConfig):
23
+ super().__init__()
24
+ self.proj_layer = nn.Linear(config.token_len, config.hidden_size, bias=False)
25
+ self.token_len = config.token_len
26
+ self.threshold_ratio = config.threshold_ratio
27
+ self.mask_ratio = config.mask_ratio
28
+
29
+ def _freq_masking(self, x):
30
+ x_fft = torch.fft.rfft(x, dim=-1)
31
+ x_ifft_list = []
32
+ for ratio in self.threshold_ratio:
33
+ temp = x_fft.clone()
34
+ truncation = int(temp.shape[-1] * ratio)
35
+ if random.random() > self.mask_ratio:
36
+ temp[:, :truncation] = 0
37
+ else:
38
+ temp[:, truncation:] = 0
39
+
40
+ x_ifft = torch.fft.irfft(temp, dim=-1)
41
+ x_ifft_list.append(x_ifft)
42
+ x_ifft = torch.stack(x_ifft_list, dim=0)
43
+ return rearrange(x_ifft, 's b l -> (s b) l')
44
+
45
+ def _predict(self, x, inference_token_len=48):
46
+ input_length = x.shape[-1]
47
+ padding_length = (inference_token_len - (input_length %
48
+ inference_token_len)) % inference_token_len
49
+ x = F.pad(x, (padding_length, 0))
50
+ x = x.unfold(dimension=-1, size=inference_token_len,
51
+ step=inference_token_len)
52
+
53
+ resampled_weight = resample(old=self.proj_layer.weight.data, new_patch_len=inference_token_len)
54
+
55
+ output = F.linear(x, resampled_weight)
56
+
57
+ return output, None
58
+
59
+ def forward(self, x, inference_token_len=48):
60
+ if not self.training:
61
+ return self._predict(x, inference_token_len)
62
+
63
+ input_length = x.shape[-1]
64
+ padding_length = (self.token_len - (input_length %
65
+ self.token_len)) % self.token_len
66
+ x = F.pad(x, (padding_length, 0))
67
+
68
+ x_masked = self._freq_masking(x)
69
+
70
+ x_origin = x.unfold(dimension=-1, size=self.token_len,
71
+ step=self.token_len)
72
+ output_origin = self.proj_layer(x_origin)
73
+
74
+ x_masked = x_masked.unfold(dimension=-1, size=self.token_len,
75
+ step=self.token_len)
76
+ output_masked = self.proj_layer(x_masked)
77
+
78
+ return output_origin, output_masked
79
+
80
+
81
+ class AuroraAttention(nn.Module):
82
+ def __init__(self, config: AuroraConfig, layer_idx: Optional[int] = None, rope: bool = False):
83
+ super().__init__()
84
+ self.layer_idx = layer_idx
85
+ self.hidden_size = config.hidden_size
86
+ self.num_heads = config.num_attention_heads
87
+ self.head_dim = self.hidden_size // self.num_heads
88
+ self.attention_dropout = config.dropout_rate
89
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
90
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
91
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
92
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
93
+ self.rope = rope
94
+
95
+ def _scaled_dot_product_attention(self, Q, K, V, bias=None, attn_mask=None):
96
+ attn_scores = torch.matmul(Q, K.transpose(-2, -1))
97
+ attn_scores = attn_scores / math.sqrt(Q.size(-1))
98
+
99
+ if attn_mask is not None:
100
+ if attn_mask.dtype == torch.bool:
101
+ attn_scores = attn_scores.masked_fill(attn_mask, float('-inf'))
102
+ else:
103
+ attn_scores = attn_scores + attn_mask
104
+
105
+ if bias is not None:
106
+ if attn_scores.shape[0] > bias.shape[0]:
107
+ bias = bias.repeat(attn_scores.shape[0] // bias.shape[0], 1, 1, 1)
108
+ attn_scores += bias
109
+
110
+ attn_weights = F.softmax(attn_scores, dim=-1)
111
+
112
+ if self.attention_dropout > 0.0 and self.training:
113
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout)
114
+
115
+ attn_output = torch.matmul(attn_weights, V)
116
+
117
+ return attn_output, attn_scores
118
+
119
+ def forward(
120
+ self,
121
+ hidden_states: torch.Tensor,
122
+ key_embedding: torch.Tensor = None,
123
+ value_embedding: torch.Tensor = None,
124
+ attention_mask: Optional[torch.Tensor] = None,
125
+ output_attentions: bool = False,
126
+ bias: torch.Tensor = None,
127
+ **kwargs,
128
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
129
+ bsz, q_len, _ = hidden_states.size()
130
+
131
+ if key_embedding is None:
132
+ key_embedding = hidden_states
133
+ if value_embedding is None:
134
+ value_embedding = hidden_states
135
+
136
+ _, k_len, _ = key_embedding.size()
137
+ _, v_len, _ = value_embedding.size()
138
+
139
+ query_states = self.q_proj(hidden_states)
140
+ key_states = self.k_proj(key_embedding)
141
+ value_states = self.v_proj(value_embedding)
142
+
143
+ query_states = query_states.view(
144
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
145
+ key_states = key_states.view(
146
+ bsz, k_len, self.num_heads, self.head_dim).transpose(1, 2)
147
+ value_states = value_states.view(
148
+ bsz, v_len, self.num_heads, self.head_dim).transpose(1, 2)
149
+
150
+ if self.rope:
151
+ query_states, key_states = RoPE_decoder(query_states, key_states)
152
+
153
+ attn_output, attn_scores = self._scaled_dot_product_attention(
154
+ Q=query_states, K=key_states, V=value_states, bias=bias,
155
+ attn_mask=attention_mask)
156
+
157
+ attn_output = attn_output.transpose(1, 2).contiguous()
158
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
159
+ attn_output = self.o_proj(attn_output)
160
+
161
+ if not output_attentions:
162
+ attn_scores = None
163
+
164
+ return attn_output, attn_scores
165
+
166
+
167
+ class AuroraFFN(nn.Module):
168
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
169
+ super().__init__()
170
+ self.ffn = nn.Sequential(nn.Linear(hidden_size, intermediate_size),
171
+ ACT2FN[hidden_act],
172
+ nn.Linear(intermediate_size, hidden_size))
173
+
174
+ def forward(self, hidden_state):
175
+ return self.ffn(hidden_state)
176
+
177
+
178
+ class AuroraDecoderLayer(nn.Module):
179
+ def __init__(self, config: AuroraConfig, layer_idx: int):
180
+ super().__init__()
181
+ self.self_attn = AuroraAttention(config, layer_idx, rope=False)
182
+ self.cross_attn = AuroraAttention(config, layer_idx, rope=True)
183
+
184
+ self.ffn_layer = AuroraFFN(
185
+ hidden_size=config.hidden_size,
186
+ intermediate_size=config.intermediate_size,
187
+ hidden_act=config.hidden_act
188
+ )
189
+ if config.norm_mode == 'batch':
190
+ self.norm1 = nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(config.hidden_size), Transpose(1, 2))
191
+ self.norm2 = nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(config.hidden_size), Transpose(1, 2))
192
+ self.norm3 = nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(config.hidden_size), Transpose(1, 2))
193
+ else:
194
+ self.norm1 = torch.nn.LayerNorm(config.hidden_size)
195
+ self.norm2 = torch.nn.LayerNorm(config.hidden_size)
196
+ self.norm3 = torch.nn.LayerNorm(config.hidden_size)
197
+
198
+ def forward(
199
+ self,
200
+ hidden_states: torch.Tensor,
201
+ cross_states: torch.Tensor,
202
+ output_attentions: Optional[bool] = False,
203
+ **kwargs,
204
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
205
+ residual = hidden_states
206
+
207
+ num_token = hidden_states.shape[1]
208
+ attention_mask = causal_attention_mask(num_token).to(hidden_states.device)
209
+
210
+ # Self Attention
211
+ hidden_states, self_attn_weights = self.self_attn(
212
+ hidden_states=hidden_states,
213
+ attention_mask=attention_mask,
214
+ output_attentions=output_attentions,
215
+ )
216
+ x_attn = residual + self.norm1(hidden_states)
217
+
218
+ x_cross, cross_attn_weights = self.cross_attn(hidden_states=x_attn, key_embedding=cross_states,
219
+ value_embedding=cross_states)
220
+ x_cross = self.norm2(x_cross) + x_attn
221
+
222
+ # Fully Connected
223
+ output_states = self.ffn_layer(x_cross)
224
+ output_states = self.norm3(output_states) + x_cross
225
+
226
+ if not output_attentions:
227
+ self_attn_weights = None
228
+ cross_attn_weights = None
229
+
230
+ return output_states, self_attn_weights, cross_attn_weights
231
+
232
+
233
+ class AuroraEncoderLayer(nn.Module):
234
+ def __init__(self, config: AuroraConfig, layer_idx: int):
235
+ super().__init__()
236
+ self.self_attn = AuroraAttention(config, layer_idx, rope=False)
237
+ self.ffn_layer = AuroraFFN(
238
+ hidden_size=config.hidden_size,
239
+ intermediate_size=config.intermediate_size,
240
+ hidden_act=config.hidden_act
241
+ )
242
+
243
+ if config.norm_mode == 'batch':
244
+ self.norm1 = nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(config.hidden_size), Transpose(1, 2))
245
+ self.norm2 = nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(config.hidden_size), Transpose(1, 2))
246
+ else:
247
+ self.norm1 = torch.nn.LayerNorm(config.hidden_size)
248
+ self.norm2 = torch.nn.LayerNorm(config.hidden_size)
249
+
250
+ self.dropout_1 = nn.Dropout(config.dropout_rate)
251
+ self.dropout_2 = nn.Dropout(config.dropout_rate)
252
+
253
+ def forward(
254
+ self,
255
+ hidden_states: torch.Tensor,
256
+ output_attentions: Optional[bool] = False,
257
+ bias: torch.Tensor = None,
258
+ **kwargs
259
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
260
+ residual = hidden_states
261
+ # Self Attention
262
+ hidden_states, self_attn_weights = self.self_attn(
263
+ hidden_states=hidden_states,
264
+ output_attentions=output_attentions,
265
+ bias=bias
266
+ )
267
+ x_attn = self.norm1(residual + self.dropout_1(hidden_states))
268
+
269
+ # Fully Connected
270
+ output_states = self.ffn_layer(x_attn)
271
+ output_states = self.norm2(self.dropout_2(output_states) + x_attn)
272
+
273
+ if not output_attentions:
274
+ self_attn_weights = None
275
+
276
+ return output_states, self_attn_weights
277
+
278
+
279
+ class AuroraPredictHead(nn.Module):
280
+ def __init__(self, config: AuroraConfig):
281
+ super().__init__()
282
+ self.output_proj = nn.Linear(config.hidden_size, config.token_len, bias=False)
283
+ self.dropout = nn.Dropout(config.dropout_rate)
284
+
285
+ def _predict(self, hidden_states: torch.Tensor, inference_token_len=48):
286
+ resampled_weight = resample(old=self.output_proj.weight.data.T, new_patch_len=inference_token_len).T
287
+ output = F.linear(hidden_states, resampled_weight)
288
+ return output
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states: torch.Tensor,
293
+ inference_token_len: int = 48,
294
+ **kwargs
295
+ ) -> torch.FloatTensor:
296
+ if not self.training:
297
+ return self._predict(hidden_states, inference_token_len)
298
+
299
+ return self.output_proj(self.dropout(hidden_states))
300
+
301
+
302
+ class AuroraPreTrainedModel(PreTrainedModel):
303
+ config_class = AuroraConfig
304
+ base_model_prefix = "model"
305
+ supports_gradient_checkpointing = True
306
+ _no_split_modules = ["AuroraEncoderLayer", "AuroraDecoderLayer"]
307
+ _supports_flash_attn_2 = True
308
+ _supports_sdpa = False
309
+ _supports_cache_class = False
310
+
311
+
312
+ class AuroraModel(nn.Module):
313
+ def __init__(self, config: AuroraConfig):
314
+ super().__init__()
315
+ self.embed_layer = AuroraPatchEmbedding(config)
316
+ self.enc_layers = nn.ModuleList(
317
+ [AuroraEncoderLayer(config, layer_idx)
318
+ for layer_idx in range(config.num_enc_layers)]
319
+ )
320
+ self.dec_layers = nn.ModuleList(
321
+ [AuroraDecoderLayer(config, layer_idx)
322
+ for layer_idx in range(config.num_dec_layers)]
323
+ )
324
+ self.mask_num = len(config.threshold_ratio)
325
+ self.gradient_checkpointing = False
326
+
327
+ self.VisionEncoder = VisionEncoder(config)
328
+ self.TextEncoder = TextEncoder(config)
329
+ self.ModalityConnector = ModalityConnector(config)
330
+
331
+ self.VisionGuider = AuroraAttention(config)
332
+ self.TextGuider = AuroraAttention(config)
333
+
334
+ self.W = nn.Parameter(torch.eye(config.num_distill))
335
+ self.fuse = nn.Linear(config.hidden_size, config.hidden_size)
336
+
337
+ def forward(
338
+ self,
339
+ input_ids: torch.FloatTensor = None,
340
+ attention_mask: Optional[torch.Tensor] = None,
341
+ text_input_ids: Optional[torch.FloatTensor] = None,
342
+ text_attention_mask: Optional[torch.FloatTensor] = None,
343
+ text_token_type_ids: Optional[torch.FloatTensor] = None,
344
+ vision_ids: Optional[torch.FloatTensor] = None,
345
+ inputs_embeds: Optional[torch.FloatTensor] = None,
346
+ output_attentions: Optional[bool] = None,
347
+ output_hidden_states: Optional[bool] = None,
348
+ return_dict: Optional[bool] = None,
349
+ predict_token_num: Optional[int] = None,
350
+ inference_token_len: Optional[int] = None,
351
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
352
+ # input_ids is the input of time series, its shape is [batch_size, seq_len]
353
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
354
+ output_hidden_states = (
355
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
356
+ )
357
+
358
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
359
+
360
+ # retrieve input_ids and inputs_embeds
361
+ if input_ids is not None and inputs_embeds is not None:
362
+ raise ValueError(
363
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
364
+ elif input_ids is not None:
365
+ batch_size, seq_length = input_ids.shape
366
+ elif inputs_embeds is not None:
367
+ batch_size, seq_length, _ = inputs_embeds.shape
368
+ else:
369
+ raise ValueError(
370
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds")
371
+ if inference_token_len is None:
372
+ inference_token_len = self.config.token_len
373
+
374
+ masked_embeds = None
375
+ if inputs_embeds is None:
376
+ inputs_embeds, masked_embeds = self.embed_layer(input_ids, inference_token_len)
377
+
378
+ if masked_embeds is None:
379
+ x_enc = inputs_embeds
380
+ else:
381
+ x_enc = torch.concat([inputs_embeds, masked_embeds], dim=0)
382
+
383
+ if vision_ids is not None:
384
+ vision_features = self.VisionEncoder(vision_ids, type='real')
385
+ else:
386
+ vision_features = self.VisionEncoder(input_ids, type='pseudo')
387
+
388
+ _, attn_vision = self.VisionGuider(
389
+ inputs_embeds,
390
+ vision_features,
391
+ vision_features,
392
+ output_attentions=True
393
+ )
394
+
395
+ if text_input_ids is not None:
396
+ text_features = self.TextEncoder({'input_ids': text_input_ids, 'attention_mask': text_attention_mask,
397
+ 'token_type_ids': text_token_type_ids})
398
+ _, attn_text = self.TextGuider(
399
+ inputs_embeds,
400
+ text_features,
401
+ text_features,
402
+ output_attentions=True
403
+ )
404
+ else:
405
+ text_features = None
406
+ attn_text = None
407
+
408
+ if attn_text is not None:
409
+ guided_bias = torch.einsum("bhik,kl,bhjl->bhij", attn_vision, self.W, attn_text)
410
+ else:
411
+ guided_bias = None
412
+
413
+ # encoder layers
414
+ all_hidden_states = () if output_hidden_states else None
415
+ all_self_attns = () if output_attentions else None
416
+
417
+ for encoder_layer in self.enc_layers:
418
+ if output_hidden_states:
419
+ all_hidden_states += (x_enc,)
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+ layer_outputs = self._gradient_checkpointing_func(
423
+ encoder_layer.__call__,
424
+ x_enc,
425
+ output_attentions,
426
+ guided_bias
427
+ )
428
+ else:
429
+ layer_outputs = encoder_layer(
430
+ x_enc,
431
+ output_attentions=output_attentions,
432
+ bias=guided_bias
433
+ )
434
+
435
+ x_enc = layer_outputs[0]
436
+
437
+ if output_attentions:
438
+ all_self_attns += (layer_outputs[1],)
439
+
440
+ if x_enc.shape[0] > batch_size:
441
+ x_enc, x_rec = torch.split(x_enc, [batch_size, x_enc.shape[0] - batch_size], dim=0)
442
+ x_rec = rearrange(x_rec, '(s b) n d -> s b n d', s=self.mask_num)
443
+ x_rec = x_rec.mean(0)
444
+ else:
445
+ x_rec = None
446
+
447
+ decay_weights = 0.5 ** torch.arange(predict_token_num)
448
+ decay_weights = decay_weights.unsqueeze(0).unsqueeze(-1).to(x_enc.device)
449
+
450
+ from_text, from_vision = self.ModalityConnector(x_enc, text_features, vision_features)
451
+ if from_text is not None:
452
+ x_enc = x_enc + self.fuse(from_vision + from_text)
453
+ else:
454
+ x_enc = x_enc + self.fuse(from_vision)
455
+
456
+ last_token = x_enc[:, -1:, :]
457
+ x_dec = decay_weights * last_token.repeat(1, predict_token_num, 1)
458
+
459
+ # decoder layers
460
+ for decoder_layer in self.dec_layers:
461
+ if output_hidden_states:
462
+ all_hidden_states += (x_dec,)
463
+
464
+ if self.gradient_checkpointing and self.training:
465
+ layer_outputs = self._gradient_checkpointing_func(
466
+ decoder_layer.__call__,
467
+ x_dec,
468
+ x_enc,
469
+ output_attentions=output_attentions,
470
+ )
471
+ else:
472
+ layer_outputs = decoder_layer(
473
+ x_dec,
474
+ x_enc,
475
+ output_attentions=output_attentions
476
+ )
477
+
478
+ x_dec = layer_outputs[0]
479
+
480
+ if output_attentions:
481
+ all_self_attns += (layer_outputs[1],)
482
+
483
+ # add hidden states from the last decoder layer
484
+ if output_hidden_states:
485
+ all_hidden_states += (x_dec,)
486
+
487
+ if not return_dict:
488
+ return tuple(
489
+ v
490
+ for v in [x_dec, all_hidden_states, all_self_attns]
491
+ if v is not None
492
+ )
493
+
494
+ output_states = (x_rec, x_dec, from_text, from_vision)
495
+
496
+ return MoeModelOutputWithPast(
497
+ last_hidden_state=output_states,
498
+ hidden_states=all_hidden_states,
499
+ attentions=all_self_attns,
500
+ )
501
+
502
+
503
+ class AuroraForPrediction(AuroraPreTrainedModel, TSGenerationMixin):
504
+ def __init__(self, config: AuroraConfig):
505
+ super().__init__(config)
506
+ self.config = config
507
+ self.model = AuroraModel(config)
508
+ self.point_loss = torch.nn.MSELoss(reduction='none')
509
+ self.flow_match = FlowLoss(config.token_len, config.hidden_size, config.flow_loss_depth, config.hidden_size,
510
+ config.num_sampling_steps)
511
+ self.linear_head = AuroraPredictHead(config)
512
+
513
+ self.retriever = PrototypeRetriever(config)
514
+
515
+ def set_decoder(self, decoder):
516
+ self.model = decoder
517
+
518
+ def get_decoder(self):
519
+ return self.model
520
+
521
+ def forward(
522
+ self,
523
+ input_ids: torch.FloatTensor = None,
524
+ text_input_ids: torch.FloatTensor = None,
525
+ text_attention_mask: torch.FloatTensor = None,
526
+ text_token_type_ids: torch.FloatTensor = None,
527
+ vision_ids: torch.FloatTensor = None,
528
+ attention_mask: Optional[torch.Tensor] = None,
529
+ inputs_embeds: Optional[torch.FloatTensor] = None,
530
+ labels: Optional[torch.FloatTensor] = None,
531
+ loss_masks: Optional[torch.FloatTensor] = None,
532
+ mask_y: Optional[torch.FloatTensor] = None,
533
+ output_attentions: Optional[bool] = None,
534
+ output_hidden_states: Optional[bool] = None,
535
+ return_dict: Optional[bool] = None,
536
+ max_output_length: Optional[int] = None,
537
+ revin: Optional[bool] = True,
538
+ num_samples: Optional[int] = 1,
539
+ inference_token_len: Optional[int] = 48,
540
+ ):
541
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
542
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
543
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
544
+
545
+ if labels is not None:
546
+ if max_output_length is None:
547
+ max_output_length = labels.shape[1]
548
+ predict_token_num = math.ceil(max_output_length / self.config.token_len)
549
+ else:
550
+ predict_token_num = math.ceil(max_output_length / inference_token_len)
551
+
552
+ if revin:
553
+ means = input_ids.mean(1, keepdim=True).detach()
554
+ stdev = input_ids.std(dim=1, keepdim=True, unbiased=False).detach() + 1e-5
555
+ input_ids = (input_ids - means) / stdev
556
+
557
+ outputs = self.model(
558
+ input_ids=input_ids,
559
+ inputs_embeds=inputs_embeds,
560
+ text_input_ids=text_input_ids,
561
+ text_attention_mask=text_attention_mask,
562
+ text_token_type_ids=text_token_type_ids,
563
+ vision_ids=vision_ids,
564
+ output_attentions=output_attentions,
565
+ output_hidden_states=output_hidden_states,
566
+ return_dict=return_dict,
567
+ predict_token_num=predict_token_num,
568
+ inference_token_len=inference_token_len
569
+ )
570
+
571
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
572
+ x_rec, x_dec, from_text, from_vision = hidden_states
573
+
574
+ if from_text is not None:
575
+ generated_prototypes = self.retriever(from_text + from_vision, predict_token_num)
576
+ else:
577
+ generated_prototypes = self.retriever(from_vision, predict_token_num)
578
+
579
+ loss = None
580
+ predictions = None
581
+ eps = 1e2
582
+ mask = None
583
+ if labels is not None:
584
+ if revin:
585
+ origin_labels = labels
586
+ labels = (labels - means) / stdev
587
+
588
+ origin_length = labels.shape[-1]
589
+ target_length = predict_token_num * self.config.token_len
590
+ if origin_length < target_length:
591
+ pad_length = target_length - origin_length
592
+ labels = F.pad(labels, (0, pad_length))
593
+ mask = torch.tensor([1] * origin_length + [0] * pad_length, device=labels.device)
594
+ mask = mask.unsqueeze(0)
595
+
596
+ reco = rearrange(self.linear_head(x_rec), 'b n p -> b (n p)')
597
+ fore = rearrange(self.linear_head(x_dec), 'b n p -> b (n p)')
598
+ if revin:
599
+ fore = fore * stdev + means
600
+
601
+ reco_loss = self.point_loss(reco[:, :input_ids.shape[-1]], input_ids)
602
+ fore_loss = self.point_loss(fore[:, :origin_length], origin_labels)
603
+ reco_loss = reco_loss[reco_loss < eps]
604
+ fore_loss = fore_loss[fore_loss < eps]
605
+ point_loss = reco_loss.mean() + fore_loss.mean()
606
+
607
+ shift_labels = labels.unfold(
608
+ dimension=-1, size=self.config.token_len, step=self.config.token_len)
609
+ bsz, L, _ = shift_labels.shape
610
+ shift_labels = shift_labels.reshape(
611
+ bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1)
612
+ x_dec = x_dec.reshape(
613
+ bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1)
614
+ protos = generated_prototypes.reshape(bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1)
615
+ flow_loss = self.flow_match(target=shift_labels, z=x_dec.detach(), prototype=protos, eps=eps, mask=mask)
616
+ loss = point_loss + flow_loss
617
+
618
+ else:
619
+ predictions = self.flow_match.sample(z=rearrange(x_dec, 'b n d -> (b n) d'),
620
+ prototype=rearrange(generated_prototypes, 'b n p -> (b n) p'),
621
+ num_samples=num_samples,
622
+ inference_token_len=inference_token_len)
623
+ predictions = rearrange(predictions, '(b n) s p -> b s (n p)', n=predict_token_num)[:, :,
624
+ :max_output_length]
625
+
626
+ if revin:
627
+ stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1)
628
+ means = means.unsqueeze(1).repeat(1, num_samples, 1)
629
+ predictions = (predictions * stdev) + means
630
+
631
+ return MoeCausalLMOutputWithPast(
632
+ loss=loss,
633
+ logits=predictions,
634
+ hidden_states=outputs.hidden_states,
635
+ attentions=outputs.attentions,
636
+ )
prototype_retriever.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .configuration_aurora import AuroraConfig
6
+ from .util_functions import sinusoidal_position_embedding, causal_attention_mask
7
+
8
+
9
+ class PrototypeRetriever(nn.Module):
10
+ def __init__(self, config: AuroraConfig):
11
+ super().__init__()
12
+ self.hidden_size = config.hidden_size
13
+ self.intermediate_size = config.intermediate_size
14
+ self.num_prototypes = config.num_prototypes
15
+ self.token_len = config.token_len
16
+
17
+ # Define the learnable prototype parameter container.
18
+ # Initialize an empty Parameter first, to be filled in _initialize_prototypes.
19
+ self.prototypes = nn.Parameter(torch.empty(self.num_prototypes, self.token_len))
20
+
21
+ # Initialize prototypes using the new logic
22
+ self._initialize_prototypes()
23
+
24
+ self.retriever = Retriever(config)
25
+
26
+ def _initialize_prototypes(self, random_seed=42):
27
+ """
28
+ Initialize prototype parameters using diverse function generators.
29
+ Adapted from the generate_prototypes logic to fit the class structure.
30
+ """
31
+ # Set random seed for reproducibility
32
+ np.random.seed(random_seed)
33
+
34
+ length = self.token_len
35
+ # Create time series x, range from 0 to 10
36
+ x = np.linspace(0, 10, length)
37
+
38
+ prototypes_list = []
39
+
40
+ # --- Define internal generation functions ---
41
+ def generate_sin():
42
+ """Generate sine function features"""
43
+ freq = np.random.uniform(0.3, 2.0)
44
+ amp = np.random.uniform(0.5, 2.0)
45
+ phase = np.random.uniform(0, np.pi)
46
+ return amp * np.sin(freq * x + phase)
47
+
48
+ def generate_cos():
49
+ """Generate cosine function features"""
50
+ freq = np.random.uniform(0.3, 2.0)
51
+ amp = np.random.uniform(0.5, 2.0)
52
+ phase = np.random.uniform(0, np.pi)
53
+ return amp * np.cos(freq * x + phase)
54
+
55
+ def generate_log():
56
+ """Generate logarithmic function features (trend)"""
57
+ # Ensure x is positive, suitable for log function
58
+ x_log = x + np.random.uniform(0.5, 2.0)
59
+ slope = np.random.uniform(0.3, 1.5)
60
+ offset = np.random.uniform(-2.0, 2.0)
61
+ return slope * np.log(x_log) + offset
62
+
63
+ def generate_exponential():
64
+ """Generate exponential function features (trend)"""
65
+ # Can be positive or negative, allowing growth or decay
66
+ growth = np.random.uniform(-0.3, 0.3)
67
+ amp = np.random.uniform(0.5, 2.0)
68
+ return amp * np.exp(growth * x)
69
+
70
+ def generate_linear():
71
+ """Generate linear function features (trend)"""
72
+ slope = np.random.uniform(-1.0, 1.0)
73
+ intercept = np.random.uniform(-2.0, 2.0)
74
+ return slope * x + intercept
75
+
76
+ def generate_combination():
77
+ """Generate combined features from multiple functions"""
78
+ # Generate weights that sum to 1
79
+ weights = np.random.dirichlet(np.ones(3))
80
+ func1 = generate_sin()
81
+ func2 = generate_linear()
82
+ # Randomly select the third component
83
+ func3 = generate_exponential() if np.random.random() > 0.5 else generate_log()
84
+ return weights[0] * func1 + weights[1] * func2 + weights[2] * func3
85
+
86
+ # Function types and their probability distributions
87
+ functions = [
88
+ (generate_sin, 0.2),
89
+ (generate_cos, 0.2),
90
+ (generate_log, 0.15),
91
+ (generate_exponential, 0.15),
92
+ (generate_linear, 0.1),
93
+ (generate_combination, 0.2)
94
+ ]
95
+
96
+ # Extract functions and corresponding probabilities
97
+ funcs, probs = zip(*functions)
98
+
99
+ # --- Prototype generation loop ---
100
+ for _ in range(self.num_prototypes):
101
+ # Randomly select function type based on probability
102
+ func = np.random.choice(funcs, p=probs)
103
+ prototype = func()
104
+
105
+ # Add some noise
106
+ noise_level = np.random.uniform(0.05, 0.2)
107
+ noise = np.random.normal(0, noise_level, length)
108
+ prototype += noise
109
+
110
+ prototypes_list.append(prototype)
111
+
112
+ # Convert to Numpy array
113
+ prototypes_np = np.array(prototypes_list)
114
+
115
+ # --- Key step: Convert to Tensor and assign to Parameter ---
116
+ # 1. Convert to Tensor
117
+ # 2. Convert to float32 (numpy defaults to float64, PyTorch typically uses float32)
118
+ # 3. Use .data.copy_ to fill nn.Parameter, maintaining the gradient tracking mechanism
119
+ tensor_data = torch.from_numpy(prototypes_np).float()
120
+ self.prototypes.data.copy_(tensor_data)
121
+
122
+ def forward(self, x, output_token_len):
123
+ """
124
+ Args:
125
+ x: Input representation with shape [B, k, d]
126
+ Returns:
127
+ synthetic_protos: [B, F, p] (Normalized)
128
+ """
129
+ # Calculate distribution [B, F, M]
130
+ dist = self.retriever(x, output_token_len)
131
+
132
+ # Weighted combination of prototypes [B, F, p]
133
+ synthetic_protos = torch.matmul(dist, self.prototypes)
134
+
135
+ # Normalize
136
+ # Note: Since the new initialization logic generates values with larger ranges and noise,
137
+ # Instance Normalization here is crucial for output stability.
138
+ mean = synthetic_protos.mean(dim=-1, keepdim=True).detach()
139
+ std = synthetic_protos.std(dim=-1, keepdim=True).detach() + 1e-5
140
+ synthetic_protos = (synthetic_protos - mean) / std
141
+
142
+ return synthetic_protos
143
+
144
+
145
+ class Retriever(nn.Module):
146
+ def __init__(self, config: AuroraConfig):
147
+ super().__init__()
148
+ self.input_emb = nn.Sequential(nn.LayerNorm(config.hidden_size),
149
+ nn.Linear(config.hidden_size, config.hidden_size))
150
+ self.encoder = nn.TransformerEncoder(
151
+ nn.TransformerEncoderLayer(
152
+ d_model=config.hidden_size,
153
+ nhead=config.num_attention_heads,
154
+ dim_feedforward=config.intermediate_size,
155
+ dropout=config.dropout_rate,
156
+ batch_first=True,
157
+ ),
158
+ norm=nn.LayerNorm(config.hidden_size),
159
+ num_layers=config.num_retriever_enc_layers,
160
+ )
161
+ self.decoder = nn.TransformerEncoder(
162
+ nn.TransformerEncoderLayer(
163
+ d_model=config.hidden_size,
164
+ nhead=config.num_attention_heads,
165
+ dim_feedforward=config.intermediate_size,
166
+ dropout=config.dropout_rate,
167
+ batch_first=True,
168
+ ),
169
+ norm=nn.LayerNorm(config.hidden_size),
170
+ num_layers=config.num_retriever_dec_layers,
171
+ )
172
+
173
+ self.head = nn.Sequential(
174
+ nn.Linear(config.hidden_size, config.intermediate_size), # Combine context and position information
175
+ nn.LayerNorm(config.intermediate_size),
176
+ nn.SiLU(),
177
+ nn.Dropout(config.dropout_rate),
178
+ nn.Linear(config.intermediate_size, config.num_prototypes), # Predict prototype distribution
179
+ nn.Softmax(dim=-1)
180
+ )
181
+
182
+ self.hidden_size = config.hidden_size
183
+
184
+ def forward(self, x, output_token_len):
185
+ x_encoded = self.input_emb(x)
186
+ enc_attn_mask = causal_attention_mask(x.shape[1]).to(x.device)
187
+ enc_output = self.encoder(x_encoded, mask=enc_attn_mask.squeeze(0).squeeze(0)) # Shape: [B, k, d]
188
+
189
+ enc_output = enc_output[:, -1:, :]
190
+
191
+ dec = enc_output.repeat(1, output_token_len, 1)
192
+
193
+ pos_embeds = sinusoidal_position_embedding(
194
+ batch_size=dec.shape[0], num_heads=1,
195
+ max_len=output_token_len, output_dim=self.hidden_size,
196
+ device=dec.device).squeeze(1)
197
+
198
+ embeds = dec + pos_embeds
199
+
200
+ dec_attn_mask = causal_attention_mask(output_token_len).to(x.device)
201
+ dec_output = self.decoder(embeds, mask=dec_attn_mask.squeeze(0).squeeze(0))
202
+
203
+ dist = self.head(dec_output) # Shape: [B, F, M]
204
+
205
+ return dist
ts_generation_mixin.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional, Union, Callable
3
+
4
+ import torch
5
+ from transformers import BertTokenizer
6
+ from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
7
+ from transformers.generation.utils import GenerationConfig, GenerateOutput
8
+ from transformers.utils import ModelOutput
9
+
10
+
11
+ class TSGenerationMixin(GenerationMixin):
12
+ tokenizer = BertTokenizer.from_pretrained(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert_config'))
13
+
14
+ @torch.no_grad()
15
+ def generate(
16
+ self,
17
+ inputs: Optional[torch.Tensor] = None,
18
+ text_inputs=None,
19
+ text_input_ids: Optional[torch.Tensor] = None,
20
+ text_attention_mask: Optional[torch.Tensor] = None,
21
+ text_token_type_ids: Optional[torch.Tensor] = None,
22
+ vision_inputs: Optional[torch.Tensor] = None,
23
+ generation_config: Optional[GenerationConfig] = None,
24
+ logits_processor: Optional[LogitsProcessorList] = None,
25
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
26
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
27
+ synced_gpus: Optional[bool] = None,
28
+ assistant_model: Optional["PreTrainedModel"] = None,
29
+ streamer: Optional["BaseStreamer"] = None,
30
+ negative_prompt_ids: Optional[torch.Tensor] = None,
31
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
32
+ revin: Optional[bool] = True,
33
+ num_samples: Optional[int] = 1,
34
+ max_output_length: Optional[int] = 96,
35
+ inference_token_len: Optional[int] = None,
36
+ max_text_token_length: Optional[int] = 125,
37
+ **kwargs,
38
+ ) -> Union[GenerateOutput, torch.Tensor]:
39
+ if len(inputs.shape) != 2:
40
+ raise ValueError('Input shape must be: [batch_size, seq_len]')
41
+ if revin:
42
+ means = inputs.mean(dim=-1, keepdim=True)
43
+ stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5
44
+ inputs = (inputs - means) / stdev
45
+ if text_inputs is not None:
46
+ tokenized_text = self._tokenize(text_inputs, max_length=max_text_token_length)
47
+ text_input_ids = tokenized_text['input_ids'].squeeze(0)
48
+ text_attention_mask = tokenized_text['attention_mask'].squeeze(0)
49
+ text_token_type_ids = tokenized_text.get('token_type_ids', torch.zeros_like(text_input_ids)).squeeze(0)
50
+
51
+ model_inputs = self.prepare_inputs_for_generation(
52
+ inputs,
53
+ text_input_ids=text_input_ids,
54
+ text_attention_mask=text_attention_mask,
55
+ text_token_type_ids=text_token_type_ids,
56
+ vision_inputs=vision_inputs,
57
+ generation_config=generation_config,
58
+ max_output_length=max_output_length,
59
+ inference_token_len=inference_token_len,
60
+ **kwargs
61
+ )
62
+
63
+ outputs = self(**model_inputs, return_dict=True, revin=False, num_samples=num_samples)
64
+
65
+ predictions = outputs.logits
66
+
67
+ if revin:
68
+ stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1)
69
+ means = means.unsqueeze(1).repeat(1, num_samples, 1)
70
+ predictions = (predictions * stdev) + means
71
+
72
+ return predictions
73
+
74
+ def prepare_inputs_for_generation(
75
+ self,
76
+ inputs: torch.Tensor,
77
+ text_input_ids: Optional[torch.Tensor] = None,
78
+ text_attention_mask: Optional[torch.Tensor] = None,
79
+ text_token_type_ids: Optional[torch.Tensor] = None,
80
+ vision_inputs: Optional[torch.Tensor] = None,
81
+ generation_config: Optional[GenerationConfig] = None,
82
+ max_output_length: Optional[int] = None,
83
+ inference_token_len: Optional[int] = None,
84
+ **kwargs
85
+ ):
86
+ return {
87
+ "input_ids": inputs,
88
+ "text_input_ids": text_input_ids,
89
+ "text_attention_mask": text_attention_mask,
90
+ "text_token_type_ids": text_token_type_ids,
91
+ "vision_ids": vision_inputs,
92
+ "max_output_length": max_output_length,
93
+ "inference_token_len": inference_token_len,
94
+ **kwargs
95
+ }
96
+
97
+ def _tokenize(self, texts, max_length):
98
+ return self.tokenizer(
99
+ texts,
100
+ padding='max_length',
101
+ truncation=True,
102
+ max_length=max_length,
103
+ return_tensors="pt"
104
+ )
105
+
106
+ def _update_model_kwargs_for_generation(
107
+ self,
108
+ outputs: ModelOutput,
109
+ model_kwargs: Dict[str, Any],
110
+ horizon_length: int = 1,
111
+ is_encoder_decoder: bool = False,
112
+ standardize_cache_format: bool = False,
113
+ ) -> Dict[str, Any]:
114
+ return model_kwargs
util_functions.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def resize(x_tensor, new_shape):
10
+ return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear').squeeze(0)
11
+
12
+
13
+ def resample(old: torch.Tensor, new_patch_len: int):
14
+ assert old.dim() == 2, "the size of input tensor should be (d_model, patch_size)"
15
+ if old.size(1) == new_patch_len:
16
+ return old
17
+
18
+ old = old.T
19
+ old_shape = old.size(0)
20
+ factor = new_patch_len / old_shape
21
+
22
+ basis_vectors = torch.eye(old_shape, dtype=torch.get_default_dtype(), device=old.device)
23
+ resize_mat = resize(basis_vectors, new_patch_len).T
24
+ resize_mat_pinv = torch.linalg.pinv(resize_mat.T)
25
+
26
+ resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor)
27
+
28
+ return resampled_kernels.T
29
+
30
+
31
+ def RoPE(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ """
33
+ Apply Rotary Position Embedding (RoPE) to the query and key tensors.
34
+
35
+ Args:
36
+ query (torch.Tensor): Query tensor with shape (bs, head, max_len, output_dim).
37
+ key (torch.Tensor): Key tensor with shape (bs, head, max_len, output_dim).
38
+
39
+ Returns:
40
+ Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE.
41
+ """
42
+ # Get the shape information of the input tensors
43
+ batch_size, num_heads, max_len, output_dim = query.shape
44
+ # Generate sinusoidal position embeddings
45
+ pos_emb = sinusoidal_position_embedding(batch_size, num_heads, max_len, output_dim, query.device, factor=1)
46
+
47
+ # Extract cosine and sine position embeddings
48
+ cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
49
+ sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)
50
+
51
+ # Apply RoPE to the query tensor
52
+ query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape)
53
+ query = query * cos_pos + query_rot * sin_pos
54
+
55
+ # Apply RoPE to the key tensor
56
+ key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape)
57
+ key = key * cos_pos + key_rot * sin_pos
58
+
59
+ return query, key
60
+
61
+
62
+ def RoPE_decoder(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """
64
+ Apply Rotary Position Embedding (RoPE) to the query and key tensors in the decoder.
65
+
66
+ Args:
67
+ query (torch.Tensor): Query tensor with shape (bs, head, q_max_len, output_dim).
68
+ key (torch.Tensor): Key tensor with shape (bs, head, k_max_len, output_dim).
69
+
70
+ Returns:
71
+ Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE.
72
+ """
73
+ # Get the shape information of the input tensors
74
+ batch_size, num_heads, q_max_len, output_dim = query.shape
75
+ _, _, k_max_len, _ = key.shape
76
+ # Generate sinusoidal position embeddings
77
+ pos_emb = sinusoidal_position_embedding(batch_size, num_heads, k_max_len + q_max_len, output_dim, query.device,
78
+ factor=1)
79
+
80
+ # Extract cosine and sine position embeddings
81
+ cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
82
+ sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)
83
+
84
+ # Apply RoPE to the query tensor
85
+ query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape)
86
+ query = query * cos_pos[:, :, -q_max_len:, :] + query_rot * sin_pos[:, :, -q_max_len:, :]
87
+
88
+ # Apply RoPE to the key tensor
89
+ key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape)
90
+ key = key * cos_pos[:, :, :k_max_len, :] + key_rot * sin_pos[:, :, :k_max_len, :]
91
+
92
+ return query, key
93
+
94
+
95
+ def sinusoidal_position_embedding(
96
+ batch_size: int,
97
+ num_heads: int,
98
+ max_len: int,
99
+ output_dim: int,
100
+ device: torch.device,
101
+ factor: float = 1.0
102
+ ) -> torch.Tensor:
103
+ """
104
+ Generate sinusoidal position embeddings.
105
+
106
+ Args:
107
+ batch_size (int): Batch size.
108
+ num_heads (int): Number of attention heads.
109
+ max_len (int): Maximum sequence length.
110
+ output_dim (int): Output dimension.
111
+ device (torch.device): Device type.
112
+ factor (float, optional): Scaling factor. Defaults to 1.0.
113
+
114
+ Returns:
115
+ torch.Tensor: Sinusoidal position embedding tensor with shape (bs, head, max_len, output_dim).
116
+ """
117
+ # Generate position indices
118
+ position = torch.arange(0, max_len * factor, 1 / factor, dtype=torch.float).unsqueeze(-1)
119
+ # Generate frequency indices
120
+ ids = torch.arange(0, output_dim // 2, dtype=torch.float)
121
+ theta = torch.pow(10000, -2 * ids / output_dim)
122
+
123
+ # Calculate position embeddings
124
+ embeddings = position * theta
125
+ embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
126
+
127
+ # Expand dimensions to match batch size and number of attention heads
128
+ embeddings = embeddings.repeat((batch_size, num_heads, *([1] * len(embeddings.shape))))
129
+ embeddings = torch.reshape(embeddings, (batch_size, num_heads, -1, output_dim))
130
+ embeddings = embeddings.to(device)
131
+
132
+ # If the factor is greater than 1, perform interpolation
133
+ if factor > 1.0:
134
+ interpolation_indices = torch.linspace(0, embeddings.shape[2] - 1, max_len).long()
135
+ embeddings = embeddings[:, :, interpolation_indices, :]
136
+
137
+ return embeddings
138
+
139
+
140
+ def causal_attention_mask(seq_length):
141
+ mask = torch.triu(torch.ones(seq_length, seq_length) * float('-inf'), diagonal=1)
142
+ return mask.unsqueeze(0).unsqueeze(0)
143
+
144
+
145
+ class Transpose(nn.Module):
146
+ def __init__(self, *dims, contiguous=False):
147
+ super().__init__()
148
+ self.dims, self.contiguous = dims, contiguous
149
+
150
+ def forward(self, x):
151
+ if self.contiguous:
152
+ return x.transpose(*self.dims).contiguous()
153
+ else:
154
+ return x.transpose(*self.dims)
vit_config/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/vit-base-patch16-224-in21k",
3
+ "architectures": [
4
+ "ViTModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.0,
9
+ "hidden_size": 768,
10
+ "image_size": 224,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 3072,
13
+ "layer_norm_eps": 1e-12,
14
+ "model_type": "vit",
15
+ "num_attention_heads": 12,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 12,
18
+ "patch_size": 16,
19
+ "qkv_bias": true,
20
+ "transformers_version": "4.13.0.dev0"
21
+ }
vit_config/preprocessor_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_resize": true,
4
+ "image_mean": [
5
+ 0.5,
6
+ 0.5,
7
+ 0.5
8
+ ],
9
+ "image_std": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "size": 224
15
+ }