| |
| """Verify that a released TeX-UNet safetensors checkpoint loads and runs.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from dataclasses import fields |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import load_file |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Verify TeX-UNet safetensors loading.") |
| parser.add_argument("--code-root", type=Path, required=True, help="Path to the TeX-1500 GitHub code checkout.") |
| parser.add_argument("--model-dir", type=Path, required=True, help="Path to one model directory in this HF repo.") |
| parser.add_argument("--height", type=int, default=64) |
| parser.add_argument("--width", type=int, default=64) |
| parser.add_argument("--device", default="cpu") |
| return parser.parse_args() |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| sys.path.insert(0, str(args.code_root)) |
|
|
| from tex1500.model import TeXUNet, TeXUNetConfig |
|
|
| model_dir = args.model_dir |
| config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) |
| raw_model_config = config["model_config"] |
| valid = {field.name for field in fields(TeXUNetConfig)} |
| model_config = TeXUNetConfig(**{k: v for k, v in raw_model_config.items() if k in valid}) |
|
|
| state = load_file(str(model_dir / "model.safetensors"), device="cpu") |
| model = TeXUNet(model_config) |
| model.load_state_dict(state, strict=True) |
| model.to(args.device) |
| model.eval() |
|
|
| hsi = torch.zeros(1, model_config.num_bands, args.height, args.width, device=args.device) |
| wavelength = torch.linspace( |
| model_config.wavelength_min_um, |
| model_config.wavelength_max_um, |
| model_config.num_bands, |
| device=args.device, |
| ).unsqueeze(0) |
| with torch.no_grad(): |
| output = model(hsi, wavelength) |
|
|
| summary = { |
| "model_dir": str(model_dir), |
| "tensors": len(state), |
| "parameters": sum(v.numel() for v in state.values()), |
| "T_shape": list(output["T"].shape), |
| "e_shape": list(output["e"].shape), |
| "X_shape": list(output["X"].shape), |
| } |
| print(json.dumps(summary, indent=2, sort_keys=True)) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|