Spaces:
Running
Running
| # train.py | |
| # Trains Option A, B, C across all lookback windows (30, 45, 60). | |
| # Auto-selects best lookback per model by validation MSE. | |
| # Saves best weights + training metadata to models/ | |
| # Usage: | |
| # python train.py --model all # train A, B, C | |
| # python train.py --model a # train only A | |
| # python train.py --epochs 50 # override epochs | |
| import argparse | |
| import json | |
| import os | |
| from datetime import datetime | |
| import numpy as np | |
| import pandas as pd | |
| import config | |
| import model_a | |
| import model_b | |
| import model_c | |
| from data_download import load_local | |
| from preprocess import run_preprocessing | |
| os.makedirs(config.MODELS_DIR, exist_ok=True) | |
| # βββ Per-model trainer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train_one_model(module, tag: str, prep: dict, epochs: int) -> dict: | |
| """Train a single model on a single lookback. Returns val metrics.""" | |
| model, history = module.train(prep, epochs=epochs) | |
| val_loss = min(history.history["val_loss"]) | |
| val_acc = max(history.history.get("val_accuracy", [0])) | |
| print(f" [{tag}] lb={prep['lookback']} " | |
| f"val_loss={val_loss:.6f} val_acc={val_acc:.4f}") | |
| return {"val_mse": val_loss, "val_acc": val_acc, | |
| "lookback": prep["lookback"], "model": tag} | |
| # βββ Lookback selector ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def select_best_lookback(results: list) -> int: | |
| """Return lookback with lowest val_mse.""" | |
| best = min(results, key=lambda r: r["val_mse"]) | |
| print(f" Best lookback = {best['lookback']}d " | |
| f"(val_mse={best['val_mse']:.6f})") | |
| return best["lookback"] | |
| # βββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_training(models_to_train: list, epochs: int, start_year: int = None, wavelet: str = None): | |
| print(f"\n{'='*60}") | |
| print(f" Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print(f" Models: {models_to_train} | Max epochs: {epochs}") | |
| print(f"{'='*60}") | |
| # Load data | |
| data = load_local() | |
| if not data: | |
| raise RuntimeError("No data found. Run data_download.py first.") | |
| # Preprocess for all lookbacks | |
| preps = {} | |
| for lb in config.LOOKBACKS: | |
| preps[lb] = run_preprocessing(data, lb) | |
| training_summary = {} | |
| # ββ Option A ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if "a" in models_to_train: | |
| print("\n" + "β"*50) | |
| print(" OPTION A: Wavelet-CNN-LSTM") | |
| print("β"*50) | |
| results_a = [] | |
| for lb in config.LOOKBACKS: | |
| r = train_one_model(model_a, "A", preps[lb], epochs) | |
| results_a.append(r) | |
| best_lb_a = select_best_lookback(results_a) | |
| training_summary["model_a"] = { | |
| "best_lookback": best_lb_a, | |
| "results": results_a, | |
| } | |
| # ββ Option B ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if "b" in models_to_train: | |
| print("\n" + "β"*50) | |
| print(" OPTION B: Wavelet-Attention-CNN-LSTM") | |
| print("β"*50) | |
| results_b = [] | |
| for lb in config.LOOKBACKS: | |
| r = train_one_model(model_b, "B", preps[lb], epochs) | |
| results_b.append(r) | |
| best_lb_b = select_best_lookback(results_b) | |
| training_summary["model_b"] = { | |
| "best_lookback": best_lb_b, | |
| "results": results_b, | |
| } | |
| # ββ Option C ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if "c" in models_to_train: | |
| print("\n" + "β"*50) | |
| print(" OPTION C: Wavelet-Parallel-Dual-Stream-CNN-LSTM") | |
| print("β"*50) | |
| results_c = [] | |
| for lb in config.LOOKBACKS: | |
| r = train_one_model(model_c, "C", preps[lb], epochs) | |
| results_c.append(r) | |
| best_lb_c = select_best_lookback(results_c) | |
| training_summary["model_c"] = { | |
| "best_lookback": best_lb_c, | |
| "results": results_c, | |
| } | |
| # ββ Save training summary βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| training_summary["trained_at"] = datetime.now().isoformat() | |
| training_summary["epochs"] = epochs | |
| training_summary["start_year"] = start_year | |
| training_summary["wavelet"] = wavelet | |
| summary_path = os.path.join(config.MODELS_DIR, "training_summary.json") | |
| with open(summary_path, "w") as f: | |
| json.dump(training_summary, f, indent=2) | |
| print(f"\n Training summary β {summary_path}") | |
| print(f"\n{'='*60}") | |
| print(" Training complete.") | |
| print(f"{'='*60}") | |
| return training_summary | |
| # βββ CLI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", default="all", | |
| help="all | a | b | c | a,b | b,c etc.") | |
| parser.add_argument("--epochs", type=int, default=config.MAX_EPOCHS) | |
| parser.add_argument("--start_year", type=int, default=None, | |
| help="Training start year (e.g. 2015). Stamped into summary.") | |
| parser.add_argument("--wavelet", default=None, | |
| help="Wavelet key (e.g. db4). Stamped into summary.") | |
| args = parser.parse_args() | |
| if args.model == "all": | |
| models = ["a", "b", "c"] | |
| else: | |
| models = [m.strip().lower() for m in args.model.split(",")] | |
| run_training(models, args.epochs, | |
| start_year=args.start_year, | |
| wavelet=args.wavelet) | |