JS6969 commited on
Commit
70eed7c
·
verified ·
1 Parent(s): ac631ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py CHANGED
@@ -12,7 +12,43 @@
12
  # ────────────────────────────────────────────────────────
13
  # Standard imports
14
  # ────────────────────────────────────────────────────────
 
15
  import sys, types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import os
17
  import re
18
  import cv2
 
12
  # ────────────────────────────────────────────────────────
13
  # Standard imports
14
  # ────────────────────────────────────────────────────────
15
+ # ---- TorchVision shim so basicsr can import without torchvision installed ----
16
  import sys, types
17
+ try:
18
+ # If torchvision is present, great — use it.
19
+ import torchvision.transforms.functional_tensor as _ft # noqa: F401
20
+ except Exception:
21
+ # Create a minimal module that provides rgb_to_grayscale with Torch ops.
22
+ import torch
23
+
24
+ _mod = types.ModuleType("torchvision.transforms.functional_tensor")
25
+
26
+ def rgb_to_grayscale(img: "torch.Tensor", num_output_channels: int = 1) -> "torch.Tensor":
27
+ """
28
+ Minimal replacement for torchvision's rgb_to_grayscale.
29
+ Expects a Tensor with channel-last-three: (..., 3, H, W) and returns
30
+ (..., 1, H, W) or (..., 3, H, W) if num_output_channels == 3.
31
+ """
32
+ if not torch.is_tensor(img):
33
+ raise TypeError("rgb_to_grayscale expects a torch.Tensor")
34
+ if img.ndim < 3 or img.shape[-3] != 3:
35
+ raise ValueError(f"expected tensor with C=3 as the third-from-last dim, got shape {tuple(img.shape)}")
36
+
37
+ r = img[..., -3, :, :]
38
+ g = img[..., -2, :, :]
39
+ b = img[..., -1, :, :]
40
+
41
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b # same weights as TV
42
+ if num_output_channels == 3:
43
+ out = torch.stack([gray, gray, gray], dim=-3)
44
+ else:
45
+ out = gray.unsqueeze(-3)
46
+ return out
47
+
48
+ _mod.rgb_to_grayscale = rgb_to_grayscale
49
+ sys.modules["torchvision.transforms.functional_tensor"] = _mod
50
+ # ---------------------------------------------------------------------------
51
+
52
  import os
53
  import re
54
  import cv2