tactile-vae / script /inference.py
WitneyWW's picture
Initial upload of tactile_vae (code, model, config, inference)
3770c94 verified
Raw
History Blame Contribute Delete
8.07 kB
"""General Lightning-based inference script for TactileVAE.
Features:
- Load any Lightning `.ckpt` checkpoint.
- Load any config YAML.
- Randomly select `N` samples from any split (`train` / `val` / `test`).
- Run reconstruction inference and save metrics + visualization.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any
import numpy as np
import pytorch_lightning as pl
import torch
import yaml
from PIL import Image
from torch.utils.data import DataLoader, Subset
_REPO_ROOT = Path(__file__).resolve().parents[2]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from tactile_vae.dataset import TactileParquetDataset
from tactile_vae.model import TactileVAE
DEFAULT_CONFIG = Path("/group2/ct/weihanx/tactile_world_model/runs/vae_baseline_3/config.snapshot.yaml")
DEFAULT_CKPT = Path("/group2/ct/weihanx/tactile_world_model/runs/vae_baseline_3/checkpoints/last.ckpt")
DEFAULT_OUT_DIR = Path("/group2/ct/weihanx/tactile_world_model/tactile_vae/inference/vae_baseline_3")
def _resolve_path(p: str | Path) -> Path:
path = Path(p)
return path if path.is_absolute() else (_REPO_ROOT / path).resolve()
def load_config(path: Path) -> dict:
with path.open() as f:
cfg = yaml.safe_load(f)
if not isinstance(cfg, dict):
raise ValueError(f"invalid config: {path}")
cfg["data"]["root"] = str(_resolve_path(cfg["data"]["root"]))
if cfg["data"].get("splits_path"):
cfg["data"]["splits_path"] = str(_resolve_path(cfg["data"]["splits_path"]))
return cfg
def pick_device(spec: str) -> torch.device:
if spec == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(spec)
class InferenceModule(pl.LightningModule):
"""Minimal LightningModule used for strict Lightning checkpoint loading."""
def __init__(self, config: dict):
super().__init__()
self.config = config
self.model = TactileVAE(**config["model"])
def forward(self, x, **kw):
return self.model(x, **kw)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--config", type=Path, default=DEFAULT_CONFIG, help="config yaml")
p.add_argument("--ckpt", type=Path, default=DEFAULT_CKPT, help="Lightning checkpoint .ckpt")
p.add_argument("--out-dir", type=Path, default=DEFAULT_OUT_DIR, help="output directory")
p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
p.add_argument("--num-samples", type=int, default=50, help="number of random samples from the split")
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--num-workers", type=int, default=0)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--device", type=str, default="auto", help="auto / cuda / cpu / cuda:0 ...")
p.add_argument("--max-grid", type=int, default=16, help="max samples shown in saved reconstruction grid")
return p.parse_args()
def build_dataset(cfg: dict, split: str) -> TactileParquetDataset:
dcfg = cfg["data"]
return TactileParquetDataset(
root=dcfg["root"],
split=split,
splits_path=dcfg.get("splits_path"),
image_size=dcfg["image_size"],
cache_files=dcfg.get("cache_files", 1),
color_jitter=None,
)
def select_subset(ds: TactileParquetDataset, n: int, seed: int) -> tuple[Subset, list[int]]:
n = min(max(1, int(n)), len(ds))
rng = np.random.default_rng(seed)
idx = rng.choice(len(ds), size=n, replace=False).tolist()
return Subset(ds, idx), idx
@torch.no_grad()
def run_inference(
module: InferenceModule,
ds: TactileParquetDataset,
subset_idx: list[int],
loader: DataLoader,
device: torch.device,
) -> tuple[list[dict[str, Any]], float, float, list[tuple[torch.Tensor, torch.Tensor]]]:
module.eval().to(device)
per_sample: list[dict[str, Any]] = []
vis_pairs: list[tuple[torch.Tensor, torch.Tensor]] = []
mae_total = 0.0
mse_total = 0.0
n_total = 0
cursor = 0
for x in loader:
x = x.to(device, non_blocking=True)
out = module.model(x, sample=False)
x_hat = out["x_hat"]
abs_err = (x - x_hat).abs().mean(dim=(1, 2, 3))
sq_err = ((x - x_hat) ** 2).mean(dim=(1, 2, 3))
bs = x.shape[0]
for i in range(bs):
gidx = subset_idx[cursor + i]
sample_id = ds.sample_id(gidx)
mae_i = float(abs_err[i].item())
mse_i = float(sq_err[i].item())
per_sample.append(
{
"subset_rank": cursor + i,
"dataset_index": int(gidx),
"sample_id": sample_id,
"mae": mae_i,
"mse": mse_i,
}
)
vis_pairs.append((x[i].detach().cpu(), x_hat[i].detach().cpu()))
mae_total += mae_i
mse_total += mse_i
n_total += 1
cursor += bs
mae_mean = mae_total / max(1, n_total)
mse_mean = mse_total / max(1, n_total)
return per_sample, mae_mean, mse_mean, vis_pairs
def save_grid(pairs: list[tuple[torch.Tensor, torch.Tensor]], out_path: Path, n_show: int, image_size: int) -> None:
n = min(n_show, len(pairs))
if n <= 0:
return
h = w = int(image_size)
canvas = np.zeros((2 * h, n * w, 3), dtype=np.uint8)
for i in range(n):
src, rec = pairs[i]
src_np = (src.clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
rec_np = (rec.clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
canvas[:h, i * w : (i + 1) * w] = src_np
canvas[h:, i * w : (i + 1) * w] = rec_np
out_path.parent.mkdir(parents=True, exist_ok=True)
Image.fromarray(canvas).save(out_path)
def main() -> None:
args = parse_args()
cfg = load_config(args.config)
device = pick_device(args.device)
args.out_dir.mkdir(parents=True, exist_ok=True)
print(f"config: {args.config}")
print(f"ckpt: {args.ckpt}")
print(f"split: {args.split}")
print(f"num_samples: {args.num_samples}")
print(f"device: {device}")
print(f"out_dir: {args.out_dir}")
ds = build_dataset(cfg, split=args.split)
subset, subset_idx = select_subset(ds, args.num_samples, args.seed)
print(f"split_size={len(ds)} selected={len(subset_idx)}")
print(f"preview_sample_ids={[ds.sample_id(i) for i in subset_idx[:5]]}")
loader = DataLoader(
subset,
batch_size=min(max(1, args.batch_size), len(subset)),
shuffle=False,
num_workers=args.num_workers,
pin_memory=device.type == "cuda",
drop_last=False,
persistent_workers=args.num_workers > 0,
)
module = InferenceModule.load_from_checkpoint(
str(args.ckpt),
config=cfg,
strict=True,
map_location="cpu",
)
per_sample, mae_mean, mse_mean, vis_pairs = run_inference(
module=module, ds=ds, subset_idx=subset_idx, loader=loader, device=device
)
grid_path = args.out_dir / "reconstruction_grid.png"
save_grid(vis_pairs, out_path=grid_path, n_show=args.max_grid, image_size=cfg["data"]["image_size"])
summary = {
"config": str(args.config),
"checkpoint": str(args.ckpt),
"split": args.split,
"seed": args.seed,
"selected_num_samples": len(subset_idx),
"mean_mae": mae_mean,
"mean_mse": mse_mean,
"grid_path": str(grid_path),
}
with (args.out_dir / "summary.json").open("w") as f:
json.dump(summary, f, indent=2)
with (args.out_dir / "per_sample_metrics.json").open("w") as f:
json.dump(per_sample, f, indent=2)
print(f"mean_mae={mae_mean:.6f} mean_mse={mse_mean:.6f}")
print(f"saved: {args.out_dir / 'summary.json'}")
print(f"saved: {args.out_dir / 'per_sample_metrics.json'}")
print(f"saved: {grid_path}")
if __name__ == "__main__":
main()