Login2025 commited on
Commit
53ca419
·
verified ·
1 Parent(s): 4fa12ac

Upload 5 files

Browse files
Files changed (5) hide show
  1. ExtLWM_sub16.pth +3 -0
  2. ExtLWM_sub32.pth +3 -0
  3. ExtLWM_sub64.pth +3 -0
  4. lwm_model.py +299 -0
  5. lwm_train.py +259 -0
ExtLWM_sub16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0abccf7087201c67845cc423b62b5bcf1e869a55ab94b7be82a4e52a20804c45
3
+ size 9856811
ExtLWM_sub32.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7c0b2fb2be98455f0adaf73cec80d2b7f268e6bd27a7b00dedaa391580d5e2b
3
+ size 9807787
ExtLWM_sub64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:441dc9114b40607c4fc92828f52ee2b94b6a9aeaa5013b61b6d4a8662d6156df
3
+ size 9832619
lwm_model.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 19:23:54 2024
4
+
5
+ This script defines the LWM model architecture.
6
+
7
+ @author: Sadjad Alikhani
8
+ """
9
+ #%%
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ from collections import defaultdict
16
+ from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
17
+
18
+
19
+
20
+ def create_dataloader(grouped_data, batch_size, shuffle, generator=None):
21
+ dataloaders = {}
22
+
23
+ for seq_length, group in grouped_data.items():
24
+ print(f"dataloader in progress ...\nkey: {seq_length}")
25
+
26
+ ## Uncomment the following line if you run out of memory during pre-training
27
+ # batch_size = batch_size // 8 if seq_length >= 5 else batch_size
28
+
29
+ # Unpack samples for the current group
30
+ input_ids, masked_tokens, masked_pos = zip(*group)
31
+
32
+ # Convert to tensors
33
+ input_ids_tensor = torch.tensor(input_ids, dtype=torch.float32)
34
+ masked_tokens_tensor = torch.tensor(masked_tokens, dtype=torch.float32)
35
+ masked_pos_tensor = torch.tensor(masked_pos, dtype=torch.long)
36
+
37
+ # Create TensorDataset and DataLoader
38
+ dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
39
+ dataloaders[seq_length] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, generator=generator)
40
+
41
+ return dataloaders
42
+
43
+
44
+ def lwm_tokenizer(manual_data, patch_rows, patch_cols, masking_percent=.40, mask=False, mask_pos=None, seed=42, ):
45
+ patches = [patch_maker(np.array(manual_data), patch_rows, patch_cols)]
46
+ patches = [patch for patch_list in patches for patch in patch_list] # list(Batch)
47
+
48
+ grouped_data = defaultdict(list) # Group samples by sequence length
49
+ grouped_data_2 = []
50
+
51
+ # for user_idx in tqdm(range(len(patches)), desc="Processing items"):
52
+ for user_idx in range(len(patches)):
53
+ patch_size = patches[user_idx].shape[1]
54
+ n_patches = patches[user_idx].shape[0]
55
+ n_masks_half = int(masking_percent * n_patches)
56
+
57
+ word2id = {
58
+ '[CLS]': 0.2 * np.ones((patch_size)),
59
+ '[MASK]': 0.1 * np.ones((patch_size))
60
+ }
61
+
62
+ sample = make_sample(
63
+ user_idx, patches, word2id, n_patches, n_masks_half, mask_pos, mask=mask, seed=seed
64
+ )
65
+
66
+ if mask:
67
+ seq_length = len(sample[0])
68
+ grouped_data[seq_length].append(sample)
69
+ else:
70
+ grouped_data_2.append(sample)
71
+
72
+ if mask:
73
+ # Normalize keys to 0, 1, 2, ...
74
+ normalized_grouped_data = {i: grouped_data[key] for i, key in enumerate(sorted(grouped_data.keys()))}
75
+ else:
76
+ normalized_grouped_data = torch.stack(grouped_data_2, dim=0)
77
+
78
+ return normalized_grouped_data
79
+
80
+
81
+
82
+ def make_sample(user_idx, patch, word2id, n_patches, n_masks, mask_pos=None, mask=True, seed=None):
83
+ if seed is not None:
84
+ np.random.seed(seed)
85
+
86
+ # Step 1: Retrieve tokens and prepend [CLS]
87
+ tokens = patch[user_idx]
88
+ input_ids = np.vstack((word2id['[CLS]'], tokens))
89
+
90
+ # Step 2: Mask real and imaginary patches
91
+ tokens_size = int(n_patches) # int(n_patches / 2)
92
+ if mask_pos is not None:
93
+ masked_pos = mask_pos
94
+ else:
95
+ masked_pos = np.random.choice(range(1, tokens_size+1), size=n_masks, replace=False)
96
+
97
+ masked_tokens = []
98
+ for pos in masked_pos:
99
+ original_masked_tokens = input_ids[pos].copy()
100
+ masked_tokens.append(original_masked_tokens)
101
+ if mask:
102
+ input_ids[pos] = word2id['[MASK]']
103
+ # rnd_num = np.random.rand()
104
+ # if rnd_num < 0.1:
105
+ # input_ids[pos] = np.random.rand(32) # Replace with random values
106
+ # elif rnd_num < 0.9:
107
+ # input_ids[pos] = word2id['[MASK]'] # Replace with [MASK]
108
+
109
+ if mask:
110
+ return [input_ids, masked_tokens, masked_pos]
111
+ else:
112
+ return torch.tensor(input_ids)
113
+
114
+
115
+ def patch_maker(original_ch, patch_rows=1, patch_cols=16):
116
+ n_samples, n_rows, n_cols = original_ch.shape
117
+
118
+ # Step 2: Split into real and imaginary parts and interleave them
119
+ flat_real = original_ch.real
120
+ flat_imag = original_ch.imag
121
+
122
+ # Interleave real and imaginary parts along the last axis
123
+ interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
124
+ interleaved[:, :, 0::2] = flat_real
125
+ interleaved[:, :, 1::2] = flat_imag
126
+
127
+ # Step 3: Compute the number of patches along rows and columns
128
+ n_patches_rows = int(np.ceil(n_rows / patch_rows))
129
+ n_patches_cols = int(np.ceil(n_cols / patch_cols))
130
+
131
+ # Step 4: Pad the matrix if necessary to make it divisible by patch size
132
+ padded_rows = n_patches_rows * patch_rows - n_rows
133
+ padded_cols = n_patches_cols * patch_cols - n_cols
134
+ if padded_rows > 0 or padded_cols > 0:
135
+ interleaved = np.pad(
136
+ interleaved,
137
+ ((0, 0), (0, padded_rows), (0, padded_cols * 2)), # Double padding for interleaved axis
138
+ mode='constant',
139
+ constant_values=0,
140
+ )
141
+
142
+ # Step 5: Create patches by dividing into blocks
143
+ n_samples, padded_rows, padded_cols = interleaved.shape
144
+ padded_cols //= 2 # Adjust for interleaving (real and imaginary parts count as one)
145
+ patches = []
146
+
147
+ for i in range(0, padded_rows, patch_rows):
148
+ for j in range(0, padded_cols, patch_cols):
149
+ patch = interleaved[:, i:i + patch_rows, j * 2:(j + patch_cols) * 2]
150
+ patches.append(patch.reshape(n_samples, -1)) # Flatten each patch
151
+
152
+ # Step 6: Stack patches to form the final array
153
+ patches = np.stack(patches, axis=1) # Shape: (num_samples, n_patches, patch_rows * patch_cols * 2)
154
+
155
+ # nor_patches = patches
156
+ nor_patches = patches*1e6
157
+ return nor_patches
158
+
159
+ #%%
160
+ class LayerNormalization(nn.Module):
161
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
162
+ super().__init__()
163
+ self.eps = eps
164
+ self.alpha = nn.Parameter(torch.ones(d_model))
165
+ self.bias = nn.Parameter(torch.zeros(d_model))
166
+
167
+ def forward(self, x):
168
+ mean = x.mean(dim=-1, keepdim=True)
169
+ std = x.std(dim=-1, keepdim=True)
170
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
171
+
172
+
173
+ class Embedding(nn.Module):
174
+ def __init__(self, element_length, d_model, max_len=513):
175
+ super().__init__()
176
+ self.element_length = element_length
177
+ self.d_model = d_model
178
+ self.proj = nn.Linear(element_length, d_model)
179
+ self.pos_embed = nn.Embedding(max_len, d_model)
180
+ self.norm = LayerNormalization(d_model)
181
+
182
+ def forward(self, x):
183
+ seq_len = x.size(1)
184
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
185
+ pos_encodings = self.pos_embed(pos)
186
+ tok_emb = self.proj(x.float())
187
+ embedding = tok_emb + pos_encodings
188
+ return self.norm(embedding)
189
+
190
+
191
+ class ScaledDotProductAttention(nn.Module):
192
+ def __init__(self, d_k):
193
+ super().__init__()
194
+ self.d_k = d_k
195
+
196
+ def forward(self, Q, K, V):
197
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
198
+ attn = F.softmax(scores, dim=-1)
199
+ context = torch.matmul(attn, V)
200
+ return context, attn
201
+
202
+
203
+ class MultiHeadAttention(nn.Module):
204
+ def __init__(self, d_model, n_heads, dropout):
205
+ super().__init__()
206
+ self.d_k = d_model // n_heads
207
+ self.d_v = d_model // n_heads
208
+ self.n_heads = n_heads
209
+ self.W_Q = nn.Linear(d_model, self.d_k * n_heads)
210
+ self.W_K = nn.Linear(d_model, self.d_k * n_heads)
211
+ self.W_V = nn.Linear(d_model, self.d_v * n_heads)
212
+ self.linear = nn.Linear(n_heads * self.d_v, d_model)
213
+ self.dropout = nn.Dropout(dropout)
214
+ self.scaled_dot_attn = ScaledDotProductAttention(self.d_k)
215
+
216
+ def forward(self, Q, K, V):
217
+ residual, batch_size = Q, Q.size(0)
218
+ q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
219
+ k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
220
+ v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
221
+
222
+ context, attn = self.scaled_dot_attn(q_s, k_s, v_s)
223
+ output = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
224
+ output = self.linear(output)
225
+ return residual + self.dropout(output), attn
226
+
227
+
228
+ class PoswiseFeedForwardNet(nn.Module):
229
+ def __init__(self, d_model, d_ff, dropout):
230
+ super().__init__()
231
+ self.fc1 = nn.Linear(d_model, d_ff)
232
+ self.fc2 = nn.Linear(d_ff, d_model)
233
+ self.dropout = nn.Dropout(dropout)
234
+
235
+ def forward(self, x):
236
+ return self.fc2(self.dropout(F.relu(self.fc1(x))))
237
+
238
+
239
+ class EncoderLayer(nn.Module):
240
+ def __init__(self, d_model, n_heads, d_ff, dropout):
241
+ super().__init__()
242
+ self.enc_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
243
+ self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff, dropout)
244
+ self.norm1 = LayerNormalization(d_model)
245
+ self.norm2 = LayerNormalization(d_model)
246
+
247
+ def forward(self, enc_inputs):
248
+ # Self-Attention with Add & Norm
249
+ attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
250
+ attn_outputs = self.norm1(enc_inputs + attn_outputs) # Add & Norm
251
+
252
+ # Feed-Forward with Add & Norm
253
+ ff_outputs = self.pos_ffn(attn_outputs)
254
+ enc_outputs = self.norm2(attn_outputs + ff_outputs) # Add & Norm
255
+
256
+ return enc_outputs, attn
257
+
258
+
259
+ class lwm(nn.Module):
260
+ def __init__(self, element_length=32, d_model=128, n_layers=12, max_len=321, n_heads=8, dropout=0.1):
261
+ super().__init__()
262
+ self.embedding = Embedding(element_length, d_model, max_len)
263
+ self.layers = nn.ModuleList(
264
+ [EncoderLayer(d_model, n_heads, d_model*4, dropout) for _ in range(n_layers)]
265
+ )
266
+ self.linear = nn.Linear(d_model, d_model)
267
+ self.norm = LayerNormalization(d_model)
268
+
269
+ embed_weight = self.embedding.proj.weight
270
+ _, n_dim = embed_weight.size()
271
+ self.decoder = nn.Linear(d_model, n_dim, bias=False)
272
+ self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
273
+
274
+ @classmethod
275
+ def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda'):
276
+ model = cls().to(device)
277
+ model.load_state_dict(torch.load(ckpt_name, map_location=device))
278
+ print(f"Model loaded successfully from {ckpt_name}")
279
+ return model
280
+
281
+ def forward(self, input_ids, masked_pos=None):
282
+ # Step 1: Embedding
283
+ output = self.embedding(input_ids)
284
+ attention_maps = []
285
+
286
+ # Step 2: Pass through Encoder Layers
287
+ for layer in self.layers:
288
+ output, attn = layer(output)
289
+ attention_maps.append(attn)
290
+
291
+ # If masked_pos is provided, perform masked token prediction
292
+ if masked_pos is not None:
293
+ masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
294
+ h_masked = torch.gather(output, 1, masked_pos)
295
+ h_masked = self.norm(F.relu(self.linear(h_masked)))
296
+ logits_lm = self.decoder(h_masked) + self.decoder_bias
297
+ return logits_lm, output, attention_maps
298
+ else:
299
+ return output, attention_maps
lwm_train.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ The original LWM-1.1 implementation is available at:
4
+
5
+ https://huggingface.co/wi-lab/lwm-v1.1/tree/main
6
+
7
+ We extend our highest respect to wi-lab and thank them for their outstanding contributions to original LWM.
8
+ """
9
+
10
+ from tqdm import tqdm
11
+ import os
12
+ import csv
13
+ import torch
14
+ import torch.nn as nn
15
+ from lwm_model import lwm, lwm_tokenizer, create_dataloader
16
+ import numpy as np
17
+ from torch.optim import AdamW
18
+ from collections import defaultdict
19
+
20
+
21
+ def split_and_save_indices_same_seed(manual_data_list, used_ratio=1.0, train_ratio=0.50, val_ratio=0.50):
22
+ all_indices = {}
23
+ for i, data in enumerate(manual_data_list):
24
+ total_num = data.shape[0]
25
+ indices = np.arange(total_num)
26
+ np.random.shuffle(indices)
27
+ train_end = int(train_ratio * total_num)
28
+ val_end = int((train_ratio + val_ratio) * total_num)
29
+ used_end = int(train_end * used_ratio)
30
+ train_idx = indices[:train_end]
31
+ train_idx = train_idx[:used_end] #
32
+ val_idx = indices[train_end:val_end]
33
+ all_idx_list = [train_idx, val_idx]
34
+ np.savez(f"all_indices_{i}_{used_ratio}.npz", train_id=train_idx, val_id=val_idx)
35
+ all_indices[f'array_{i}'] = all_idx_list
36
+ return all_indices
37
+
38
+
39
+ def nmse_loss(y_pred, y_true):
40
+ y_pred_flat = y_pred.view(y_pred.size(0), -1)
41
+ y_true_flat = y_true.view(y_true.size(0), -1)
42
+ mse = torch.sum((y_true_flat - y_pred_flat) ** 2, dim=-1)
43
+ normalization = torch.sum(y_true_flat ** 2, dim=-1)
44
+ return mse / normalization
45
+
46
+
47
+ def train_lwm(model, train_loaders, val_loaders, optimizer, save_model, epochs, device, save_dir="models", log_file="training_log.csv"):
48
+ if not os.path.exists(save_dir):
49
+ os.makedirs(save_dir)
50
+
51
+
52
+ # Initialize CSV log
53
+ if not os.path.exists(log_file):
54
+ with open(log_file, mode='w', newline='') as file:
55
+ writer = csv.writer(file)
56
+ writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"])
57
+
58
+ train_nmse_losses = []
59
+ val_nmse_losses = []
60
+ best_val_nmse = float('inf')
61
+
62
+ start_epoch = 0
63
+
64
+ for epoch in range(start_epoch, epochs):
65
+ model.train()
66
+ train_nmse = 0.0
67
+ train_samples = 0
68
+
69
+ # Training loop across all buckets
70
+ print(f"\nEpoch {epoch + 1}/{epochs} [Training]")
71
+ for length, train_loader in train_loaders.items():
72
+ print(f"Processing sequences of length {length}")
73
+ with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t:
74
+ for batch in t:
75
+ optimizer.zero_grad()
76
+ input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
77
+ logits_lm, _, _ = model(input_ids, masked_pos)
78
+ loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
79
+ loss.backward()
80
+ optimizer.step()
81
+ train_nmse += loss.item()
82
+ train_samples += input_ids.shape[0]
83
+ t.set_postfix({"nmse": train_nmse / train_samples})
84
+
85
+ # Average NMSE across training batches
86
+ train_nmse /= max(train_samples, 1)
87
+ train_nmse_losses.append(train_nmse)
88
+
89
+ if epoch % 1 == 0:
90
+ # Validation loop across all buckets
91
+ model.eval()
92
+ val_nmse_list=[]
93
+ val_nmse = 0.0
94
+ val_samples = 0
95
+ with torch.no_grad():
96
+ print(f"\nEpoch {epoch + 1}/{epochs} [Validation]")
97
+ for length, val_loader in val_loaders.items():
98
+ print(f"Processing sequences of length {length}")
99
+ with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t:
100
+ for batch in t:
101
+ input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
102
+ logits_lm, _, _ = model(input_ids, masked_pos)
103
+ test = nmse_loss(masked_tokens, logits_lm)
104
+ loss = torch.sum(test)
105
+ val_nmse += loss.item()
106
+ val_samples += input_ids.shape[0]
107
+ val_nmse_list.append(test)
108
+ t.set_postfix({"nmse": val_nmse / val_samples})
109
+
110
+ val_nmse /= max(val_samples, 1)
111
+ val_nmse_losses.append(val_nmse)
112
+
113
+ # Save model if validation NMSE improves
114
+ is_best_model = False
115
+ if val_nmse < best_val_nmse:
116
+ best_val_nmse = val_nmse
117
+ model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth")
118
+ if save_model:
119
+ torch.save(model.state_dict(), model_path)
120
+ print(f"Model saved: {model_path}")
121
+ is_best_model = True
122
+
123
+ # Log the results
124
+ print(f" Train NMSE: {train_nmse:.4f}")
125
+ print(f" Validation NMSE: {val_nmse:.4f}")
126
+
127
+ # Append to CSV log
128
+ with open(log_file, mode='a', newline='') as file:
129
+ writer = csv.writer(file)
130
+ writer.writerow([epoch + 1, train_nmse, val_nmse, optimizer.param_groups[0]['lr'], is_best_model])
131
+
132
+ print("Training and validation complete.")
133
+ return model
134
+
135
+
136
+ def generate_mask_pos(num, total_num, allow_point_num):
137
+ total_point = total_num
138
+ all_pos = np.arange(1, total_point + 1)
139
+ init_pos, inter, n, L = 1, int(np.ceil(total_point / num)), num, int(allow_point_num / num)
140
+ un_msk_pos = np.array([init_pos + l + i * inter for i in range(num) for l in range(L)])
141
+ msk_pos = np.setdiff1d(all_pos, un_msk_pos)
142
+ return msk_pos
143
+
144
+ def merge_dicts(dict_list):
145
+ merged = defaultdict(list)
146
+ for d in dict_list:
147
+ for key, value in d.items():
148
+ merged[key].extend(value)
149
+ return dict(merged)
150
+
151
+
152
+
153
+ if __name__ == '__main__':
154
+
155
+ # 请手动修改以下参数
156
+ SAVE_DIR = "model"
157
+ LOG_FILE = "training.csv"
158
+ MASK_PERCENT = 0.90
159
+ save_model = False
160
+ scenario_name = "Boston_28G"
161
+ gpu_ids = [0]
162
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
163
+
164
+ # 设置LWM训练超参数
165
+ EPOCHS = 20000 # 10000-100% 14000-80%
166
+ BATCH_SIZE = 128
167
+ D_MODEL = 128
168
+ MAX_LEN = 513
169
+ N_LAYERS = 12
170
+ WEIGHT_DECAY = 0.05
171
+ BETA1 = 0.9
172
+ BETA2 = 0.999
173
+ N_HEADS = 8
174
+ DROPOUT = 0.1
175
+ BASE_LR = 5e-5
176
+ SEED = 0
177
+ TEST = False
178
+
179
+ torch.manual_seed(SEED)
180
+ np.random.seed(SEED)
181
+ train_generator = torch.Generator()
182
+ train_generator.manual_seed(SEED)
183
+
184
+ manual_data = [np.load(f"./dataset/{scenario_name}.npy")]
185
+ indices_dict = split_and_save_indices_same_seed(manual_data, used_ratio=1.0)
186
+
187
+ ranges = [data.shape[1] for data in manual_data]
188
+ steps = [MASK_PERCENT]
189
+ mask_pos_list = [[np.sort(np.random.choice(np.arange(1, range_max+1), size=int(range_max*step), replace=False)) for step in steps] for range_max in ranges]
190
+
191
+ pre_train_dict = {}
192
+ key_counter = 0
193
+ for mask_idx in range(len(steps)):
194
+ for data_idx in range(len(ranges)):
195
+ pre_train_dict[key_counter] = lwm_tokenizer(
196
+ manual_data=manual_data[data_idx][indices_dict[f'array_{data_idx}'][0]],
197
+ patch_rows=1,
198
+ patch_cols=16,
199
+ mask=True,
200
+ seed=None,
201
+ masking_percent=MASK_PERCENT,
202
+ mask_pos=mask_pos_list[data_idx][mask_idx]
203
+ )
204
+ key_counter += 1
205
+
206
+ preprocessed_train_data = {}
207
+ for i in range(len(pre_train_dict)):
208
+ preprocessed_train_data[i] = pre_train_dict[i][0]
209
+
210
+ train_loaders = create_dataloader(preprocessed_train_data, batch_size=BATCH_SIZE, shuffle=True, generator=train_generator)
211
+
212
+
213
+ pre_val_dict = {}
214
+ key_counter = 0
215
+ for mask_idx in range(len(steps)):
216
+ for data_idx in range(len(ranges)):
217
+ pre_val_dict[key_counter] = lwm_tokenizer(
218
+ manual_data=manual_data[data_idx][indices_dict[f'array_{data_idx}'][1]],
219
+ patch_rows=1,
220
+ patch_cols=16,
221
+ mask=True,
222
+ seed=None,
223
+ masking_percent=MASK_PERCENT,
224
+ mask_pos=mask_pos_list[data_idx][mask_idx]
225
+ )
226
+ key_counter += 1
227
+
228
+ preprocessed_val_data = {}
229
+ for i in range(len(pre_val_dict)):
230
+ preprocessed_val_data[i] = pre_val_dict[i][0]
231
+
232
+ val_loaders = create_dataloader(preprocessed_val_data, batch_size=BATCH_SIZE, shuffle=False)
233
+
234
+ # 构建LWM模型
235
+
236
+ model = lwm(d_model=D_MODEL, dropout=DROPOUT).to(device)
237
+ pretrained_lwm_dict = torch.load("./ExtLWM_sub16.pth", map_location=device)
238
+ pretrained_lwm_dict = {k.replace("module.", ""): v for k, v in pretrained_lwm_dict.items()}
239
+ model.load_state_dict(pretrained_lwm_dict, strict=False)
240
+ model = nn.DataParallel(model, gpu_ids)
241
+
242
+ optimizer = AdamW(
243
+ model.parameters(),
244
+ lr=BASE_LR,
245
+ betas=(BETA1, BETA2),
246
+ weight_decay=WEIGHT_DECAY
247
+ )
248
+
249
+ pretrained_model = train_lwm(
250
+ model,
251
+ train_loaders,
252
+ val_loaders,
253
+ optimizer,
254
+ save_model,
255
+ EPOCHS,
256
+ device=device,
257
+ save_dir=SAVE_DIR,
258
+ log_file=LOG_FILE,
259
+ )