Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------------------------------------------------------- | |
| # Model registry | |
| # --------------------------------------------------------------------------- | |
| MODELS = { | |
| "HAT SRx4 - ImageNet pretrained (clean images)": "HAT_SRx4_ImageNet-pretrain.pth", | |
| "Real-HAT GAN Sharper (real-world photos)": "Real_HAT_GAN_sharper.pth", | |
| } | |
| _weight_paths = {} | |
| def get_weight(filename): | |
| if filename not in _weight_paths: | |
| path = hf_hub_download(repo_id="Acly/hat", filename=filename) | |
| _weight_paths[filename] = path | |
| return _weight_paths[filename] | |
| # --------------------------------------------------------------------------- | |
| # Surgical HAT import β bypasses basicsr's broken __init__ chain | |
| # --------------------------------------------------------------------------- | |
| def _load_hat_class(): | |
| """ | |
| Load HAT from basicsr without triggering its broken __init__.py chain. | |
| Root cause | |
| ---------- | |
| basicsr/__init__.py β basicsr/data/__init__.py β degradations.py | |
| β torchvision.transforms.functional_tensor (removed in torchvision 0.16+) | |
| Additionally, basicsr/archs/arch_util.py does: | |
| from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv | |
| which walks back up through basicsr/__init__ and triggers the same crash, | |
| and also causes a circular-import error with basicvsr_arch.py. | |
| Fix | |
| --- | |
| 1. Pre-register lightweight stub modules for every basicsr sub-namespace | |
| that would pull in the broken chain. Any attribute access on stubs | |
| returns None β safe because HAT never calls DCN or data-pipeline code. | |
| 2. Load arch_util.py and hat_arch.py directly from disk via | |
| importlib.util.spec_from_file_location, which never runs package | |
| __init__ files. | |
| """ | |
| import sys | |
| import types | |
| import importlib.util | |
| import pathlib | |
| # -- 1. Locate basicsr on disk (without importing it) ------------------ | |
| spec = importlib.util.find_spec("basicsr") | |
| if spec is None or spec.origin is None: | |
| raise ImportError( | |
| "basicsr is not installed. Add 'basicsr' to requirements.txt." | |
| ) | |
| basicsr_root = pathlib.Path(spec.origin).parent # β¦/site-packages/basicsr/ | |
| # -- 2. Stub out every basicsr namespace we must NOT execute ----------- | |
| def _stub(dotted_name): | |
| if dotted_name not in sys.modules: | |
| sys.modules[dotted_name] = types.ModuleType(dotted_name) | |
| return sys.modules[dotted_name] | |
| # Stub the package root and all sub-packages that would otherwise be | |
| # auto-imported when arch_util does "from basicsr.ops.dcn import ...". | |
| for ns in [ | |
| "basicsr", | |
| "basicsr.ops", | |
| "basicsr.ops.dcn", | |
| "basicsr.data", | |
| "basicsr.models", | |
| "basicsr.losses", | |
| "basicsr.utils", | |
| "basicsr.archs", # stub the package; files loaded manually below | |
| "basicsr.metrics", | |
| ]: | |
| _stub(ns) | |
| # arch_util imports exactly these two names from basicsr.ops.dcn. | |
| # HAT never invokes DCN at inference time, so None is fine. | |
| dcn = sys.modules["basicsr.ops.dcn"] | |
| dcn.ModulatedDeformConvPack = None | |
| dcn.modulated_deform_conv = None | |
| # -- 3. Load the two .py files directly from disk ---------------------- | |
| def _load_file(dotted_name, rel_path): | |
| if dotted_name in sys.modules: | |
| return sys.modules[dotted_name] | |
| full_path = basicsr_root / rel_path | |
| if not full_path.exists(): | |
| raise ImportError(f"basicsr source file not found: {full_path}") | |
| file_spec = importlib.util.spec_from_file_location(dotted_name, full_path) | |
| module = importlib.util.module_from_spec(file_spec) | |
| # Register before exec_module so that any intra-file self-imports resolve | |
| sys.modules[dotted_name] = module | |
| file_spec.loader.exec_module(module) | |
| return module | |
| _load_file("basicsr.archs.arch_util", "archs/arch_util.py") | |
| hat_module = _load_file("basicsr.archs.hat_arch", "archs/hat_arch.py") | |
| return hat_module.HAT | |
| def build_hat_model(filename, device): | |
| HAT = _load_hat_class() | |
| net = HAT( | |
| upscale=4, | |
| in_chans=3, | |
| img_size=64, | |
| window_size=16, | |
| compress_ratio=3, | |
| squeeze_factor=30, | |
| conv_scale=0.01, | |
| overlap_ratio=0.5, | |
| img_range=1.0, | |
| depths=[6, 6, 6, 6, 6, 6], | |
| embed_dim=180, | |
| num_heads=[6, 6, 6, 6, 6, 6], | |
| mlp_ratio=2, | |
| upsampler="pixelshuffle", | |
| resi_connection="1conv", | |
| ) | |
| weight_path = get_weight(filename) | |
| state_dict = torch.load(weight_path, map_location="cpu") | |
| if "params_ema" in state_dict: | |
| state_dict = state_dict["params_ema"] | |
| elif "params" in state_dict: | |
| state_dict = state_dict["params"] | |
| net.load_state_dict(state_dict, strict=True) | |
| net.eval() | |
| net.to(device) | |
| return net | |
| # --------------------------------------------------------------------------- | |
| # Tile-based inference | |
| # --------------------------------------------------------------------------- | |
| TILE_OVERLAP = 32 | |
| def tile_process(img_tensor, model, tile, overlap, scale=4): | |
| b, c, h, w = img_tensor.shape | |
| out_h, out_w = h * scale, w * scale | |
| output = torch.zeros(b, c, out_h, out_w, device=img_tensor.device) | |
| weight_map = torch.zeros(b, 1, out_h, out_w, device=img_tensor.device) | |
| stride = tile - overlap | |
| h_steps = list(range(0, max(h - overlap, 1), stride)) | |
| w_steps = list(range(0, max(w - overlap, 1), stride)) | |
| if not h_steps or h_steps[-1] + tile < h: | |
| h_steps.append(max(h - tile, 0)) | |
| if not w_steps or w_steps[-1] + tile < w: | |
| w_steps.append(max(w - tile, 0)) | |
| for hs in h_steps: | |
| for ws in w_steps: | |
| he = min(hs + tile, h) | |
| we = min(ws + tile, w) | |
| hs_ = max(he - tile, 0) | |
| ws_ = max(we - tile, 0) | |
| patch = img_tensor[:, :, hs_:he, ws_:we] | |
| with torch.no_grad(): | |
| out_patch = model(patch) | |
| ohs, ohe = hs_ * scale, he * scale | |
| ows, owe = ws_ * scale, we * scale | |
| output[:, :, ohs:ohe, ows:owe] += out_patch | |
| weight_map[:, :, ohs:ohe, ows:owe] += 1.0 | |
| output /= weight_map | |
| return output | |
| def upscale(image, model_name, tile_size): | |
| if image is None: | |
| raise gr.Error("Please upload an image first.") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| filename = MODELS[model_name] | |
| model = build_hat_model(filename, device) | |
| img = image.convert("RGB") | |
| img_np = np.array(img, dtype=np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device) | |
| h, w = img_tensor.shape[2], img_tensor.shape[3] | |
| with torch.no_grad(): | |
| if h > tile_size or w > tile_size: | |
| out_tensor = tile_process(img_tensor, model, tile_size, TILE_OVERLAP, scale=4) | |
| else: | |
| out_tensor = model(img_tensor) | |
| out_np = out_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy() | |
| return Image.fromarray((out_np * 255).astype(np.uint8)) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="HAT Super Resolution") as demo: | |
| gr.Markdown( | |
| """ | |
| # π HAT Super Resolution | |
| **Hybrid Attention Transformer** β state-of-the-art 4Γ image upscaling. | |
| Upload an image, choose a model, and hit **Upscale**. | |
| > **Models:** [`Acly/hat`](https://huggingface.co/Acly/hat) | **Paper:** [HAT (CVPR 2023)](https://arxiv.org/abs/2205.04437) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Input image", | |
| ) | |
| model_choice = gr.Radio( | |
| choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[0], | |
| label="Model", | |
| ) | |
| tile_slider = gr.Slider( | |
| minimum=128, | |
| maximum=512, | |
| value=256, | |
| step=64, | |
| label="Tile size (lower = less VRAM, slower)", | |
| ) | |
| run_btn = gr.Button("Upscale 4x", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| type="pil", | |
| label="Output (4x upscaled)", | |
| interactive=False, | |
| ) | |
| with gr.Accordion("Model details", open=False): | |
| gr.Markdown( | |
| """ | |
| | Model | Best for | | |
| |---|---| | |
| | **HAT SRx4 - ImageNet pretrained** | Clean inputs. Highest PSNR on benchmarks. | | |
| | **Real-HAT GAN Sharper** | Real-world photos. More perceptual sharpness. | | |
| Both models perform **4x upscaling** (e.g. 256Γ256 β 1024Γ1024). Running on **ZeroGPU (NVIDIA H200)**. | |
| """ | |
| ) | |
| run_btn.click( | |
| fn=upscale, | |
| inputs=[input_image, model_choice, tile_slider], | |
| outputs=output_image, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) | |