| """ | |
| This will train one model given a config | |
| """ | |
| import sys | |
| import json | |
| import os | |
| import glob | |
| import pandas as pd | |
| from pathlib import Path | |
| from pprint import pprint | |
| from pytorch_lightning.callbacks import ( | |
| EarlyStopping, | |
| ModelCheckpoint, | |
| ModelSummary, | |
| StochasticWeightAveraging, | |
| LearningRateMonitor, | |
| ) | |
| import yaml | |
| from utils.callbacks import PredTrueDateWriter | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| from data_provider.data_module import CustomDataModule | |
| from exp.exp_timeseries import ExpTimeseries | |
| from utils.results_analysis import get_metrics, open_results | |
| from utils.tools import dotdict | |
| from utils.ipynb_helpers import bbtest_setting, read_data, setting_from_args | |
| import pytorch_lightning as pl | |
| from pytorch_lightning import Callback | |
| from pytorch_lightning.tuner import Tuner | |
| import torch | |
| # torch.set_float32_matmul_precision("high") | |
| class GradHistogramLogger(Callback): | |
| def __init__(self, log_every_n_steps=100): | |
| self.log_every_n_steps = log_every_n_steps | |
| def on_after_backward(self, trainer, pl_module): | |
| if trainer.global_step % self.log_every_n_steps == 0: | |
| for name, p in pl_module.named_parameters(): | |
| if p.grad is not None: | |
| trainer.logger.experiment.add_histogram( | |
| f"grad_norm_p_his", p.grad, trainer.global_step | |
| ) | |
| class GradNormTracker(Callback): | |
| def __init__(self, p=2): | |
| self.p = p | |
| def on_after_backward(self, trainer, pl_module): | |
| total_norm = 0 | |
| for p in pl_module.parameters(): | |
| if p.grad is not None: | |
| param_norm = p.grad.data.norm(self.p) | |
| total_norm += param_norm.item() ** self.p | |
| total_norm = total_norm ** (1. / self.p) | |
| trainer.logger.experiment.add_scalar("grad_norm_p", total_norm, trainer.global_step) | |
| def pt_light_experiment( | |
| args: dotdict, | |
| devices: list[int] | str | int, | |
| logger: TensorBoardLogger | None = None, | |
| ): | |
| print("Args in experiment:") | |
| print(args) | |
| strategy = "ddp" # ["ddp", "ddp_spawn", "ddp_notebook", "ddp_fork", None] | |
| num_workers = max(1, os.cpu_count()//len(devices)) * (strategy != "ddp_spawn") | |
| if args.seed is not None: | |
| pl.seed_everything(seed=args.seed, workers=True) | |
| # Create Data Module | |
| data_module = CustomDataModule(args, num_workers) | |
| # Instantiate Lightning Model | |
| exp = ExpTimeseries(args) | |
| # Define Callbacks | |
| callbacks = [] | |
| # Early Stop | |
| # if not args.no_early_stop: | |
| # callbacks.append( | |
| # EarlyStopping( | |
| # monitor="val_loss", | |
| # min_delta=0.000001, | |
| # patience=args.patience, | |
| # verbose=True, | |
| # mode="min", | |
| # ) | |
| # ) | |
| # Checkpoint model with lowest val lost into checkpoint.ckpt | |
| # Additionally, checkpoint final model into last.ckpt if args.no_early_stop | |
| callbacks.append( | |
| ModelCheckpoint( | |
| filename="checkpoint", | |
| save_top_k=1, | |
| save_last=args.no_early_stop, | |
| verbose=False, | |
| # monitor="val_loss", | |
| monitor="v_e", | |
| mode="max", | |
| ) | |
| ) | |
| # Print model details | |
| callbacks.append(ModelSummary(max_depth=2)) | |
| # Write data on predict | |
| callbacks.append(PredTrueDateWriter("result", "epoch")) | |
| # Stochastic Weight Averaging to improve generalization | |
| # TODO: Research this more | |
| # callbacks.append( | |
| # StochasticWeightAveraging(swa_lrs=1e-5, swa_epoch_start=.8, device=None) | |
| # ) | |
| # swa_lrs=1e-5 or lr | |
| # Log learning rate | |
| callbacks.append(LearningRateMonitor("epoch")) | |
| # Print all callbacks | |
| print( | |
| "Callbacks:", | |
| list(map(lambda x: str(type(x))[str(type(x)).rfind(".") + 1 : -2], callbacks)), | |
| ) | |
| # Logger | |
| if logger is None: | |
| setting = bbtest_setting(args) | |
| print("Setting:", setting) | |
| logger = TensorBoardLogger( | |
| "lightning_logs", name=setting, flush_secs=15 # , default_hp_metric=False, | |
| ) | |
| # Define Trainer Params | |
| trainer_params = { | |
| # "auto_lr_find": True, | |
| # "fast_dev_run": True, # For debugging | |
| # "profiler": "simple", # For looking for bottlenecks | |
| # "detect_anomaly": True, | |
| # "overfit_batches": 1, | |
| # "track_grad_norm": 2, | |
| "max_epochs": args.max_epochs, | |
| "accelerator": "gpu", | |
| "devices": devices, | |
| # "auto_select_gpus": True, | |
| # "strategy": strategy, # Multi GPU | |
| # "default_root_dir": f"lightning_logs/{setting}", | |
| "enable_model_summary": False, | |
| "callbacks": callbacks + [GradNormTracker(p=2), | |
| # GradHistogramLogger(log_every_n_steps=1000), | |
| ], | |
| "logger": logger, | |
| "log_every_n_steps": 25, | |
| # "precision": "bf16-mixed", | |
| # gradient_clip_val: 1.0, | |
| # gradient_clip_algorithm: "norm" | |
| } | |
| trainer = pl.Trainer(**trainer_params) | |
| trainer.logger.log_hyperparams(args) | |
| # # Tune model (noop unless auto_scale_batch_size or auto_lr_find) | |
| # tuner = Tuner(trainer) | |
| # new_batch_size = tuner.scale_batch_size( | |
| # model=exp, | |
| # train_dataloaders=data_module, | |
| # mode="power" # 或 "binsearch" | |
| # ) | |
| # print(new_batch_size) | |
| # while 1:pass | |
| def find_latest_ckpt(log_root): | |
| log_root = Path(log_root) | |
| if not log_root.exists(): | |
| raise FileNotFoundError(f"{log_root} not found") | |
| candidates = [] | |
| # 1. 遍历所有实验文件夹(最外层) | |
| for exp_dir in log_root.iterdir(): | |
| if not exp_dir.is_dir(): | |
| continue | |
| # 2. 这个实验下面找 version_* | |
| version_dirs = list(exp_dir.glob("version_*")) | |
| for vdir in version_dirs: | |
| ckpt_dir = vdir / "checkpoints" | |
| ckpt_files = list(ckpt_dir.glob("last.ckpt")) | |
| if not ckpt_files: | |
| continue | |
| # 3. 这个 version 下有 ckpt,就把每个 ckpt都当候选 | |
| for ckpt in ckpt_files: | |
| mtime = ckpt.stat().st_mtime | |
| candidates.append((mtime, ckpt)) | |
| if not candidates: | |
| raise FileNotFoundError("No .ckpt found under lightning_logs/") | |
| # 4. 选修改时间最新的 ckpt | |
| candidates.sort(key=lambda x: x[0], reverse=True) | |
| latest_ckpt = candidates[0][1] | |
| return str(latest_ckpt) | |
| latest_ckpt = None | |
| if args.load_ckpt: | |
| log_root = "lightning_logs" | |
| latest_ckpt = find_latest_ckpt(log_root) | |
| print("✅ Latest ckpt:", latest_ckpt) | |
| # Train Model | |
| trainer.fit(exp, data_module, ckpt_path=latest_ckpt) | |
| if not args.no_early_stop: | |
| exp = ExpTimeseries.load_from_checkpoint( | |
| os.path.join(trainer.log_dir, "checkpoints/checkpoint.ckpt"), config=args | |
| ) | |
| # Test Model | |
| test_loop_output = trainer.test(exp, data_module) | |
| # Predict and Save Results | |
| results = trainer.predict(exp, data_module) | |
| print("DONE!!!! Logged in:", trainer.log_dir) | |
| return trainer.log_dir, test_loop_output | |
| if __name__ == "__main__": | |
| args = dotdict() | |
| # args.des = "full_1h" | |
| # args.model = "stockformer" # 'stockformer' | |
| # args.root_path = "./data/stock/" # root path of data file | |
| # args.data_path = "full_1h.csv" # data file | |
| # args.freq = "h" # freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h | |
| # args.seed = None # Seed to control randomness, None for random seed | |
| # args.features = "MS" # forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate | |
| # args.target = "WTI_logpctchange" # target feature in S or MS task | |
| # args.seq_len = 16 # input sequence length of Informer encoder | |
| # args.label_len = 0 # start token length of Informer decoder | |
| # args.pred_len = 1 # prediction sequence length | |
| # args.cols = [ | |
| # "XOM_logpctchange", | |
| # "CVX_logpctchange", | |
| # "COP_logpctchange", | |
| # "BP_logpctchange", | |
| # "PBR_logpctchange", | |
| # "WTI_logpctchange", | |
| # "EOG_logpctchange", | |
| # "ENB_logpctchange", | |
| # "SLB_logpctchange", | |
| # ] #'C:USDSAR_logpctchange' | |
| # args.enc_in = len(args.cols) # encoder input size | |
| # # args.dec_in = len(args.cols) # decoder input size # TODO: Remove | |
| # args.c_out = 1 if args.features in ["S", "MS"] else args.dec_in # output size | |
| # args.d_model = 512 # dimension of model; also the dimension of the token embeddings | |
| # args.n_heads = 512 # num of attention heads | |
| # args.e_layers = 12 # num of encoder layers | |
| # args.d_ff = 4096 # dimension of fcn in model | |
| # args.dropout = 0.5 # dropout | |
| # args.dropout_emb = 0.0 # dropout for embedding | |
| # args.t_embed = None # time features encoding, options:[timeF, fixed, learned, None, time2vec_add, time2vec_app] | |
| # args.activation = "gelu" # activation | |
| # args.attn = "full" # attention used in encoder, options:[prob, full] | |
| # args.factor = 5 # probsparse attn factor; doesn't matter unless args.attn==prob | |
| # args.distil = False # whether to use distilling in encoder | |
| # args.output_attention = False # whether to output attention in encoder | |
| # args.mix = False # whether to use mixed attention | |
| # args.ln_mode = "post" | |
| # args.batch_size = 128 | |
| # args.learning_rate = 0.00001 | |
| # # What loss function to use, options:["mse", "mae", "stock_dir", "stock_dir-ns", "stock_tanh", "stock_tanhv1", ...] | |
| # # The logic is messy | |
| # args.loss = "stock_tanhv1" | |
| # args.lradj = ( | |
| # None # What learning rate scheduler to use: ["type2", "type3", None, "type1"] | |
| # ) | |
| # args.optim = "AdamW" # Adam, AdamW | |
| # args.max_epochs = 1 | |
| # args.patience = 100 # For early stopping | |
| # args.scale = True # whether to scale to mean 0, var 1 | |
| # args.no_scale_mean = True # whether to disable the mean scaling | |
| # args.inverse_output = ( | |
| # False # whether to invert-scale the model's output before loss is calculated | |
| # ) | |
| # args.inverse_pred = True # whether to invert-scale the data label | |
| # # This is for debugging to overfit | |
| # # When True, patience doesn't matter at all and the model-state that is saved is the one after the last epoch | |
| # # When False, the model-state that is saved is the one with the highest validation-loss and we can early stop with patience | |
| # args.no_early_stop = False | |
| # # Control data split from args, either a date string like "2000-01-30" or None (for default) | |
| # args.date_start = "2012-01-01" # Train data starts on this date, default is to go back as far as possible | |
| # args.date_end = "2020-01-01" # Train data starts on this date, default is to go back as far as possible | |
| # args.date_test = "2019-06-01" # Test data is data after this date, default is to use ~20% of the data as test data | |
| # args.dont_shuffle_train = True | |
| # args.load_model_path = "stockformer_custom_ftMS_sl16_ll4_pl1_ei12_di12_co1_iFalse_dm512_nh8_el12_dl4_df2048_atfull_fc5_ebtimeF_dtFalse_mxFalse_pretrain_full_1h_0/checkpoint-pretrain.pth" | |
| # import yaml | |
| # with open("configs/stockformer/example.yaml", "w") as file: | |
| # yaml.dump(dict(args), file) | |
| # config_file = "configs/lstm/basic_PEMSBAY.yaml" | |
| try: | |
| industry = sys.argv[1] | |
| config_file = f"configs/stockformer/general_{industry}.yaml" | |
| except: | |
| config_file = f"configs/stockformer/general.yaml" | |
| devices = [0] | |
| with open(config_file, "r") as file: | |
| args = dotdict(yaml.full_load(file)) | |
| def split_dataset(args, segment_months=15): | |
| start = pd.Timestamp(args['date_start']) | |
| end = pd.Timestamp(args['date_end']) | |
| segments = [] | |
| current_start = start | |
| while current_start < end: | |
| current_end = current_start + pd.DateOffset(months=segment_months) - pd.Timedelta(days=1) | |
| if current_end > end: | |
| current_end = end | |
| segments.append({ | |
| "start": current_start.date().isoformat(), | |
| "end": current_end.date().isoformat() | |
| }) | |
| current_start = current_end + pd.Timedelta(days=1) | |
| return segments | |
| df = read_data(os.path.join(args.root_path, args.data_path)) | |
| # for idx, seg in enumerate(split_dataset(args, 6)): | |
| # if idx < 0: | |
| # continue | |
| # args['date_start'] = seg["start"] | |
| # args['date_end'] = seg['end'] | |
| log_dir, test_loop_output = pt_light_experiment(args, devices) | |
| tpd_dict = open_results(log_dir, args, df) | |
| metrics = {} | |
| for data_group in tpd_dict: | |
| true = tpd_dict[data_group]["trues"] | |
| pred = tpd_dict[data_group]["preds"] | |
| date = tpd_dict[data_group]["dates"] | |
| metrics[data_group] = get_metrics(args, pred, true, 0.0) | |
| print(data_group, end="\t") | |
| pprint(metrics[data_group], indent=3) | |
| with open(os.path.join(log_dir, "metrics.json"), "w") as f: | |
| json.dump(metrics, f, indent=2) | |