| import os | |
| import json | |
| import time | |
| import random | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import TensorDataset, DataLoader | |
| from config import config | |
| from data_loader import load_asset_series, make_features, build_windows, save_manifest | |
| from model import TinyTransformerForecaster | |
| def set_seed(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| def _time_split(X, y, val_split: float): | |
| n = len(X) | |
| n_val = max(1, int(n * val_split)) | |
| n_train = n - n_val | |
| return (X[:n_train], y[:n_train]), (X[n_train:], y[n_train:]) | |
| def build_global_dataset(): | |
| series = load_asset_series() | |
| save_manifest(series) | |
| X_all, y_all = [], [] | |
| for s in series: | |
| feats = make_features(s.df) | |
| X, y, _ = build_windows(feats, window=config.WINDOW, horizon=config.HORIZON_DAYS) | |
| mu = X.reshape(-1, X.shape[-1]).mean(axis=0, keepdims=True) | |
| sd = X.reshape(-1, X.shape[-1]).std(axis=0, keepdims=True) + 1e-6 | |
| X = (X - mu) / sd | |
| X_all.append(X) | |
| y_all.append(y) | |
| if not X_all: | |
| raise RuntimeError("No usable data found in historical_data/. Check folder structure and CSV format.") | |
| X_all = np.concatenate(X_all, axis=0) | |
| y_all = np.concatenate(y_all, axis=0) | |
| idx = np.random.permutation(len(X_all)) | |
| X_all, y_all = X_all[idx], y_all[idx] | |
| (X_tr, y_tr), (X_va, y_va) = _time_split(X_all, y_all, config.VAL_SPLIT) | |
| return (X_tr, y_tr), (X_va, y_va), len(series) | |
| class TarangTrainer: | |
| def __init__(self): | |
| set_seed(config.SEED) | |
| os.makedirs(config.ARTIFACT_DIR, exist_ok=True) | |
| self.device = torch.device("cpu") | |
| self.model = TinyTransformerForecaster().to(self.device) | |
| self.opt = torch.optim.AdamW(self.model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY) | |
| self.loss_fn = torch.nn.MSELoss() | |
| self.status_path = os.path.join(config.ARTIFACT_DIR, "train_status.json") | |
| self.model_path = os.path.join(config.ARTIFACT_DIR, "model.pt") | |
| self.best_val = float("inf") | |
| self.last_status = None | |
| def load_local_if_exists(self): | |
| if os.path.exists(self.model_path): | |
| self.model.load_state_dict(torch.load(self.model_path, map_location="cpu")) | |
| self.model.eval() | |
| return True | |
| return False | |
| def save_local(self): | |
| torch.save(self.model.state_dict(), self.model_path) | |
| def fit(self, epochs: int): | |
| (X_tr, y_tr), (X_va, y_va), asset_count = build_global_dataset() | |
| tr_loader = DataLoader( | |
| TensorDataset(torch.tensor(X_tr), torch.tensor(y_tr).unsqueeze(-1)), | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| drop_last=False, | |
| ) | |
| va_loader = DataLoader( | |
| TensorDataset(torch.tensor(X_va), torch.tensor(y_va).unsqueeze(-1)), | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| status = { | |
| "asset_count": int(asset_count), | |
| "train_samples": int(len(X_tr)), | |
| "val_samples": int(len(X_va)), | |
| "best_val": float(self.best_val) if np.isfinite(self.best_val) else None, | |
| "last_val": None, | |
| "trained_at": None, | |
| } | |
| for epoch in range(1, epochs + 1): | |
| t0 = time.time() | |
| self.model.train() | |
| tr_loss = 0.0 | |
| for xb, yb in tr_loader: | |
| xb, yb = xb.to(self.device), yb.to(self.device) | |
| pred = self.model(xb) | |
| loss = self.loss_fn(pred, yb) | |
| self.opt.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| self.opt.step() | |
| tr_loss += loss.item() * len(xb) | |
| tr_loss /= max(1, len(X_tr)) | |
| self.model.eval() | |
| va_loss = 0.0 | |
| with torch.no_grad(): | |
| for xb, yb in va_loader: | |
| xb, yb = xb.to(self.device), yb.to(self.device) | |
| pred = self.model(xb) | |
| loss = self.loss_fn(pred, yb) | |
| va_loss += loss.item() * len(xb) | |
| va_loss /= max(1, len(X_va)) | |
| status["last_val"] = float(va_loss) | |
| if va_loss < self.best_val: | |
| self.best_val = va_loss | |
| status["best_val"] = float(self.best_val) | |
| self.save_local() | |
| status["trained_at"] = time.strftime("%Y-%m-%d %H:%M:%S") | |
| with open(self.status_path, "w", encoding="utf-8") as f: | |
| json.dump(status, f, indent=2) | |
| dt = time.time() - t0 | |
| print(f"[train] epoch={epoch}/{epochs} train={tr_loss:.6f} val={va_loss:.6f} dt={dt:.1f}s") | |
| self.last_status = status | |
| return status | |