ExtLWM / lwm_train.py
Login2025's picture
Upload 5 files
53ca419 verified
# -*- coding: utf-8 -*-
"""
The original LWM-1.1 implementation is available at:
https://huggingface.co/wi-lab/lwm-v1.1/tree/main
We extend our highest respect to wi-lab and thank them for their outstanding contributions to original LWM.
"""
from tqdm import tqdm
import os
import csv
import torch
import torch.nn as nn
from lwm_model import lwm, lwm_tokenizer, create_dataloader
import numpy as np
from torch.optim import AdamW
from collections import defaultdict
def split_and_save_indices_same_seed(manual_data_list, used_ratio=1.0, train_ratio=0.50, val_ratio=0.50):
all_indices = {}
for i, data in enumerate(manual_data_list):
total_num = data.shape[0]
indices = np.arange(total_num)
np.random.shuffle(indices)
train_end = int(train_ratio * total_num)
val_end = int((train_ratio + val_ratio) * total_num)
used_end = int(train_end * used_ratio)
train_idx = indices[:train_end]
train_idx = train_idx[:used_end] #
val_idx = indices[train_end:val_end]
all_idx_list = [train_idx, val_idx]
np.savez(f"all_indices_{i}_{used_ratio}.npz", train_id=train_idx, val_id=val_idx)
all_indices[f'array_{i}'] = all_idx_list
return all_indices
def nmse_loss(y_pred, y_true):
y_pred_flat = y_pred.view(y_pred.size(0), -1)
y_true_flat = y_true.view(y_true.size(0), -1)
mse = torch.sum((y_true_flat - y_pred_flat) ** 2, dim=-1)
normalization = torch.sum(y_true_flat ** 2, dim=-1)
return mse / normalization
def train_lwm(model, train_loaders, val_loaders, optimizer, save_model, epochs, device, save_dir="models", log_file="training_log.csv"):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Initialize CSV log
if not os.path.exists(log_file):
with open(log_file, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"])
train_nmse_losses = []
val_nmse_losses = []
best_val_nmse = float('inf')
start_epoch = 0
for epoch in range(start_epoch, epochs):
model.train()
train_nmse = 0.0
train_samples = 0
# Training loop across all buckets
print(f"\nEpoch {epoch + 1}/{epochs} [Training]")
for length, train_loader in train_loaders.items():
print(f"Processing sequences of length {length}")
with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t:
for batch in t:
optimizer.zero_grad()
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
logits_lm, _, _ = model(input_ids, masked_pos)
loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
loss.backward()
optimizer.step()
train_nmse += loss.item()
train_samples += input_ids.shape[0]
t.set_postfix({"nmse": train_nmse / train_samples})
# Average NMSE across training batches
train_nmse /= max(train_samples, 1)
train_nmse_losses.append(train_nmse)
if epoch % 1 == 0:
# Validation loop across all buckets
model.eval()
val_nmse_list=[]
val_nmse = 0.0
val_samples = 0
with torch.no_grad():
print(f"\nEpoch {epoch + 1}/{epochs} [Validation]")
for length, val_loader in val_loaders.items():
print(f"Processing sequences of length {length}")
with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t:
for batch in t:
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
logits_lm, _, _ = model(input_ids, masked_pos)
test = nmse_loss(masked_tokens, logits_lm)
loss = torch.sum(test)
val_nmse += loss.item()
val_samples += input_ids.shape[0]
val_nmse_list.append(test)
t.set_postfix({"nmse": val_nmse / val_samples})
val_nmse /= max(val_samples, 1)
val_nmse_losses.append(val_nmse)
# Save model if validation NMSE improves
is_best_model = False
if val_nmse < best_val_nmse:
best_val_nmse = val_nmse
model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth")
if save_model:
torch.save(model.state_dict(), model_path)
print(f"Model saved: {model_path}")
is_best_model = True
# Log the results
print(f" Train NMSE: {train_nmse:.4f}")
print(f" Validation NMSE: {val_nmse:.4f}")
# Append to CSV log
with open(log_file, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, train_nmse, val_nmse, optimizer.param_groups[0]['lr'], is_best_model])
print("Training and validation complete.")
return model
def generate_mask_pos(num, total_num, allow_point_num):
total_point = total_num
all_pos = np.arange(1, total_point + 1)
init_pos, inter, n, L = 1, int(np.ceil(total_point / num)), num, int(allow_point_num / num)
un_msk_pos = np.array([init_pos + l + i * inter for i in range(num) for l in range(L)])
msk_pos = np.setdiff1d(all_pos, un_msk_pos)
return msk_pos
def merge_dicts(dict_list):
merged = defaultdict(list)
for d in dict_list:
for key, value in d.items():
merged[key].extend(value)
return dict(merged)
if __name__ == '__main__':
# 请手动修改以下参数
SAVE_DIR = "model"
LOG_FILE = "training.csv"
MASK_PERCENT = 0.90
save_model = False
scenario_name = "Boston_28G"
gpu_ids = [0]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 设置LWM训练超参数
EPOCHS = 20000 # 10000-100% 14000-80%
BATCH_SIZE = 128
D_MODEL = 128
MAX_LEN = 513
N_LAYERS = 12
WEIGHT_DECAY = 0.05
BETA1 = 0.9
BETA2 = 0.999
N_HEADS = 8
DROPOUT = 0.1
BASE_LR = 5e-5
SEED = 0
TEST = False
torch.manual_seed(SEED)
np.random.seed(SEED)
train_generator = torch.Generator()
train_generator.manual_seed(SEED)
manual_data = [np.load(f"./dataset/{scenario_name}.npy")]
indices_dict = split_and_save_indices_same_seed(manual_data, used_ratio=1.0)
ranges = [data.shape[1] for data in manual_data]
steps = [MASK_PERCENT]
mask_pos_list = [[np.sort(np.random.choice(np.arange(1, range_max+1), size=int(range_max*step), replace=False)) for step in steps] for range_max in ranges]
pre_train_dict = {}
key_counter = 0
for mask_idx in range(len(steps)):
for data_idx in range(len(ranges)):
pre_train_dict[key_counter] = lwm_tokenizer(
manual_data=manual_data[data_idx][indices_dict[f'array_{data_idx}'][0]],
patch_rows=1,
patch_cols=16,
mask=True,
seed=None,
masking_percent=MASK_PERCENT,
mask_pos=mask_pos_list[data_idx][mask_idx]
)
key_counter += 1
preprocessed_train_data = {}
for i in range(len(pre_train_dict)):
preprocessed_train_data[i] = pre_train_dict[i][0]
train_loaders = create_dataloader(preprocessed_train_data, batch_size=BATCH_SIZE, shuffle=True, generator=train_generator)
pre_val_dict = {}
key_counter = 0
for mask_idx in range(len(steps)):
for data_idx in range(len(ranges)):
pre_val_dict[key_counter] = lwm_tokenizer(
manual_data=manual_data[data_idx][indices_dict[f'array_{data_idx}'][1]],
patch_rows=1,
patch_cols=16,
mask=True,
seed=None,
masking_percent=MASK_PERCENT,
mask_pos=mask_pos_list[data_idx][mask_idx]
)
key_counter += 1
preprocessed_val_data = {}
for i in range(len(pre_val_dict)):
preprocessed_val_data[i] = pre_val_dict[i][0]
val_loaders = create_dataloader(preprocessed_val_data, batch_size=BATCH_SIZE, shuffle=False)
# 构建LWM模型
model = lwm(d_model=D_MODEL, dropout=DROPOUT).to(device)
pretrained_lwm_dict = torch.load("./ExtLWM_sub16.pth", map_location=device)
pretrained_lwm_dict = {k.replace("module.", ""): v for k, v in pretrained_lwm_dict.items()}
model.load_state_dict(pretrained_lwm_dict, strict=False)
model = nn.DataParallel(model, gpu_ids)
optimizer = AdamW(
model.parameters(),
lr=BASE_LR,
betas=(BETA1, BETA2),
weight_decay=WEIGHT_DECAY
)
pretrained_model = train_lwm(
model,
train_loaders,
val_loaders,
optimizer,
save_model,
EPOCHS,
device=device,
save_dir=SAVE_DIR,
log_file=LOG_FILE,
)