Spaces:
Running
Running
File size: 10,573 Bytes
54d2540 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 | """
export_models.py
----------------
Downloads publicly available pretrained weights for SRCNN and EDSR (HResNet-style)
and exports them as ONNX files into the ./model/ directory.
Run once before starting app.py:
pip install torch torchvision huggingface_hub basicsr
python export_models.py
After this script finishes you should have:
model/SRCNN_x4.onnx
model/HResNet_x4.onnx
Then upload both files to Google Drive, copy the file IDs into DRIVE_IDS in app.py,
OR set LOCAL_ONLY = True below to skip Drive entirely and load straight from disk.
"""
import os
import torch
import torch.nn as nn
import torch.onnx
from pathlib import Path
MODEL_DIR = Path("model")
MODEL_DIR.mkdir(exist_ok=True)
# ---------------------------------------------------------------------------
# Set to True to skip Drive and have app.py load the ONNX files from disk
# directly. In app.py, remove the download_from_drive call for these keys
# (or just leave the placeholder Drive ID β the script already guards against
# missing files gracefully).
# ---------------------------------------------------------------------------
LOCAL_ONLY = True # flip to False once you have Drive IDs
# ===========================================================================
# 1. SRCNN Γ4
# Architecture: Dong et al. 2014 β 3 conv layers, no upsampling inside
# the network. Input is bicubic-upscaled LR; output is the refined HR.
# We bicubic-upsample inside a wrapper so the ONNX takes a raw LR image.
# ===========================================================================
class SRCNN(nn.Module):
"""Original SRCNN (Dong et al., 2014)."""
def __init__(self, num_channels: int = 3):
super().__init__()
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
return self.conv3(x)
class SRCNNx4Wrapper(nn.Module):
"""
Wraps SRCNN so the ONNX input is a LOW-resolution image.
Internally bicubic-upsamples by Γ4 before feeding SRCNN,
matching the interface expected by app.py's tile_upscale_model.
"""
def __init__(self, srcnn: SRCNN, scale: int = 4):
super().__init__()
self.srcnn = srcnn
self.scale = scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (1, 3, H, W) β low-res, float32 in [0, 1]
up = torch.nn.functional.interpolate(
x, scale_factor=self.scale, mode="bicubic", align_corners=False
)
return self.srcnn(up)
def build_srcnn_x4() -> nn.Module:
"""
Loads pretrained SRCNN weights from the basicsr model zoo.
Falls back to random init with a warning if download fails.
"""
srcnn = SRCNN(num_channels=3)
wrapper = SRCNNx4Wrapper(srcnn, scale=4)
# Pretrained weights from the basicsr / mmedit community
# (original Caffe weights re-converted to PyTorch by https://github.com/yjn870/SRCNN-pytorch)
SRCNN_WEIGHTS_URL = (
"https://github.com/yjn870/SRCNN-pytorch/raw/master/models/"
"srcnn_x4.pth"
)
weights_path = MODEL_DIR / "srcnn_x4.pth"
if not weights_path.exists():
print(" Downloading SRCNN Γ4 weights β¦")
try:
import urllib.request
urllib.request.urlretrieve(SRCNN_WEIGHTS_URL, weights_path)
print(f" Saved β {weights_path}")
except Exception as e:
print(f" [WARN] Could not download SRCNN weights: {e}")
print(" Continuing with random init (quality will be poor).")
return wrapper
state = torch.load(weights_path, map_location="cpu")
# The yjn870 checkpoint uses keys conv1/conv2/conv3 matching our module
try:
srcnn.load_state_dict(state, strict=True)
print(" SRCNN weights loaded β")
except RuntimeError as e:
print(f" [WARN] Weight mismatch: {e}\n Proceeding with partial load.")
srcnn.load_state_dict(state, strict=False)
return wrapper
# ===========================================================================
# 2. EDSR (HResNet-style) Γ4
# EDSR-baseline (Lim et al., 2017) is the canonical "deep residual" SR
# network. Pretrained weights from eugenesiow/torch-sr (HuggingFace).
# ===========================================================================
class ResBlock(nn.Module):
def __init__(self, n_feats: int, res_scale: float = 1.0):
super().__init__()
self.body = nn.Sequential(
nn.Conv2d(n_feats, n_feats, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(n_feats, n_feats, 3, padding=1),
)
self.res_scale = res_scale
def forward(self, x):
return x + self.body(x) * self.res_scale
class Upsampler(nn.Sequential):
def __init__(self, scale: int, n_feats: int):
layers = []
if scale in (2, 4):
steps = {2: 1, 4: 2}[scale]
for _ in range(steps):
layers += [
nn.Conv2d(n_feats, 4 * n_feats, 3, padding=1),
nn.PixelShuffle(2),
]
elif scale == 3:
layers += [
nn.Conv2d(n_feats, 9 * n_feats, 3, padding=1),
nn.PixelShuffle(3),
]
super().__init__(*layers)
class EDSR(nn.Module):
"""
EDSR-baseline: 16 residual blocks, 64 feature channels.
Matches the publicly released weights from eugenesiow/torch-sr.
"""
def __init__(self, n_resblocks: int = 16, n_feats: int = 64,
scale: int = 4, num_channels: int = 3):
super().__init__()
self.head = nn.Conv2d(num_channels, n_feats, 3, padding=1)
self.body = nn.Sequential(*[ResBlock(n_feats) for _ in range(n_resblocks)])
self.body_tail = nn.Conv2d(n_feats, n_feats, 3, padding=1)
self.tail = nn.Sequential(
Upsampler(scale, n_feats),
nn.Conv2d(n_feats, num_channels, 3, padding=1),
)
def forward(self, x):
x = self.head(x)
res = self.body(x)
res = self.body_tail(res)
x = x + res
return self.tail(x)
def build_edsr_x4() -> nn.Module:
"""
Downloads EDSR-baseline Γ4 weights and loads them.
Source: eugenesiow/torch-sr (Apache-2.0 licensed).
"""
model = EDSR(n_resblocks=16, n_feats=64, scale=4)
# Direct link to the EDSR-baseline Γ4 checkpoint
EDSR_WEIGHTS_URL = (
"https://huggingface.co/eugenesiow/edsr-base/resolve/main/"
"pytorch_model_4x.pt"
)
weights_path = MODEL_DIR / "edsr_x4.pt"
if not weights_path.exists():
print(" Downloading EDSR Γ4 weights from HuggingFace β¦")
try:
import urllib.request
urllib.request.urlretrieve(EDSR_WEIGHTS_URL, weights_path)
print(f" Saved β {weights_path}")
except Exception as e:
print(f" [WARN] Could not download EDSR weights: {e}")
print(" Continuing with random init (quality will be poor).")
return model
state = torch.load(weights_path, map_location="cpu")
# eugenesiow checkpoints may wrap state_dict under a 'model' key
if "model" in state:
state = state["model"]
if "state_dict" in state:
state = state["state_dict"]
# Strip any 'module.' prefix from DataParallel wrapping
state = {k.replace("module.", ""): v for k, v in state.items()}
try:
model.load_state_dict(state, strict=True)
print(" EDSR weights loaded β")
except RuntimeError as e:
print(f" [WARN] Weight mismatch ({e}). Trying strict=False β¦")
model.load_state_dict(state, strict=False)
print(" EDSR weights loaded (partial) β")
return model
# ===========================================================================
# ONNX export helper
# ===========================================================================
def export_onnx(model: nn.Module, out_path: Path, tile_h: int = 128, tile_w: int = 128):
"""Export *model* to ONNX with dynamic H/W axes."""
model.eval()
dummy = torch.zeros(1, 3, tile_h, tile_w)
torch.onnx.export(
model,
dummy,
str(out_path),
opset_version=17,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch", 2: "H", 3: "W"},
"output": {0: "batch", 2: "H_out", 3: "W_out"},
},
)
size_mb = out_path.stat().st_size / 1_048_576
print(f" Exported β {out_path} ({size_mb:.1f} MB)")
# ===========================================================================
# Main
# ===========================================================================
if __name__ == "__main__":
print("=" * 60)
print("SpectraGAN β ONNX model exporter")
print("=" * 60)
# -- SRCNN Γ4 ------------------------------------------------------------
srcnn_out = MODEL_DIR / "SRCNN_x4.onnx"
if srcnn_out.exists():
print(f"\n[SKIP] {srcnn_out} already exists.")
else:
print("\n[1/2] Building SRCNN Γ4 β¦")
srcnn_model = build_srcnn_x4()
print(" Exporting to ONNX β¦")
export_onnx(srcnn_model, srcnn_out, tile_h=128, tile_w=128)
# -- EDSR (HResNet) Γ4 ---------------------------------------------------
edsr_out = MODEL_DIR / "HResNet_x4.onnx"
if edsr_out.exists():
print(f"\n[SKIP] {edsr_out} already exists.")
else:
print("\n[2/2] Building EDSR (HResNet) Γ4 β¦")
edsr_model = build_edsr_x4()
print(" Exporting to ONNX β¦")
export_onnx(edsr_model, edsr_out, tile_h=128, tile_w=128)
print("\n" + "=" * 60)
print("Done! Files created:")
for p in [srcnn_out, edsr_out]:
status = "β" if p.exists() else "β MISSING"
print(f" {status} {p}")
print()
if LOCAL_ONLY:
print("LOCAL_ONLY = True:")
print(" app.py will load these files directly from disk.")
print(" No Google Drive upload needed.")
else:
print("Next step:")
print(" Upload the .onnx files to Google Drive and paste")
print(" the file IDs into DRIVE_IDS in app.py.")
print("=" * 60)
|