| """Minimal inference example for the GB-LSR Hugging Face release. |
| |
| Loads a released model from its safetensors + config.json and runs it. |
| |
| # arbitrary-scale SR on an image, 4x |
| python infer.py --model gblsr-scalar-asr --image in.png --scale 4 --out sr.png |
| |
| # native reconstruction of a 256x256 image |
| python infer.py --model gblsr-scalar --image in256.png --out recon.png |
| |
| With no --image, runs on a random tensor and just prints the output shape. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import load_file |
|
|
| HERE = Path(__file__).resolve().parent.parent |
|
|
|
|
| def load_model(name: str): |
| cfg = json.loads((HERE / name / "config.json").read_text()) |
| weights = load_file(str(HERE / name / "model.safetensors")) |
| if name == "gblsr-scalar": |
| from gblsr import BasisConfig, EncoderConfig, ModelConfig, build_model |
|
|
| model = build_model( |
| ModelConfig(arm="local_spectral", image_size=256, patch_size=32, |
| basis=BasisConfig(patch_size=32, p_max=16, s_e_range=(0.25, 2.0)), |
| encoder=EncoderConfig()), |
| bandwidth_mode="global_scalar", adapt_order=False) |
| else: |
| from gblsr import GBLSRScalarASR |
|
|
| model = GBLSRScalarASR(encoder_cfg=cfg["encoder_cfg"], decoder_cfg=cfg["decoder_cfg"]) |
| model.load_state_dict(weights, strict=True) |
| return model.eval(), cfg |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model", required=True, help="release folder name, e.g. gblsr-scalar-asr") |
| ap.add_argument("--image", default=None, help="input image (PNG/JPG); random if omitted") |
| ap.add_argument("--scale", type=float, default=4.0, help="ASR upscale factor") |
| ap.add_argument("--device", default=None, help="cuda / cpu (default: cuda if available)") |
| ap.add_argument("--out", default=None) |
| args = ap.parse_args() |
|
|
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") |
| model, cfg = load_model(args.model) |
| model = model.to(device) |
| is_asr = args.model != "gblsr-scalar" |
|
|
| if args.image: |
| from PIL import Image |
| import numpy as np |
|
|
| arr = np.asarray(Image.open(args.image).convert("RGB"), dtype="float32") / 255.0 |
| x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) |
| else: |
| x = torch.rand(1, 3, 64 if is_asr else 256, 64 if is_asr else 256) |
| x = x.to(device) |
|
|
| with torch.no_grad(): |
| if is_asr: |
| H, W = int(x.shape[-2] * args.scale), int(x.shape[-1] * args.scale) |
| out = model.predict_full(x, H_q=H, W_q=W) |
| else: |
| d = model(x) |
| out = next(v for v in d.values() if torch.is_tensor(v) and v.dim() == 4 and v.shape[1] == 3) |
| out = out.clamp(0, 1) |
| print(f"{args.model}: input {tuple(x.shape)} -> output {tuple(out.shape)}") |
|
|
| if args.out: |
| from PIL import Image |
| import numpy as np |
|
|
| img = (out[0].permute(1, 2, 0).cpu().numpy() * 255).round().astype("uint8") |
| Image.fromarray(img).save(args.out) |
| print(f"wrote {args.out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|