| |
|
| | """
|
| | The original LWM-1.1 implementation is available at:
|
| |
|
| | https://huggingface.co/wi-lab/lwm-v1.1/tree/main
|
| |
|
| | We extend our highest respect to wi-lab and thank them for their outstanding contributions to original LWM.
|
| | """
|
| |
|
| | from tqdm import tqdm
|
| | import os
|
| | import csv
|
| | import torch
|
| | import torch.nn as nn
|
| | from lwm_model import lwm, lwm_tokenizer, create_dataloader
|
| | import numpy as np
|
| | from torch.optim import AdamW
|
| | from collections import defaultdict
|
| |
|
| |
|
| | def split_and_save_indices_same_seed(manual_data_list, used_ratio=1.0, train_ratio=0.50, val_ratio=0.50):
|
| | all_indices = {}
|
| | for i, data in enumerate(manual_data_list):
|
| | total_num = data.shape[0]
|
| | indices = np.arange(total_num)
|
| | np.random.shuffle(indices)
|
| | train_end = int(train_ratio * total_num)
|
| | val_end = int((train_ratio + val_ratio) * total_num)
|
| | used_end = int(train_end * used_ratio)
|
| | train_idx = indices[:train_end]
|
| | train_idx = train_idx[:used_end]
|
| | val_idx = indices[train_end:val_end]
|
| | all_idx_list = [train_idx, val_idx]
|
| | np.savez(f"all_indices_{i}_{used_ratio}.npz", train_id=train_idx, val_id=val_idx)
|
| | all_indices[f'array_{i}'] = all_idx_list
|
| | return all_indices
|
| |
|
| |
|
| | def nmse_loss(y_pred, y_true):
|
| | y_pred_flat = y_pred.view(y_pred.size(0), -1)
|
| | y_true_flat = y_true.view(y_true.size(0), -1)
|
| | mse = torch.sum((y_true_flat - y_pred_flat) ** 2, dim=-1)
|
| | normalization = torch.sum(y_true_flat ** 2, dim=-1)
|
| | return mse / normalization
|
| |
|
| |
|
| | def train_lwm(model, train_loaders, val_loaders, optimizer, save_model, epochs, device, save_dir="models", log_file="training_log.csv"):
|
| | if not os.path.exists(save_dir):
|
| | os.makedirs(save_dir)
|
| |
|
| |
|
| |
|
| | if not os.path.exists(log_file):
|
| | with open(log_file, mode='w', newline='') as file:
|
| | writer = csv.writer(file)
|
| | writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"])
|
| |
|
| | train_nmse_losses = []
|
| | val_nmse_losses = []
|
| | best_val_nmse = float('inf')
|
| |
|
| | start_epoch = 0
|
| |
|
| | for epoch in range(start_epoch, epochs):
|
| | model.train()
|
| | train_nmse = 0.0
|
| | train_samples = 0
|
| |
|
| |
|
| | print(f"\nEpoch {epoch + 1}/{epochs} [Training]")
|
| | for length, train_loader in train_loaders.items():
|
| | print(f"Processing sequences of length {length}")
|
| | with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t:
|
| | for batch in t:
|
| | optimizer.zero_grad()
|
| | input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
|
| | logits_lm, _, _ = model(input_ids, masked_pos)
|
| | loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
|
| | loss.backward()
|
| | optimizer.step()
|
| | train_nmse += loss.item()
|
| | train_samples += input_ids.shape[0]
|
| | t.set_postfix({"nmse": train_nmse / train_samples})
|
| |
|
| |
|
| | train_nmse /= max(train_samples, 1)
|
| | train_nmse_losses.append(train_nmse)
|
| |
|
| | if epoch % 1 == 0:
|
| |
|
| | model.eval()
|
| | val_nmse_list=[]
|
| | val_nmse = 0.0
|
| | val_samples = 0
|
| | with torch.no_grad():
|
| | print(f"\nEpoch {epoch + 1}/{epochs} [Validation]")
|
| | for length, val_loader in val_loaders.items():
|
| | print(f"Processing sequences of length {length}")
|
| | with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t:
|
| | for batch in t:
|
| | input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
|
| | logits_lm, _, _ = model(input_ids, masked_pos)
|
| | test = nmse_loss(masked_tokens, logits_lm)
|
| | loss = torch.sum(test)
|
| | val_nmse += loss.item()
|
| | val_samples += input_ids.shape[0]
|
| | val_nmse_list.append(test)
|
| | t.set_postfix({"nmse": val_nmse / val_samples})
|
| |
|
| | val_nmse /= max(val_samples, 1)
|
| | val_nmse_losses.append(val_nmse)
|
| |
|
| |
|
| | is_best_model = False
|
| | if val_nmse < best_val_nmse:
|
| | best_val_nmse = val_nmse
|
| | model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth")
|
| | if save_model:
|
| | torch.save(model.state_dict(), model_path)
|
| | print(f"Model saved: {model_path}")
|
| | is_best_model = True
|
| |
|
| |
|
| | print(f" Train NMSE: {train_nmse:.4f}")
|
| | print(f" Validation NMSE: {val_nmse:.4f}")
|
| |
|
| |
|
| | with open(log_file, mode='a', newline='') as file:
|
| | writer = csv.writer(file)
|
| | writer.writerow([epoch + 1, train_nmse, val_nmse, optimizer.param_groups[0]['lr'], is_best_model])
|
| |
|
| | print("Training and validation complete.")
|
| | return model
|
| |
|
| |
|
| | def generate_mask_pos(num, total_num, allow_point_num):
|
| | total_point = total_num
|
| | all_pos = np.arange(1, total_point + 1)
|
| | init_pos, inter, n, L = 1, int(np.ceil(total_point / num)), num, int(allow_point_num / num)
|
| | un_msk_pos = np.array([init_pos + l + i * inter for i in range(num) for l in range(L)])
|
| | msk_pos = np.setdiff1d(all_pos, un_msk_pos)
|
| | return msk_pos
|
| |
|
| | def merge_dicts(dict_list):
|
| | merged = defaultdict(list)
|
| | for d in dict_list:
|
| | for key, value in d.items():
|
| | merged[key].extend(value)
|
| | return dict(merged)
|
| |
|
| |
|
| |
|
| | if __name__ == '__main__':
|
| |
|
| |
|
| | SAVE_DIR = "model"
|
| | LOG_FILE = "training.csv"
|
| | MASK_PERCENT = 0.90
|
| | save_model = False
|
| | scenario_name = "Boston_28G"
|
| | gpu_ids = [0]
|
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| |
|
| |
|
| | EPOCHS = 20000
|
| | BATCH_SIZE = 128
|
| | D_MODEL = 128
|
| | MAX_LEN = 513
|
| | N_LAYERS = 12
|
| | WEIGHT_DECAY = 0.05
|
| | BETA1 = 0.9
|
| | BETA2 = 0.999
|
| | N_HEADS = 8
|
| | DROPOUT = 0.1
|
| | BASE_LR = 5e-5
|
| | SEED = 0
|
| | TEST = False
|
| |
|
| | torch.manual_seed(SEED)
|
| | np.random.seed(SEED)
|
| | train_generator = torch.Generator()
|
| | train_generator.manual_seed(SEED)
|
| |
|
| | manual_data = [np.load(f"./dataset/{scenario_name}.npy")]
|
| | indices_dict = split_and_save_indices_same_seed(manual_data, used_ratio=1.0)
|
| |
|
| | ranges = [data.shape[1] for data in manual_data]
|
| | steps = [MASK_PERCENT]
|
| | mask_pos_list = [[np.sort(np.random.choice(np.arange(1, range_max+1), size=int(range_max*step), replace=False)) for step in steps] for range_max in ranges]
|
| |
|
| | pre_train_dict = {}
|
| | key_counter = 0
|
| | for mask_idx in range(len(steps)):
|
| | for data_idx in range(len(ranges)):
|
| | pre_train_dict[key_counter] = lwm_tokenizer(
|
| | manual_data=manual_data[data_idx][indices_dict[f'array_{data_idx}'][0]],
|
| | patch_rows=1,
|
| | patch_cols=16,
|
| | mask=True,
|
| | seed=None,
|
| | masking_percent=MASK_PERCENT,
|
| | mask_pos=mask_pos_list[data_idx][mask_idx]
|
| | )
|
| | key_counter += 1
|
| |
|
| | preprocessed_train_data = {}
|
| | for i in range(len(pre_train_dict)):
|
| | preprocessed_train_data[i] = pre_train_dict[i][0]
|
| |
|
| | train_loaders = create_dataloader(preprocessed_train_data, batch_size=BATCH_SIZE, shuffle=True, generator=train_generator)
|
| |
|
| |
|
| | pre_val_dict = {}
|
| | key_counter = 0
|
| | for mask_idx in range(len(steps)):
|
| | for data_idx in range(len(ranges)):
|
| | pre_val_dict[key_counter] = lwm_tokenizer(
|
| | manual_data=manual_data[data_idx][indices_dict[f'array_{data_idx}'][1]],
|
| | patch_rows=1,
|
| | patch_cols=16,
|
| | mask=True,
|
| | seed=None,
|
| | masking_percent=MASK_PERCENT,
|
| | mask_pos=mask_pos_list[data_idx][mask_idx]
|
| | )
|
| | key_counter += 1
|
| |
|
| | preprocessed_val_data = {}
|
| | for i in range(len(pre_val_dict)):
|
| | preprocessed_val_data[i] = pre_val_dict[i][0]
|
| |
|
| | val_loaders = create_dataloader(preprocessed_val_data, batch_size=BATCH_SIZE, shuffle=False)
|
| |
|
| |
|
| |
|
| | model = lwm(d_model=D_MODEL, dropout=DROPOUT).to(device)
|
| | pretrained_lwm_dict = torch.load("./ExtLWM_sub16.pth", map_location=device)
|
| | pretrained_lwm_dict = {k.replace("module.", ""): v for k, v in pretrained_lwm_dict.items()}
|
| | model.load_state_dict(pretrained_lwm_dict, strict=False)
|
| | model = nn.DataParallel(model, gpu_ids)
|
| |
|
| | optimizer = AdamW(
|
| | model.parameters(),
|
| | lr=BASE_LR,
|
| | betas=(BETA1, BETA2),
|
| | weight_decay=WEIGHT_DECAY
|
| | )
|
| |
|
| | pretrained_model = train_lwm(
|
| | model,
|
| | train_loaders,
|
| | val_loaders,
|
| | optimizer,
|
| | save_model,
|
| | EPOCHS,
|
| | device=device,
|
| | save_dir=SAVE_DIR,
|
| | log_file=LOG_FILE,
|
| | ) |