grocerfabrics's picture
Update app.py
9f79732 verified
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
# ---------------------------------------------------------------------------
# Model registry
# ---------------------------------------------------------------------------
MODELS = {
"HAT SRx4 - ImageNet pretrained (clean images)": "HAT_SRx4_ImageNet-pretrain.pth",
"Real-HAT GAN Sharper (real-world photos)": "Real_HAT_GAN_sharper.pth",
}
_weight_paths = {}
def get_weight(filename):
if filename not in _weight_paths:
path = hf_hub_download(repo_id="Acly/hat", filename=filename)
_weight_paths[filename] = path
return _weight_paths[filename]
# ---------------------------------------------------------------------------
# Surgical HAT import β€” bypasses basicsr's broken __init__ chain
# ---------------------------------------------------------------------------
def _load_hat_class():
"""
Load HAT from basicsr without triggering its broken __init__.py chain.
Root cause
----------
basicsr/__init__.py β†’ basicsr/data/__init__.py β†’ degradations.py
β†’ torchvision.transforms.functional_tensor (removed in torchvision 0.16+)
Additionally, basicsr/archs/arch_util.py does:
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
which walks back up through basicsr/__init__ and triggers the same crash,
and also causes a circular-import error with basicvsr_arch.py.
Fix
---
1. Pre-register lightweight stub modules for every basicsr sub-namespace
that would pull in the broken chain. Any attribute access on stubs
returns None β€” safe because HAT never calls DCN or data-pipeline code.
2. Load arch_util.py and hat_arch.py directly from disk via
importlib.util.spec_from_file_location, which never runs package
__init__ files.
"""
import sys
import types
import importlib.util
import pathlib
# -- 1. Locate basicsr on disk (without importing it) ------------------
spec = importlib.util.find_spec("basicsr")
if spec is None or spec.origin is None:
raise ImportError(
"basicsr is not installed. Add 'basicsr' to requirements.txt."
)
basicsr_root = pathlib.Path(spec.origin).parent # …/site-packages/basicsr/
# -- 2. Stub out every basicsr namespace we must NOT execute -----------
def _stub(dotted_name):
if dotted_name not in sys.modules:
sys.modules[dotted_name] = types.ModuleType(dotted_name)
return sys.modules[dotted_name]
# Stub the package root and all sub-packages that would otherwise be
# auto-imported when arch_util does "from basicsr.ops.dcn import ...".
for ns in [
"basicsr",
"basicsr.ops",
"basicsr.ops.dcn",
"basicsr.data",
"basicsr.models",
"basicsr.losses",
"basicsr.utils",
"basicsr.archs", # stub the package; files loaded manually below
"basicsr.metrics",
]:
_stub(ns)
# arch_util imports exactly these two names from basicsr.ops.dcn.
# HAT never invokes DCN at inference time, so None is fine.
dcn = sys.modules["basicsr.ops.dcn"]
dcn.ModulatedDeformConvPack = None
dcn.modulated_deform_conv = None
# -- 3. Load the two .py files directly from disk ----------------------
def _load_file(dotted_name, rel_path):
if dotted_name in sys.modules:
return sys.modules[dotted_name]
full_path = basicsr_root / rel_path
if not full_path.exists():
raise ImportError(f"basicsr source file not found: {full_path}")
file_spec = importlib.util.spec_from_file_location(dotted_name, full_path)
module = importlib.util.module_from_spec(file_spec)
# Register before exec_module so that any intra-file self-imports resolve
sys.modules[dotted_name] = module
file_spec.loader.exec_module(module)
return module
_load_file("basicsr.archs.arch_util", "archs/arch_util.py")
hat_module = _load_file("basicsr.archs.hat_arch", "archs/hat_arch.py")
return hat_module.HAT
def build_hat_model(filename, device):
HAT = _load_hat_class()
net = HAT(
upscale=4,
in_chans=3,
img_size=64,
window_size=16,
compress_ratio=3,
squeeze_factor=30,
conv_scale=0.01,
overlap_ratio=0.5,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="pixelshuffle",
resi_connection="1conv",
)
weight_path = get_weight(filename)
state_dict = torch.load(weight_path, map_location="cpu")
if "params_ema" in state_dict:
state_dict = state_dict["params_ema"]
elif "params" in state_dict:
state_dict = state_dict["params"]
net.load_state_dict(state_dict, strict=True)
net.eval()
net.to(device)
return net
# ---------------------------------------------------------------------------
# Tile-based inference
# ---------------------------------------------------------------------------
TILE_OVERLAP = 32
def tile_process(img_tensor, model, tile, overlap, scale=4):
b, c, h, w = img_tensor.shape
out_h, out_w = h * scale, w * scale
output = torch.zeros(b, c, out_h, out_w, device=img_tensor.device)
weight_map = torch.zeros(b, 1, out_h, out_w, device=img_tensor.device)
stride = tile - overlap
h_steps = list(range(0, max(h - overlap, 1), stride))
w_steps = list(range(0, max(w - overlap, 1), stride))
if not h_steps or h_steps[-1] + tile < h:
h_steps.append(max(h - tile, 0))
if not w_steps or w_steps[-1] + tile < w:
w_steps.append(max(w - tile, 0))
for hs in h_steps:
for ws in w_steps:
he = min(hs + tile, h)
we = min(ws + tile, w)
hs_ = max(he - tile, 0)
ws_ = max(we - tile, 0)
patch = img_tensor[:, :, hs_:he, ws_:we]
with torch.no_grad():
out_patch = model(patch)
ohs, ohe = hs_ * scale, he * scale
ows, owe = ws_ * scale, we * scale
output[:, :, ohs:ohe, ows:owe] += out_patch
weight_map[:, :, ohs:ohe, ows:owe] += 1.0
output /= weight_map
return output
@spaces.GPU
def upscale(image, model_name, tile_size):
if image is None:
raise gr.Error("Please upload an image first.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
filename = MODELS[model_name]
model = build_hat_model(filename, device)
img = image.convert("RGB")
img_np = np.array(img, dtype=np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
h, w = img_tensor.shape[2], img_tensor.shape[3]
with torch.no_grad():
if h > tile_size or w > tile_size:
out_tensor = tile_process(img_tensor, model, tile_size, TILE_OVERLAP, scale=4)
else:
out_tensor = model(img_tensor)
out_np = out_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()
return Image.fromarray((out_np * 255).astype(np.uint8))
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="HAT Super Resolution") as demo:
gr.Markdown(
"""
# πŸ” HAT Super Resolution
**Hybrid Attention Transformer** β€” state-of-the-art 4Γ— image upscaling.
Upload an image, choose a model, and hit **Upscale**.
> **Models:** [`Acly/hat`](https://huggingface.co/Acly/hat) &nbsp;|&nbsp; **Paper:** [HAT (CVPR 2023)](https://arxiv.org/abs/2205.04437)
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Input image",
)
model_choice = gr.Radio(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model",
)
tile_slider = gr.Slider(
minimum=128,
maximum=512,
value=256,
step=64,
label="Tile size (lower = less VRAM, slower)",
)
run_btn = gr.Button("Upscale 4x", variant="primary")
with gr.Column():
output_image = gr.Image(
type="pil",
label="Output (4x upscaled)",
interactive=False,
)
with gr.Accordion("Model details", open=False):
gr.Markdown(
"""
| Model | Best for |
|---|---|
| **HAT SRx4 - ImageNet pretrained** | Clean inputs. Highest PSNR on benchmarks. |
| **Real-HAT GAN Sharper** | Real-world photos. More perceptual sharpness. |
Both models perform **4x upscaling** (e.g. 256Γ—256 β†’ 1024Γ—1024). Running on **ZeroGPU (NVIDIA H200)**.
"""
)
run_btn.click(
fn=upscale,
inputs=[input_image, model_choice, tile_slider],
outputs=output_image,
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())