P2-ETF-DEEPWAVE-DL / train.py
P2SAMAPA's picture
[auto] Sync code from GitHub
d55eb36 verified
# train.py
# Trains Option A, B, C across all lookback windows (30, 45, 60).
# Auto-selects best lookback per model by validation MSE.
# Saves best weights + training metadata to models/
# Usage:
# python train.py --model all # train A, B, C
# python train.py --model a # train only A
# python train.py --epochs 50 # override epochs
import argparse
import json
import os
from datetime import datetime
import numpy as np
import pandas as pd
import config
import model_a
import model_b
import model_c
from data_download import load_local
from preprocess import run_preprocessing
os.makedirs(config.MODELS_DIR, exist_ok=True)
# ─── Per-model trainer ────────────────────────────────────────────────────────
def train_one_model(module, tag: str, prep: dict, epochs: int) -> dict:
"""Train a single model on a single lookback. Returns val metrics."""
model, history = module.train(prep, epochs=epochs)
val_loss = min(history.history["val_loss"])
val_acc = max(history.history.get("val_accuracy", [0]))
print(f" [{tag}] lb={prep['lookback']} "
f"val_loss={val_loss:.6f} val_acc={val_acc:.4f}")
return {"val_mse": val_loss, "val_acc": val_acc,
"lookback": prep["lookback"], "model": tag}
# ─── Lookback selector ────────────────────────────────────────────────────────
def select_best_lookback(results: list) -> int:
"""Return lookback with lowest val_mse."""
best = min(results, key=lambda r: r["val_mse"])
print(f" Best lookback = {best['lookback']}d "
f"(val_mse={best['val_mse']:.6f})")
return best["lookback"]
# ─── Main ─────────────────────────────────────────────────────────────────────
def run_training(models_to_train: list, epochs: int, start_year: int = None, wavelet: str = None):
print(f"\n{'='*60}")
print(f" Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f" Models: {models_to_train} | Max epochs: {epochs}")
print(f"{'='*60}")
# Load data
data = load_local()
if not data:
raise RuntimeError("No data found. Run data_download.py first.")
# Preprocess for all lookbacks
preps = {}
for lb in config.LOOKBACKS:
preps[lb] = run_preprocessing(data, lb)
training_summary = {}
# ── Option A ──────────────────────────────────────────────────────────────
if "a" in models_to_train:
print("\n" + "─"*50)
print(" OPTION A: Wavelet-CNN-LSTM")
print("─"*50)
results_a = []
for lb in config.LOOKBACKS:
r = train_one_model(model_a, "A", preps[lb], epochs)
results_a.append(r)
best_lb_a = select_best_lookback(results_a)
training_summary["model_a"] = {
"best_lookback": best_lb_a,
"results": results_a,
}
# ── Option B ──────────────────────────────────────────────────────────────
if "b" in models_to_train:
print("\n" + "─"*50)
print(" OPTION B: Wavelet-Attention-CNN-LSTM")
print("─"*50)
results_b = []
for lb in config.LOOKBACKS:
r = train_one_model(model_b, "B", preps[lb], epochs)
results_b.append(r)
best_lb_b = select_best_lookback(results_b)
training_summary["model_b"] = {
"best_lookback": best_lb_b,
"results": results_b,
}
# ── Option C ──────────────────────────────────────────────────────────────
if "c" in models_to_train:
print("\n" + "─"*50)
print(" OPTION C: Wavelet-Parallel-Dual-Stream-CNN-LSTM")
print("─"*50)
results_c = []
for lb in config.LOOKBACKS:
r = train_one_model(model_c, "C", preps[lb], epochs)
results_c.append(r)
best_lb_c = select_best_lookback(results_c)
training_summary["model_c"] = {
"best_lookback": best_lb_c,
"results": results_c,
}
# ── Save training summary ─────────────────────────────────────────────────
training_summary["trained_at"] = datetime.now().isoformat()
training_summary["epochs"] = epochs
training_summary["start_year"] = start_year
training_summary["wavelet"] = wavelet
summary_path = os.path.join(config.MODELS_DIR, "training_summary.json")
with open(summary_path, "w") as f:
json.dump(training_summary, f, indent=2)
print(f"\n Training summary β†’ {summary_path}")
print(f"\n{'='*60}")
print(" Training complete.")
print(f"{'='*60}")
return training_summary
# ─── CLI ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="all",
help="all | a | b | c | a,b | b,c etc.")
parser.add_argument("--epochs", type=int, default=config.MAX_EPOCHS)
parser.add_argument("--start_year", type=int, default=None,
help="Training start year (e.g. 2015). Stamped into summary.")
parser.add_argument("--wavelet", default=None,
help="Wavelet key (e.g. db4). Stamped into summary.")
args = parser.parse_args()
if args.model == "all":
models = ["a", "b", "c"]
else:
models = [m.strip().lower() for m in args.model.split(",")]
run_training(models, args.epochs,
start_year=args.start_year,
wavelet=args.wavelet)