File size: 3,189 Bytes
52e5c4f fc7ad5f 52e5c4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """Minimal inference example for the GB-LSR Hugging Face release.
Loads a released model from its safetensors + config.json and runs it.
# arbitrary-scale SR on an image, 4x
python infer.py --model gblsr-scalar-asr --image in.png --scale 4 --out sr.png
# native reconstruction of a 256x256 image
python infer.py --model gblsr-scalar --image in256.png --out recon.png
With no --image, runs on a random tensor and just prints the output shape.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import torch
from safetensors.torch import load_file
HERE = Path(__file__).resolve().parent.parent # release root
def load_model(name: str):
cfg = json.loads((HERE / name / "config.json").read_text())
weights = load_file(str(HERE / name / "model.safetensors"))
if name == "gblsr-scalar":
from gblsr import BasisConfig, EncoderConfig, ModelConfig, build_model
model = build_model(
ModelConfig(arm="local_spectral", image_size=256, patch_size=32,
basis=BasisConfig(patch_size=32, p_max=16, s_e_range=(0.25, 2.0)),
encoder=EncoderConfig()),
bandwidth_mode="global_scalar", adapt_order=False)
else:
from gblsr import GBLSRScalarASR
model = GBLSRScalarASR(encoder_cfg=cfg["encoder_cfg"], decoder_cfg=cfg["decoder_cfg"])
model.load_state_dict(weights, strict=True)
return model.eval(), cfg
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True, help="release folder name, e.g. gblsr-scalar-asr")
ap.add_argument("--image", default=None, help="input image (PNG/JPG); random if omitted")
ap.add_argument("--scale", type=float, default=4.0, help="ASR upscale factor")
ap.add_argument("--device", default=None, help="cuda / cpu (default: cuda if available)")
ap.add_argument("--out", default=None)
args = ap.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
model, cfg = load_model(args.model)
model = model.to(device)
is_asr = args.model != "gblsr-scalar"
if args.image:
from PIL import Image
import numpy as np
arr = np.asarray(Image.open(args.image).convert("RGB"), dtype="float32") / 255.0
x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
else:
x = torch.rand(1, 3, 64 if is_asr else 256, 64 if is_asr else 256)
x = x.to(device)
with torch.no_grad():
if is_asr:
H, W = int(x.shape[-2] * args.scale), int(x.shape[-1] * args.scale)
out = model.predict_full(x, H_q=H, W_q=W)
else:
d = model(x)
out = next(v for v in d.values() if torch.is_tensor(v) and v.dim() == 4 and v.shape[1] == 3)
out = out.clamp(0, 1)
print(f"{args.model}: input {tuple(x.shape)} -> output {tuple(out.shape)}")
if args.out:
from PIL import Image
import numpy as np
img = (out[0].permute(1, 2, 0).cpu().numpy() * 255).round().astype("uint8")
Image.fromarray(img).save(args.out)
print(f"wrote {args.out}")
if __name__ == "__main__":
main()
|