WaveLSFromer / run_once.py
ducheng678
Initial WaveLSFromer project
093b0a5
Raw
History Blame Contribute Delete
13.4 kB
"""
This will train one model given a config
"""
import sys
import json
import os
import glob
import pandas as pd
from pathlib import Path
from pprint import pprint
from pytorch_lightning.callbacks import (
EarlyStopping,
ModelCheckpoint,
ModelSummary,
StochasticWeightAveraging,
LearningRateMonitor,
)
import yaml
from utils.callbacks import PredTrueDateWriter
from pytorch_lightning.loggers import TensorBoardLogger
from data_provider.data_module import CustomDataModule
from exp.exp_timeseries import ExpTimeseries
from utils.results_analysis import get_metrics, open_results
from utils.tools import dotdict
from utils.ipynb_helpers import bbtest_setting, read_data, setting_from_args
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.tuner import Tuner
import torch
# torch.set_float32_matmul_precision("high")
class GradHistogramLogger(Callback):
def __init__(self, log_every_n_steps=100):
self.log_every_n_steps = log_every_n_steps
def on_after_backward(self, trainer, pl_module):
if trainer.global_step % self.log_every_n_steps == 0:
for name, p in pl_module.named_parameters():
if p.grad is not None:
trainer.logger.experiment.add_histogram(
f"grad_norm_p_his", p.grad, trainer.global_step
)
class GradNormTracker(Callback):
def __init__(self, p=2):
self.p = p
def on_after_backward(self, trainer, pl_module):
total_norm = 0
for p in pl_module.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(self.p)
total_norm += param_norm.item() ** self.p
total_norm = total_norm ** (1. / self.p)
trainer.logger.experiment.add_scalar("grad_norm_p", total_norm, trainer.global_step)
def pt_light_experiment(
args: dotdict,
devices: list[int] | str | int,
logger: TensorBoardLogger | None = None,
):
print("Args in experiment:")
print(args)
strategy = "ddp" # ["ddp", "ddp_spawn", "ddp_notebook", "ddp_fork", None]
num_workers = max(1, os.cpu_count()//len(devices)) * (strategy != "ddp_spawn")
if args.seed is not None:
pl.seed_everything(seed=args.seed, workers=True)
# Create Data Module
data_module = CustomDataModule(args, num_workers)
# Instantiate Lightning Model
exp = ExpTimeseries(args)
# Define Callbacks
callbacks = []
# Early Stop
# if not args.no_early_stop:
# callbacks.append(
# EarlyStopping(
# monitor="val_loss",
# min_delta=0.000001,
# patience=args.patience,
# verbose=True,
# mode="min",
# )
# )
# Checkpoint model with lowest val lost into checkpoint.ckpt
# Additionally, checkpoint final model into last.ckpt if args.no_early_stop
callbacks.append(
ModelCheckpoint(
filename="checkpoint",
save_top_k=1,
save_last=args.no_early_stop,
verbose=False,
# monitor="val_loss",
monitor="v_e",
mode="max",
)
)
# Print model details
callbacks.append(ModelSummary(max_depth=2))
# Write data on predict
callbacks.append(PredTrueDateWriter("result", "epoch"))
# Stochastic Weight Averaging to improve generalization
# TODO: Research this more
# callbacks.append(
# StochasticWeightAveraging(swa_lrs=1e-5, swa_epoch_start=.8, device=None)
# )
# swa_lrs=1e-5 or lr
# Log learning rate
callbacks.append(LearningRateMonitor("epoch"))
# Print all callbacks
print(
"Callbacks:",
list(map(lambda x: str(type(x))[str(type(x)).rfind(".") + 1 : -2], callbacks)),
)
# Logger
if logger is None:
setting = bbtest_setting(args)
print("Setting:", setting)
logger = TensorBoardLogger(
"lightning_logs", name=setting, flush_secs=15 # , default_hp_metric=False,
)
# Define Trainer Params
trainer_params = {
# "auto_lr_find": True,
# "fast_dev_run": True, # For debugging
# "profiler": "simple", # For looking for bottlenecks
# "detect_anomaly": True,
# "overfit_batches": 1,
# "track_grad_norm": 2,
"max_epochs": args.max_epochs,
"accelerator": "gpu",
"devices": devices,
# "auto_select_gpus": True,
# "strategy": strategy, # Multi GPU
# "default_root_dir": f"lightning_logs/{setting}",
"enable_model_summary": False,
"callbacks": callbacks + [GradNormTracker(p=2),
# GradHistogramLogger(log_every_n_steps=1000),
],
"logger": logger,
"log_every_n_steps": 25,
# "precision": "bf16-mixed",
# gradient_clip_val: 1.0,
# gradient_clip_algorithm: "norm"
}
trainer = pl.Trainer(**trainer_params)
trainer.logger.log_hyperparams(args)
# # Tune model (noop unless auto_scale_batch_size or auto_lr_find)
# tuner = Tuner(trainer)
# new_batch_size = tuner.scale_batch_size(
# model=exp,
# train_dataloaders=data_module,
# mode="power" # 或 "binsearch"
# )
# print(new_batch_size)
# while 1:pass
def find_latest_ckpt(log_root):
log_root = Path(log_root)
if not log_root.exists():
raise FileNotFoundError(f"{log_root} not found")
candidates = []
# 1. 遍历所有实验文件夹(最外层)
for exp_dir in log_root.iterdir():
if not exp_dir.is_dir():
continue
# 2. 这个实验下面找 version_*
version_dirs = list(exp_dir.glob("version_*"))
for vdir in version_dirs:
ckpt_dir = vdir / "checkpoints"
ckpt_files = list(ckpt_dir.glob("last.ckpt"))
if not ckpt_files:
continue
# 3. 这个 version 下有 ckpt,就把每个 ckpt都当候选
for ckpt in ckpt_files:
mtime = ckpt.stat().st_mtime
candidates.append((mtime, ckpt))
if not candidates:
raise FileNotFoundError("No .ckpt found under lightning_logs/")
# 4. 选修改时间最新的 ckpt
candidates.sort(key=lambda x: x[0], reverse=True)
latest_ckpt = candidates[0][1]
return str(latest_ckpt)
latest_ckpt = None
if args.load_ckpt:
log_root = "lightning_logs"
latest_ckpt = find_latest_ckpt(log_root)
print("✅ Latest ckpt:", latest_ckpt)
# Train Model
trainer.fit(exp, data_module, ckpt_path=latest_ckpt)
if not args.no_early_stop:
exp = ExpTimeseries.load_from_checkpoint(
os.path.join(trainer.log_dir, "checkpoints/checkpoint.ckpt"), config=args
)
# Test Model
test_loop_output = trainer.test(exp, data_module)
# Predict and Save Results
results = trainer.predict(exp, data_module)
print("DONE!!!! Logged in:", trainer.log_dir)
return trainer.log_dir, test_loop_output
if __name__ == "__main__":
args = dotdict()
# args.des = "full_1h"
# args.model = "stockformer" # 'stockformer'
# args.root_path = "./data/stock/" # root path of data file
# args.data_path = "full_1h.csv" # data file
# args.freq = "h" # freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h
# args.seed = None # Seed to control randomness, None for random seed
# args.features = "MS" # forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate
# args.target = "WTI_logpctchange" # target feature in S or MS task
# args.seq_len = 16 # input sequence length of Informer encoder
# args.label_len = 0 # start token length of Informer decoder
# args.pred_len = 1 # prediction sequence length
# args.cols = [
# "XOM_logpctchange",
# "CVX_logpctchange",
# "COP_logpctchange",
# "BP_logpctchange",
# "PBR_logpctchange",
# "WTI_logpctchange",
# "EOG_logpctchange",
# "ENB_logpctchange",
# "SLB_logpctchange",
# ] #'C:USDSAR_logpctchange'
# args.enc_in = len(args.cols) # encoder input size
# # args.dec_in = len(args.cols) # decoder input size # TODO: Remove
# args.c_out = 1 if args.features in ["S", "MS"] else args.dec_in # output size
# args.d_model = 512 # dimension of model; also the dimension of the token embeddings
# args.n_heads = 512 # num of attention heads
# args.e_layers = 12 # num of encoder layers
# args.d_ff = 4096 # dimension of fcn in model
# args.dropout = 0.5 # dropout
# args.dropout_emb = 0.0 # dropout for embedding
# args.t_embed = None # time features encoding, options:[timeF, fixed, learned, None, time2vec_add, time2vec_app]
# args.activation = "gelu" # activation
# args.attn = "full" # attention used in encoder, options:[prob, full]
# args.factor = 5 # probsparse attn factor; doesn't matter unless args.attn==prob
# args.distil = False # whether to use distilling in encoder
# args.output_attention = False # whether to output attention in encoder
# args.mix = False # whether to use mixed attention
# args.ln_mode = "post"
# args.batch_size = 128
# args.learning_rate = 0.00001
# # What loss function to use, options:["mse", "mae", "stock_dir", "stock_dir-ns", "stock_tanh", "stock_tanhv1", ...]
# # The logic is messy
# args.loss = "stock_tanhv1"
# args.lradj = (
# None # What learning rate scheduler to use: ["type2", "type3", None, "type1"]
# )
# args.optim = "AdamW" # Adam, AdamW
# args.max_epochs = 1
# args.patience = 100 # For early stopping
# args.scale = True # whether to scale to mean 0, var 1
# args.no_scale_mean = True # whether to disable the mean scaling
# args.inverse_output = (
# False # whether to invert-scale the model's output before loss is calculated
# )
# args.inverse_pred = True # whether to invert-scale the data label
# # This is for debugging to overfit
# # When True, patience doesn't matter at all and the model-state that is saved is the one after the last epoch
# # When False, the model-state that is saved is the one with the highest validation-loss and we can early stop with patience
# args.no_early_stop = False
# # Control data split from args, either a date string like "2000-01-30" or None (for default)
# args.date_start = "2012-01-01" # Train data starts on this date, default is to go back as far as possible
# args.date_end = "2020-01-01" # Train data starts on this date, default is to go back as far as possible
# args.date_test = "2019-06-01" # Test data is data after this date, default is to use ~20% of the data as test data
# args.dont_shuffle_train = True
# args.load_model_path = "stockformer_custom_ftMS_sl16_ll4_pl1_ei12_di12_co1_iFalse_dm512_nh8_el12_dl4_df2048_atfull_fc5_ebtimeF_dtFalse_mxFalse_pretrain_full_1h_0/checkpoint-pretrain.pth"
# import yaml
# with open("configs/stockformer/example.yaml", "w") as file:
# yaml.dump(dict(args), file)
# config_file = "configs/lstm/basic_PEMSBAY.yaml"
try:
industry = sys.argv[1]
config_file = f"configs/stockformer/general_{industry}.yaml"
except:
config_file = f"configs/stockformer/general.yaml"
devices = [0]
with open(config_file, "r") as file:
args = dotdict(yaml.full_load(file))
def split_dataset(args, segment_months=15):
start = pd.Timestamp(args['date_start'])
end = pd.Timestamp(args['date_end'])
segments = []
current_start = start
while current_start < end:
current_end = current_start + pd.DateOffset(months=segment_months) - pd.Timedelta(days=1)
if current_end > end:
current_end = end
segments.append({
"start": current_start.date().isoformat(),
"end": current_end.date().isoformat()
})
current_start = current_end + pd.Timedelta(days=1)
return segments
df = read_data(os.path.join(args.root_path, args.data_path))
# for idx, seg in enumerate(split_dataset(args, 6)):
# if idx < 0:
# continue
# args['date_start'] = seg["start"]
# args['date_end'] = seg['end']
log_dir, test_loop_output = pt_light_experiment(args, devices)
tpd_dict = open_results(log_dir, args, df)
metrics = {}
for data_group in tpd_dict:
true = tpd_dict[data_group]["trues"]
pred = tpd_dict[data_group]["preds"]
date = tpd_dict[data_group]["dates"]
metrics[data_group] = get_metrics(args, pred, true, 0.0)
print(data_group, end="\t")
pprint(metrics[data_group], indent=3)
with open(os.path.join(log_dir, "metrics.json"), "w") as f:
json.dump(metrics, f, indent=2)