#!/usr/bin/env python3 """Verify that a released TeX-UNet safetensors checkpoint loads and runs.""" from __future__ import annotations import argparse import json import sys from dataclasses import fields from pathlib import Path import torch from safetensors.torch import load_file def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Verify TeX-UNet safetensors loading.") parser.add_argument("--code-root", type=Path, required=True, help="Path to the TeX-1500 GitHub code checkout.") parser.add_argument("--model-dir", type=Path, required=True, help="Path to one model directory in this HF repo.") parser.add_argument("--height", type=int, default=64) parser.add_argument("--width", type=int, default=64) parser.add_argument("--device", default="cpu") return parser.parse_args() def main() -> int: args = parse_args() sys.path.insert(0, str(args.code_root)) from tex1500.model import TeXUNet, TeXUNetConfig model_dir = args.model_dir config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) raw_model_config = config["model_config"] valid = {field.name for field in fields(TeXUNetConfig)} model_config = TeXUNetConfig(**{k: v for k, v in raw_model_config.items() if k in valid}) state = load_file(str(model_dir / "model.safetensors"), device="cpu") model = TeXUNet(model_config) model.load_state_dict(state, strict=True) model.to(args.device) model.eval() hsi = torch.zeros(1, model_config.num_bands, args.height, args.width, device=args.device) wavelength = torch.linspace( model_config.wavelength_min_um, model_config.wavelength_max_um, model_config.num_bands, device=args.device, ).unsqueeze(0) with torch.no_grad(): output = model(hsi, wavelength) summary = { "model_dir": str(model_dir), "tensors": len(state), "parameters": sum(v.numel() for v in state.values()), "T_shape": list(output["T"].shape), "e_shape": list(output["e"].shape), "X_shape": list(output["X"].shape), } print(json.dumps(summary, indent=2, sort_keys=True)) return 0 if __name__ == "__main__": raise SystemExit(main())