Tarang_v2 / trainer.py
unknownfriend00007's picture
Update trainer.py
7282950 verified
import os
import json
import time
import random
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from config import config
from data_loader import load_asset_series, make_features, build_windows, save_manifest
from model import TinyTransformerForecaster
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def _time_split(X, y, val_split: float):
n = len(X)
n_val = max(1, int(n * val_split))
n_train = n - n_val
return (X[:n_train], y[:n_train]), (X[n_train:], y[n_train:])
def build_global_dataset():
series = load_asset_series()
save_manifest(series)
X_all, y_all = [], []
for s in series:
feats = make_features(s.df)
X, y, _ = build_windows(feats, window=config.WINDOW, horizon=config.HORIZON_DAYS)
mu = X.reshape(-1, X.shape[-1]).mean(axis=0, keepdims=True)
sd = X.reshape(-1, X.shape[-1]).std(axis=0, keepdims=True) + 1e-6
X = (X - mu) / sd
X_all.append(X)
y_all.append(y)
if not X_all:
raise RuntimeError("No usable data found in historical_data/. Check folder structure and CSV format.")
X_all = np.concatenate(X_all, axis=0)
y_all = np.concatenate(y_all, axis=0)
idx = np.random.permutation(len(X_all))
X_all, y_all = X_all[idx], y_all[idx]
(X_tr, y_tr), (X_va, y_va) = _time_split(X_all, y_all, config.VAL_SPLIT)
return (X_tr, y_tr), (X_va, y_va), len(series)
class TarangTrainer:
def __init__(self):
set_seed(config.SEED)
os.makedirs(config.ARTIFACT_DIR, exist_ok=True)
self.device = torch.device("cpu")
self.model = TinyTransformerForecaster().to(self.device)
self.opt = torch.optim.AdamW(self.model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
self.loss_fn = torch.nn.MSELoss()
self.status_path = os.path.join(config.ARTIFACT_DIR, "train_status.json")
self.model_path = os.path.join(config.ARTIFACT_DIR, "model.pt")
self.best_val = float("inf")
self.last_status = None
def load_local_if_exists(self):
if os.path.exists(self.model_path):
self.model.load_state_dict(torch.load(self.model_path, map_location="cpu"))
self.model.eval()
return True
return False
def save_local(self):
torch.save(self.model.state_dict(), self.model_path)
def fit(self, epochs: int):
(X_tr, y_tr), (X_va, y_va), asset_count = build_global_dataset()
tr_loader = DataLoader(
TensorDataset(torch.tensor(X_tr), torch.tensor(y_tr).unsqueeze(-1)),
batch_size=config.BATCH_SIZE,
shuffle=True,
drop_last=False,
)
va_loader = DataLoader(
TensorDataset(torch.tensor(X_va), torch.tensor(y_va).unsqueeze(-1)),
batch_size=config.BATCH_SIZE,
shuffle=False,
drop_last=False,
)
status = {
"asset_count": int(asset_count),
"train_samples": int(len(X_tr)),
"val_samples": int(len(X_va)),
"best_val": float(self.best_val) if np.isfinite(self.best_val) else None,
"last_val": None,
"trained_at": None,
}
for epoch in range(1, epochs + 1):
t0 = time.time()
self.model.train()
tr_loss = 0.0
for xb, yb in tr_loader:
xb, yb = xb.to(self.device), yb.to(self.device)
pred = self.model(xb)
loss = self.loss_fn(pred, yb)
self.opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.opt.step()
tr_loss += loss.item() * len(xb)
tr_loss /= max(1, len(X_tr))
self.model.eval()
va_loss = 0.0
with torch.no_grad():
for xb, yb in va_loader:
xb, yb = xb.to(self.device), yb.to(self.device)
pred = self.model(xb)
loss = self.loss_fn(pred, yb)
va_loss += loss.item() * len(xb)
va_loss /= max(1, len(X_va))
status["last_val"] = float(va_loss)
if va_loss < self.best_val:
self.best_val = va_loss
status["best_val"] = float(self.best_val)
self.save_local()
status["trained_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
with open(self.status_path, "w", encoding="utf-8") as f:
json.dump(status, f, indent=2)
dt = time.time() - t0
print(f"[train] epoch={epoch}/{epochs} train={tr_loss:.6f} val={va_loss:.6f} dt={dt:.1f}s")
self.last_status = status
return status