| |
| """Train + evaluate frame-level future verb_fine forecasting. |
| |
| Outputs per-horizon top-1 frame accuracy on the test set, saved to |
| results.json under <output_dir>. |
| """ |
| from __future__ import annotations |
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
|
|
| THIS = Path(__file__).resolve() |
| sys.path.insert(0, str(THIS.parent)) |
| sys.path.insert(0, str(THIS.parents[1])) |
| try: |
| from experiments.dataset_forecast import ( |
| ForecastDataset, collate_forecast, build_train_test, |
| IDLE_LABEL, NUM_FORECAST_CLASSES, |
| ) |
| from experiments.models_forecast import build_forecast_model |
| except ModuleNotFoundError: |
| from dataset_forecast import ( |
| ForecastDataset, collate_forecast, build_train_test, |
| IDLE_LABEL, NUM_FORECAST_CLASSES, |
| ) |
| from models_forecast import build_forecast_model |
|
|
|
|
| def set_seed(seed: int): |
| random.seed(seed); np.random.seed(seed) |
| torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def train_epoch(model, loader, optimizer, criterion, device): |
| model.train() |
| total, n_frames, correct = 0.0, 0, 0 |
| for x, y, _ in loader: |
| x = {m: v.to(device) for m, v in x.items()} |
| y = y.to(device) |
| optimizer.zero_grad() |
| logits = model(x) |
| loss = criterion(logits.reshape(-1, logits.size(-1)), |
| y.reshape(-1)) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| total += loss.item() * y.numel() |
| n_frames += y.numel() |
| correct += (logits.argmax(-1) == y).sum().item() |
| return total / max(n_frames, 1), correct / max(n_frames, 1) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device, t_fut: int): |
| model.eval() |
| |
| per_h_correct = np.zeros(t_fut, dtype=np.int64) |
| per_h_total = np.zeros(t_fut, dtype=np.int64) |
| per_h_correct_action = np.zeros(t_fut, dtype=np.int64) |
| per_h_total_action = np.zeros(t_fut, dtype=np.int64) |
|
|
| for x, y, _ in loader: |
| x = {m: v.to(device) for m, v in x.items()} |
| y = y.to(device) |
| logits = model(x) |
| pred = logits.argmax(-1) |
| for h in range(t_fut): |
| yh = y[:, h]; ph = pred[:, h] |
| per_h_correct[h] += (ph == yh).sum().item() |
| per_h_total[h] += yh.numel() |
| mask = (yh != IDLE_LABEL) |
| per_h_correct_action[h] += ((ph == yh) & mask).sum().item() |
| per_h_total_action[h] += mask.sum().item() |
|
|
| return { |
| "per_h_acc": (per_h_correct / np.maximum(per_h_total, 1)).tolist(), |
| "per_h_acc_action": (per_h_correct_action / np.maximum(per_h_total_action, 1)).tolist(), |
| "frame_acc": float(per_h_correct.sum() / max(per_h_total.sum(), 1)), |
| "frame_acc_action": float(per_h_correct_action.sum() / max(per_h_total_action.sum(), 1)), |
| } |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model", type=str, required=True, |
| choices=["daf", "futr", "deepconvlstm", "rulstm", "avt"]) |
| ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack,mocap,pressure", |
| help="Comma-separated modality list") |
| ap.add_argument("--t_obs", type=float, default=1.5) |
| ap.add_argument("--t_fut", type=float, default=0.5) |
| ap.add_argument("--anchor_stride", type=float, default=0.25) |
| ap.add_argument("--contact_only", action="store_true", |
| help="Only keep anchors whose past+future window has any " |
| "frame with pressure-sum > threshold (Plan B).") |
| ap.add_argument("--contact_threshold_g", type=float, default=5.0) |
| ap.add_argument("--epochs", type=int, default=15) |
| ap.add_argument("--batch_size", type=int, default=64) |
| ap.add_argument("--lr", type=float, default=3e-4) |
| ap.add_argument("--weight_decay", type=float, default=1e-4) |
| ap.add_argument("--d_model", type=int, default=128) |
| ap.add_argument("--dropout", type=float, default=0.1) |
| ap.add_argument("--label_smoothing", type=float, default=0.05) |
| ap.add_argument("--num_workers", type=int, default=2) |
| ap.add_argument("--seed", type=int, default=42) |
| ap.add_argument("--patience", type=int, default=5) |
| ap.add_argument("--output_dir", type=str, required=True) |
| args = ap.parse_args() |
|
|
| set_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"device={device} | seed={args.seed} | model={args.model} " |
| f"modalities={args.modalities}") |
|
|
| mods = args.modalities.split(",") |
| train_ds, test_ds = build_train_test( |
| modalities=mods, |
| t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, |
| anchor_stride_sec=args.anchor_stride, |
| contact_only=args.contact_only, |
| contact_threshold_g=args.contact_threshold_g, |
| ) |
| print(f"train={len(train_ds)} test={len(test_ds)} " |
| f"T_obs={train_ds.T_obs} T_fut={train_ds.T_fut} " |
| f"mod_dims={train_ds.modality_dims}") |
|
|
| tr_loader = DataLoader( |
| train_ds, batch_size=args.batch_size, shuffle=True, |
| num_workers=args.num_workers, collate_fn=collate_forecast, |
| drop_last=False, |
| ) |
| te_loader = DataLoader( |
| test_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, collate_fn=collate_forecast, |
| ) |
|
|
| model = build_forecast_model( |
| args.model, train_ds.modality_dims, |
| num_classes=NUM_FORECAST_CLASSES, |
| t_obs=train_ds.T_obs, t_fut=train_ds.T_fut, |
| d_model=args.d_model, dropout=args.dropout, |
| ).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"params={n_params:,}") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, |
| weight_decay=args.weight_decay) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=args.epochs, eta_min=args.lr * 0.05 |
| ) |
| criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) |
|
|
| out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) |
| best = {"frame_acc_action": -1.0, "epoch": 0, "state_dict": None} |
|
|
| for ep in range(1, args.epochs + 1): |
| t0 = time.time() |
| tr_loss, tr_acc = train_epoch(model, tr_loader, optimizer, criterion, device) |
| ev = evaluate(model, te_loader, device, t_fut=train_ds.T_fut) |
| sched.step() |
| print(f" E{ep:2d} | tr {tr_loss:.4f}/{tr_acc:.3f} " |
| f"| te frame_acc {ev['frame_acc']:.3f} action {ev['frame_acc_action']:.3f} " |
| f"| {time.time()-t0:.1f}s") |
| if ev["frame_acc_action"] > best["frame_acc_action"]: |
| best = {**ev, "epoch": ep, "state_dict": {k: v.cpu() for k, v in model.state_dict().items()}} |
| torch.save(best["state_dict"], out_dir / "model_best.pt") |
|
|
| |
| final = {k: v for k, v in best.items() if k != "state_dict"} |
| out = { |
| "method": args.model, |
| "modalities": mods, |
| "seed": args.seed, |
| "n_params": n_params, |
| "T_obs": train_ds.T_obs, |
| "T_fut": train_ds.T_fut, |
| "best_epoch": int(best["epoch"]), |
| "frame_acc": float(best["frame_acc"]), |
| "frame_acc_action": float(best["frame_acc_action"]), |
| "per_h_acc": list(map(float, best["per_h_acc"])), |
| "per_h_acc_action": list(map(float, best["per_h_acc_action"])), |
| "args": vars(args), |
| } |
| with open(out_dir / "results.json", "w") as f: |
| json.dump(out, f, indent=2) |
| print(f"\n[done] best frame_acc_action {best['frame_acc_action']:.4f} (epoch {best['epoch']})") |
| print(f"per_h_acc_action: {[f'{a:.3f}' for a in best['per_h_acc_action']]}") |
| print(f"saved to {out_dir}/results.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|