TeX-UNet / scripts /verify_load.py
dccc2025
Release TeX-UNet checkpoints
9ddb370
#!/usr/bin/env python3
"""Verify that a released TeX-UNet safetensors checkpoint loads and runs."""
from __future__ import annotations
import argparse
import json
import sys
from dataclasses import fields
from pathlib import Path
import torch
from safetensors.torch import load_file
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Verify TeX-UNet safetensors loading.")
parser.add_argument("--code-root", type=Path, required=True, help="Path to the TeX-1500 GitHub code checkout.")
parser.add_argument("--model-dir", type=Path, required=True, help="Path to one model directory in this HF repo.")
parser.add_argument("--height", type=int, default=64)
parser.add_argument("--width", type=int, default=64)
parser.add_argument("--device", default="cpu")
return parser.parse_args()
def main() -> int:
args = parse_args()
sys.path.insert(0, str(args.code_root))
from tex1500.model import TeXUNet, TeXUNetConfig
model_dir = args.model_dir
config = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
raw_model_config = config["model_config"]
valid = {field.name for field in fields(TeXUNetConfig)}
model_config = TeXUNetConfig(**{k: v for k, v in raw_model_config.items() if k in valid})
state = load_file(str(model_dir / "model.safetensors"), device="cpu")
model = TeXUNet(model_config)
model.load_state_dict(state, strict=True)
model.to(args.device)
model.eval()
hsi = torch.zeros(1, model_config.num_bands, args.height, args.width, device=args.device)
wavelength = torch.linspace(
model_config.wavelength_min_um,
model_config.wavelength_max_um,
model_config.num_bands,
device=args.device,
).unsqueeze(0)
with torch.no_grad():
output = model(hsi, wavelength)
summary = {
"model_dir": str(model_dir),
"tensors": len(state),
"parameters": sum(v.numel() for v in state.values()),
"T_shape": list(output["T"].shape),
"e_shape": list(output["e"].shape),
"X_shape": list(output["X"].shape),
}
print(json.dumps(summary, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())