"""Minimal inference example for the GB-LSR Hugging Face release. Loads a released model from its safetensors + config.json and runs it. # arbitrary-scale SR on an image, 4x python infer.py --model gblsr-scalar-asr --image in.png --scale 4 --out sr.png # native reconstruction of a 256x256 image python infer.py --model gblsr-scalar --image in256.png --out recon.png With no --image, runs on a random tensor and just prints the output shape. """ from __future__ import annotations import argparse import json from pathlib import Path import torch from safetensors.torch import load_file HERE = Path(__file__).resolve().parent.parent # release root def load_model(name: str): cfg = json.loads((HERE / name / "config.json").read_text()) weights = load_file(str(HERE / name / "model.safetensors")) if name == "gblsr-scalar": from gblsr import BasisConfig, EncoderConfig, ModelConfig, build_model model = build_model( ModelConfig(arm="local_spectral", image_size=256, patch_size=32, basis=BasisConfig(patch_size=32, p_max=16, s_e_range=(0.25, 2.0)), encoder=EncoderConfig()), bandwidth_mode="global_scalar", adapt_order=False) else: from gblsr import GBLSRScalarASR model = GBLSRScalarASR(encoder_cfg=cfg["encoder_cfg"], decoder_cfg=cfg["decoder_cfg"]) model.load_state_dict(weights, strict=True) return model.eval(), cfg def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True, help="release folder name, e.g. gblsr-scalar-asr") ap.add_argument("--image", default=None, help="input image (PNG/JPG); random if omitted") ap.add_argument("--scale", type=float, default=4.0, help="ASR upscale factor") ap.add_argument("--device", default=None, help="cuda / cpu (default: cuda if available)") ap.add_argument("--out", default=None) args = ap.parse_args() device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") model, cfg = load_model(args.model) model = model.to(device) is_asr = args.model != "gblsr-scalar" if args.image: from PIL import Image import numpy as np arr = np.asarray(Image.open(args.image).convert("RGB"), dtype="float32") / 255.0 x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) else: x = torch.rand(1, 3, 64 if is_asr else 256, 64 if is_asr else 256) x = x.to(device) with torch.no_grad(): if is_asr: H, W = int(x.shape[-2] * args.scale), int(x.shape[-1] * args.scale) out = model.predict_full(x, H_q=H, W_q=W) else: d = model(x) out = next(v for v in d.values() if torch.is_tensor(v) and v.dim() == 4 and v.shape[1] == 3) out = out.clamp(0, 1) print(f"{args.model}: input {tuple(x.shape)} -> output {tuple(out.shape)}") if args.out: from PIL import Image import numpy as np img = (out[0].permute(1, 2, 0).cpu().numpy() * 255).round().astype("uint8") Image.fromarray(img).save(args.out) print(f"wrote {args.out}") if __name__ == "__main__": main()