import gradio as gr import spaces import torch import numpy as np from PIL import Image from huggingface_hub import hf_hub_download # --------------------------------------------------------------------------- # Model registry # --------------------------------------------------------------------------- MODELS = { "HAT SRx4 - ImageNet pretrained (clean images)": "HAT_SRx4_ImageNet-pretrain.pth", "Real-HAT GAN Sharper (real-world photos)": "Real_HAT_GAN_sharper.pth", } _weight_paths = {} def get_weight(filename): if filename not in _weight_paths: path = hf_hub_download(repo_id="Acly/hat", filename=filename) _weight_paths[filename] = path return _weight_paths[filename] # --------------------------------------------------------------------------- # Surgical HAT import — bypasses basicsr's broken __init__ chain # --------------------------------------------------------------------------- def _load_hat_class(): """ Load HAT from basicsr without triggering its broken __init__.py chain. Root cause ---------- basicsr/__init__.py → basicsr/data/__init__.py → degradations.py → torchvision.transforms.functional_tensor (removed in torchvision 0.16+) Additionally, basicsr/archs/arch_util.py does: from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv which walks back up through basicsr/__init__ and triggers the same crash, and also causes a circular-import error with basicvsr_arch.py. Fix --- 1. Pre-register lightweight stub modules for every basicsr sub-namespace that would pull in the broken chain. Any attribute access on stubs returns None — safe because HAT never calls DCN or data-pipeline code. 2. Load arch_util.py and hat_arch.py directly from disk via importlib.util.spec_from_file_location, which never runs package __init__ files. """ import sys import types import importlib.util import pathlib # -- 1. Locate basicsr on disk (without importing it) ------------------ spec = importlib.util.find_spec("basicsr") if spec is None or spec.origin is None: raise ImportError( "basicsr is not installed. Add 'basicsr' to requirements.txt." ) basicsr_root = pathlib.Path(spec.origin).parent # …/site-packages/basicsr/ # -- 2. Stub out every basicsr namespace we must NOT execute ----------- def _stub(dotted_name): if dotted_name not in sys.modules: sys.modules[dotted_name] = types.ModuleType(dotted_name) return sys.modules[dotted_name] # Stub the package root and all sub-packages that would otherwise be # auto-imported when arch_util does "from basicsr.ops.dcn import ...". for ns in [ "basicsr", "basicsr.ops", "basicsr.ops.dcn", "basicsr.data", "basicsr.models", "basicsr.losses", "basicsr.utils", "basicsr.archs", # stub the package; files loaded manually below "basicsr.metrics", ]: _stub(ns) # arch_util imports exactly these two names from basicsr.ops.dcn. # HAT never invokes DCN at inference time, so None is fine. dcn = sys.modules["basicsr.ops.dcn"] dcn.ModulatedDeformConvPack = None dcn.modulated_deform_conv = None # -- 3. Load the two .py files directly from disk ---------------------- def _load_file(dotted_name, rel_path): if dotted_name in sys.modules: return sys.modules[dotted_name] full_path = basicsr_root / rel_path if not full_path.exists(): raise ImportError(f"basicsr source file not found: {full_path}") file_spec = importlib.util.spec_from_file_location(dotted_name, full_path) module = importlib.util.module_from_spec(file_spec) # Register before exec_module so that any intra-file self-imports resolve sys.modules[dotted_name] = module file_spec.loader.exec_module(module) return module _load_file("basicsr.archs.arch_util", "archs/arch_util.py") hat_module = _load_file("basicsr.archs.hat_arch", "archs/hat_arch.py") return hat_module.HAT def build_hat_model(filename, device): HAT = _load_hat_class() net = HAT( upscale=4, in_chans=3, img_size=64, window_size=16, compress_ratio=3, squeeze_factor=30, conv_scale=0.01, overlap_ratio=0.5, img_range=1.0, depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler="pixelshuffle", resi_connection="1conv", ) weight_path = get_weight(filename) state_dict = torch.load(weight_path, map_location="cpu") if "params_ema" in state_dict: state_dict = state_dict["params_ema"] elif "params" in state_dict: state_dict = state_dict["params"] net.load_state_dict(state_dict, strict=True) net.eval() net.to(device) return net # --------------------------------------------------------------------------- # Tile-based inference # --------------------------------------------------------------------------- TILE_OVERLAP = 32 def tile_process(img_tensor, model, tile, overlap, scale=4): b, c, h, w = img_tensor.shape out_h, out_w = h * scale, w * scale output = torch.zeros(b, c, out_h, out_w, device=img_tensor.device) weight_map = torch.zeros(b, 1, out_h, out_w, device=img_tensor.device) stride = tile - overlap h_steps = list(range(0, max(h - overlap, 1), stride)) w_steps = list(range(0, max(w - overlap, 1), stride)) if not h_steps or h_steps[-1] + tile < h: h_steps.append(max(h - tile, 0)) if not w_steps or w_steps[-1] + tile < w: w_steps.append(max(w - tile, 0)) for hs in h_steps: for ws in w_steps: he = min(hs + tile, h) we = min(ws + tile, w) hs_ = max(he - tile, 0) ws_ = max(we - tile, 0) patch = img_tensor[:, :, hs_:he, ws_:we] with torch.no_grad(): out_patch = model(patch) ohs, ohe = hs_ * scale, he * scale ows, owe = ws_ * scale, we * scale output[:, :, ohs:ohe, ows:owe] += out_patch weight_map[:, :, ohs:ohe, ows:owe] += 1.0 output /= weight_map return output @spaces.GPU def upscale(image, model_name, tile_size): if image is None: raise gr.Error("Please upload an image first.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") filename = MODELS[model_name] model = build_hat_model(filename, device) img = image.convert("RGB") img_np = np.array(img, dtype=np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device) h, w = img_tensor.shape[2], img_tensor.shape[3] with torch.no_grad(): if h > tile_size or w > tile_size: out_tensor = tile_process(img_tensor, model, tile_size, TILE_OVERLAP, scale=4) else: out_tensor = model(img_tensor) out_np = out_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy() return Image.fromarray((out_np * 255).astype(np.uint8)) # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- with gr.Blocks(title="HAT Super Resolution") as demo: gr.Markdown( """ # 🔍 HAT Super Resolution **Hybrid Attention Transformer** — state-of-the-art 4× image upscaling. Upload an image, choose a model, and hit **Upscale**. > **Models:** [`Acly/hat`](https://huggingface.co/Acly/hat)  |  **Paper:** [HAT (CVPR 2023)](https://arxiv.org/abs/2205.04437) """ ) with gr.Row(): with gr.Column(): input_image = gr.Image( type="pil", label="Input image", ) model_choice = gr.Radio( choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="Model", ) tile_slider = gr.Slider( minimum=128, maximum=512, value=256, step=64, label="Tile size (lower = less VRAM, slower)", ) run_btn = gr.Button("Upscale 4x", variant="primary") with gr.Column(): output_image = gr.Image( type="pil", label="Output (4x upscaled)", interactive=False, ) with gr.Accordion("Model details", open=False): gr.Markdown( """ | Model | Best for | |---|---| | **HAT SRx4 - ImageNet pretrained** | Clean inputs. Highest PSNR on benchmarks. | | **Real-HAT GAN Sharper** | Real-world photos. More perceptual sharpness. | Both models perform **4x upscaling** (e.g. 256×256 → 1024×1024). Running on **ZeroGPU (NVIDIA H200)**. """ ) run_btn.click( fn=upscale, inputs=[input_image, model_choice, tile_slider], outputs=output_image, ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())