Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,43 @@
|
|
| 12 |
# ────────────────────────────────────────────────────────
|
| 13 |
# Standard imports
|
| 14 |
# ────────────────────────────────────────────────────────
|
|
|
|
| 15 |
import sys, types
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import os
|
| 17 |
import re
|
| 18 |
import cv2
|
|
|
|
| 12 |
# ────────────────────────────────────────────────────────
|
| 13 |
# Standard imports
|
| 14 |
# ────────────────────────────────────────────────────────
|
| 15 |
+
# ---- TorchVision shim so basicsr can import without torchvision installed ----
|
| 16 |
import sys, types
|
| 17 |
+
try:
|
| 18 |
+
# If torchvision is present, great — use it.
|
| 19 |
+
import torchvision.transforms.functional_tensor as _ft # noqa: F401
|
| 20 |
+
except Exception:
|
| 21 |
+
# Create a minimal module that provides rgb_to_grayscale with Torch ops.
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
_mod = types.ModuleType("torchvision.transforms.functional_tensor")
|
| 25 |
+
|
| 26 |
+
def rgb_to_grayscale(img: "torch.Tensor", num_output_channels: int = 1) -> "torch.Tensor":
|
| 27 |
+
"""
|
| 28 |
+
Minimal replacement for torchvision's rgb_to_grayscale.
|
| 29 |
+
Expects a Tensor with channel-last-three: (..., 3, H, W) and returns
|
| 30 |
+
(..., 1, H, W) or (..., 3, H, W) if num_output_channels == 3.
|
| 31 |
+
"""
|
| 32 |
+
if not torch.is_tensor(img):
|
| 33 |
+
raise TypeError("rgb_to_grayscale expects a torch.Tensor")
|
| 34 |
+
if img.ndim < 3 or img.shape[-3] != 3:
|
| 35 |
+
raise ValueError(f"expected tensor with C=3 as the third-from-last dim, got shape {tuple(img.shape)}")
|
| 36 |
+
|
| 37 |
+
r = img[..., -3, :, :]
|
| 38 |
+
g = img[..., -2, :, :]
|
| 39 |
+
b = img[..., -1, :, :]
|
| 40 |
+
|
| 41 |
+
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b # same weights as TV
|
| 42 |
+
if num_output_channels == 3:
|
| 43 |
+
out = torch.stack([gray, gray, gray], dim=-3)
|
| 44 |
+
else:
|
| 45 |
+
out = gray.unsqueeze(-3)
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
_mod.rgb_to_grayscale = rgb_to_grayscale
|
| 49 |
+
sys.modules["torchvision.transforms.functional_tensor"] = _mod
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
import os
|
| 53 |
import re
|
| 54 |
import cv2
|