File size: 3,189 Bytes
52e5c4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc7ad5f
 
52e5c4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Minimal inference example for the GB-LSR Hugging Face release.

Loads a released model from its safetensors + config.json and runs it.

  # arbitrary-scale SR on an image, 4x
  python infer.py --model gblsr-scalar-asr --image in.png --scale 4 --out sr.png

  # native reconstruction of a 256x256 image
  python infer.py --model gblsr-scalar --image in256.png --out recon.png

With no --image, runs on a random tensor and just prints the output shape.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import torch
from safetensors.torch import load_file

HERE = Path(__file__).resolve().parent.parent  # release root


def load_model(name: str):
    cfg = json.loads((HERE / name / "config.json").read_text())
    weights = load_file(str(HERE / name / "model.safetensors"))
    if name == "gblsr-scalar":
        from gblsr import BasisConfig, EncoderConfig, ModelConfig, build_model

        model = build_model(
            ModelConfig(arm="local_spectral", image_size=256, patch_size=32,
                        basis=BasisConfig(patch_size=32, p_max=16, s_e_range=(0.25, 2.0)),
                        encoder=EncoderConfig()),
            bandwidth_mode="global_scalar", adapt_order=False)
    else:
        from gblsr import GBLSRScalarASR

        model = GBLSRScalarASR(encoder_cfg=cfg["encoder_cfg"], decoder_cfg=cfg["decoder_cfg"])
    model.load_state_dict(weights, strict=True)
    return model.eval(), cfg


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", required=True, help="release folder name, e.g. gblsr-scalar-asr")
    ap.add_argument("--image", default=None, help="input image (PNG/JPG); random if omitted")
    ap.add_argument("--scale", type=float, default=4.0, help="ASR upscale factor")
    ap.add_argument("--device", default=None, help="cuda / cpu (default: cuda if available)")
    ap.add_argument("--out", default=None)
    args = ap.parse_args()

    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
    model, cfg = load_model(args.model)
    model = model.to(device)
    is_asr = args.model != "gblsr-scalar"

    if args.image:
        from PIL import Image
        import numpy as np

        arr = np.asarray(Image.open(args.image).convert("RGB"), dtype="float32") / 255.0
        x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
    else:
        x = torch.rand(1, 3, 64 if is_asr else 256, 64 if is_asr else 256)
    x = x.to(device)

    with torch.no_grad():
        if is_asr:
            H, W = int(x.shape[-2] * args.scale), int(x.shape[-1] * args.scale)
            out = model.predict_full(x, H_q=H, W_q=W)
        else:
            d = model(x)
            out = next(v for v in d.values() if torch.is_tensor(v) and v.dim() == 4 and v.shape[1] == 3)
    out = out.clamp(0, 1)
    print(f"{args.model}: input {tuple(x.shape)} -> output {tuple(out.shape)}")

    if args.out:
        from PIL import Image
        import numpy as np

        img = (out[0].permute(1, 2, 0).cpu().numpy() * 255).round().astype("uint8")
        Image.fromarray(img).save(args.out)
        print(f"wrote {args.out}")


if __name__ == "__main__":
    main()