# -*- coding: utf-8 -*- """ The original LWM-1.1 implementation is available at: https://huggingface.co/wi-lab/lwm-v1.1/tree/main We extend our highest respect to wi-lab and thank them for their outstanding contributions to original LWM. """ from tqdm import tqdm import os import csv import torch import torch.nn as nn from lwm_model import lwm, lwm_tokenizer, create_dataloader import numpy as np from torch.optim import AdamW from collections import defaultdict def split_and_save_indices_same_seed(manual_data_list, used_ratio=1.0, train_ratio=0.50, val_ratio=0.50): all_indices = {} for i, data in enumerate(manual_data_list): total_num = data.shape[0] indices = np.arange(total_num) np.random.shuffle(indices) train_end = int(train_ratio * total_num) val_end = int((train_ratio + val_ratio) * total_num) used_end = int(train_end * used_ratio) train_idx = indices[:train_end] train_idx = train_idx[:used_end] # val_idx = indices[train_end:val_end] all_idx_list = [train_idx, val_idx] np.savez(f"all_indices_{i}_{used_ratio}.npz", train_id=train_idx, val_id=val_idx) all_indices[f'array_{i}'] = all_idx_list return all_indices def nmse_loss(y_pred, y_true): y_pred_flat = y_pred.view(y_pred.size(0), -1) y_true_flat = y_true.view(y_true.size(0), -1) mse = torch.sum((y_true_flat - y_pred_flat) ** 2, dim=-1) normalization = torch.sum(y_true_flat ** 2, dim=-1) return mse / normalization def train_lwm(model, train_loaders, val_loaders, optimizer, save_model, epochs, device, save_dir="models", log_file="training_log.csv"): if not os.path.exists(save_dir): os.makedirs(save_dir) # Initialize CSV log if not os.path.exists(log_file): with open(log_file, mode='w', newline='') as file: writer = csv.writer(file) writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"]) train_nmse_losses = [] val_nmse_losses = [] best_val_nmse = float('inf') start_epoch = 0 for epoch in range(start_epoch, epochs): model.train() train_nmse = 0.0 train_samples = 0 # Training loop across all buckets print(f"\nEpoch {epoch + 1}/{epochs} [Training]") for length, train_loader in train_loaders.items(): print(f"Processing sequences of length {length}") with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t: for batch in t: optimizer.zero_grad() input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch] logits_lm, _, _ = model(input_ids, masked_pos) loss = torch.sum(nmse_loss(masked_tokens, logits_lm)) loss.backward() optimizer.step() train_nmse += loss.item() train_samples += input_ids.shape[0] t.set_postfix({"nmse": train_nmse / train_samples}) # Average NMSE across training batches train_nmse /= max(train_samples, 1) train_nmse_losses.append(train_nmse) if epoch % 1 == 0: # Validation loop across all buckets model.eval() val_nmse_list=[] val_nmse = 0.0 val_samples = 0 with torch.no_grad(): print(f"\nEpoch {epoch + 1}/{epochs} [Validation]") for length, val_loader in val_loaders.items(): print(f"Processing sequences of length {length}") with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t: for batch in t: input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch] logits_lm, _, _ = model(input_ids, masked_pos) test = nmse_loss(masked_tokens, logits_lm) loss = torch.sum(test) val_nmse += loss.item() val_samples += input_ids.shape[0] val_nmse_list.append(test) t.set_postfix({"nmse": val_nmse / val_samples}) val_nmse /= max(val_samples, 1) val_nmse_losses.append(val_nmse) # Save model if validation NMSE improves is_best_model = False if val_nmse < best_val_nmse: best_val_nmse = val_nmse model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth") if save_model: torch.save(model.state_dict(), model_path) print(f"Model saved: {model_path}") is_best_model = True # Log the results print(f" Train NMSE: {train_nmse:.4f}") print(f" Validation NMSE: {val_nmse:.4f}") # Append to CSV log with open(log_file, mode='a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch + 1, train_nmse, val_nmse, optimizer.param_groups[0]['lr'], is_best_model]) print("Training and validation complete.") return model def generate_mask_pos(num, total_num, allow_point_num): total_point = total_num all_pos = np.arange(1, total_point + 1) init_pos, inter, n, L = 1, int(np.ceil(total_point / num)), num, int(allow_point_num / num) un_msk_pos = np.array([init_pos + l + i * inter for i in range(num) for l in range(L)]) msk_pos = np.setdiff1d(all_pos, un_msk_pos) return msk_pos def merge_dicts(dict_list): merged = defaultdict(list) for d in dict_list: for key, value in d.items(): merged[key].extend(value) return dict(merged) if __name__ == '__main__': # 请手动修改以下参数 SAVE_DIR = "model" LOG_FILE = "training.csv" MASK_PERCENT = 0.90 save_model = False scenario_name = "Boston_28G" gpu_ids = [0] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 设置LWM训练超参数 EPOCHS = 20000 # 10000-100% 14000-80% BATCH_SIZE = 128 D_MODEL = 128 MAX_LEN = 513 N_LAYERS = 12 WEIGHT_DECAY = 0.05 BETA1 = 0.9 BETA2 = 0.999 N_HEADS = 8 DROPOUT = 0.1 BASE_LR = 5e-5 SEED = 0 TEST = False torch.manual_seed(SEED) np.random.seed(SEED) train_generator = torch.Generator() train_generator.manual_seed(SEED) manual_data = [np.load(f"./dataset/{scenario_name}.npy")] indices_dict = split_and_save_indices_same_seed(manual_data, used_ratio=1.0) ranges = [data.shape[1] for data in manual_data] steps = [MASK_PERCENT] mask_pos_list = [[np.sort(np.random.choice(np.arange(1, range_max+1), size=int(range_max*step), replace=False)) for step in steps] for range_max in ranges] pre_train_dict = {} key_counter = 0 for mask_idx in range(len(steps)): for data_idx in range(len(ranges)): pre_train_dict[key_counter] = lwm_tokenizer( manual_data=manual_data[data_idx][indices_dict[f'array_{data_idx}'][0]], patch_rows=1, patch_cols=16, mask=True, seed=None, masking_percent=MASK_PERCENT, mask_pos=mask_pos_list[data_idx][mask_idx] ) key_counter += 1 preprocessed_train_data = {} for i in range(len(pre_train_dict)): preprocessed_train_data[i] = pre_train_dict[i][0] train_loaders = create_dataloader(preprocessed_train_data, batch_size=BATCH_SIZE, shuffle=True, generator=train_generator) pre_val_dict = {} key_counter = 0 for mask_idx in range(len(steps)): for data_idx in range(len(ranges)): pre_val_dict[key_counter] = lwm_tokenizer( manual_data=manual_data[data_idx][indices_dict[f'array_{data_idx}'][1]], patch_rows=1, patch_cols=16, mask=True, seed=None, masking_percent=MASK_PERCENT, mask_pos=mask_pos_list[data_idx][mask_idx] ) key_counter += 1 preprocessed_val_data = {} for i in range(len(pre_val_dict)): preprocessed_val_data[i] = pre_val_dict[i][0] val_loaders = create_dataloader(preprocessed_val_data, batch_size=BATCH_SIZE, shuffle=False) # 构建LWM模型 model = lwm(d_model=D_MODEL, dropout=DROPOUT).to(device) pretrained_lwm_dict = torch.load("./ExtLWM_sub16.pth", map_location=device) pretrained_lwm_dict = {k.replace("module.", ""): v for k, v in pretrained_lwm_dict.items()} model.load_state_dict(pretrained_lwm_dict, strict=False) model = nn.DataParallel(model, gpu_ids) optimizer = AdamW( model.parameters(), lr=BASE_LR, betas=(BETA1, BETA2), weight_decay=WEIGHT_DECAY ) pretrained_model = train_lwm( model, train_loaders, val_loaders, optimizer, save_model, EPOCHS, device=device, save_dir=SAVE_DIR, log_file=LOG_FILE, )