File size: 2,268 Bytes
9ddb370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
"""Verify that a released TeX-UNet safetensors checkpoint loads and runs."""

from __future__ import annotations

import argparse
import json
import sys
from dataclasses import fields
from pathlib import Path

import torch
from safetensors.torch import load_file


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Verify TeX-UNet safetensors loading.")
    parser.add_argument("--code-root", type=Path, required=True, help="Path to the TeX-1500 GitHub code checkout.")
    parser.add_argument("--model-dir", type=Path, required=True, help="Path to one model directory in this HF repo.")
    parser.add_argument("--height", type=int, default=64)
    parser.add_argument("--width", type=int, default=64)
    parser.add_argument("--device", default="cpu")
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    sys.path.insert(0, str(args.code_root))

    from tex1500.model import TeXUNet, TeXUNetConfig

    model_dir = args.model_dir
    config = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
    raw_model_config = config["model_config"]
    valid = {field.name for field in fields(TeXUNetConfig)}
    model_config = TeXUNetConfig(**{k: v for k, v in raw_model_config.items() if k in valid})

    state = load_file(str(model_dir / "model.safetensors"), device="cpu")
    model = TeXUNet(model_config)
    model.load_state_dict(state, strict=True)
    model.to(args.device)
    model.eval()

    hsi = torch.zeros(1, model_config.num_bands, args.height, args.width, device=args.device)
    wavelength = torch.linspace(
        model_config.wavelength_min_um,
        model_config.wavelength_max_um,
        model_config.num_bands,
        device=args.device,
    ).unsqueeze(0)
    with torch.no_grad():
        output = model(hsi, wavelength)

    summary = {
        "model_dir": str(model_dir),
        "tensors": len(state),
        "parameters": sum(v.numel() for v in state.values()),
        "T_shape": list(output["T"].shape),
        "e_shape": list(output["e"].shape),
        "X_shape": list(output["X"].shape),
    }
    print(json.dumps(summary, indent=2, sort_keys=True))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())