Spaces:
Running
Running
Update the calling method of the HAT Module.
Browse files
app.py
CHANGED
|
@@ -324,7 +324,7 @@ class Upscale:
|
|
| 324 |
self.img_name = os.path.basename(str(img))
|
| 325 |
self.basename, self.extension = os.path.splitext(self.img_name)
|
| 326 |
|
| 327 |
-
img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED)
|
| 328 |
|
| 329 |
self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
|
| 330 |
if len(img.shape) == 2: # for gray inputs
|
|
@@ -381,51 +381,7 @@ class Upscale:
|
|
| 381 |
# print(f"{param}: {value}")
|
| 382 |
elif upscale_type == "HAT":
|
| 383 |
half = False
|
| 384 |
-
import torch.nn.functional as F
|
| 385 |
from basicsr.archs.hat_arch import HAT
|
| 386 |
-
class HATWithAutoPadding(HAT):
|
| 387 |
-
def pad_to_multiple(self, img, multiple):
|
| 388 |
-
"""
|
| 389 |
-
Fill the image to multiples of both width and height as integers.
|
| 390 |
-
"""
|
| 391 |
-
_, _, h, w = img.shape
|
| 392 |
-
pad_h = (multiple - h % multiple) % multiple
|
| 393 |
-
pad_w = (multiple - w % multiple) % multiple
|
| 394 |
-
|
| 395 |
-
# Padding on the top, bottom, left, and right.
|
| 396 |
-
pad_top = pad_h // 2
|
| 397 |
-
pad_bottom = pad_h - pad_top
|
| 398 |
-
pad_left = pad_w // 2
|
| 399 |
-
pad_right = pad_w - pad_left
|
| 400 |
-
|
| 401 |
-
img_padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode="reflect")
|
| 402 |
-
return img_padded, (pad_top, pad_bottom, pad_left, pad_right)
|
| 403 |
-
|
| 404 |
-
def remove_padding(self, img, pad_info):
|
| 405 |
-
"""
|
| 406 |
-
Remove padding and restore to the original size, considering upscaling.
|
| 407 |
-
"""
|
| 408 |
-
pad_top, pad_bottom, pad_left, pad_right = pad_info
|
| 409 |
-
|
| 410 |
-
# Adjust padding based on upscaling factor
|
| 411 |
-
pad_top = int(pad_top * self.upscale)
|
| 412 |
-
pad_bottom = int(pad_bottom * self.upscale)
|
| 413 |
-
pad_left = int(pad_left * self.upscale)
|
| 414 |
-
pad_right = int(pad_right * self.upscale)
|
| 415 |
-
|
| 416 |
-
return img[:, :, pad_top:-pad_bottom if pad_bottom > 0 else None, pad_left:-pad_right if pad_right > 0 else None]
|
| 417 |
-
|
| 418 |
-
def forward(self, x):
|
| 419 |
-
# Step 1: Auto padding
|
| 420 |
-
x_padded, pad_info = self.pad_to_multiple(x, self.window_size)
|
| 421 |
-
|
| 422 |
-
# Step 2: Normal model processing
|
| 423 |
-
x_processed = super().forward(x_padded)
|
| 424 |
-
|
| 425 |
-
# Step 3: Remove padding
|
| 426 |
-
x_cropped = self.remove_padding(x_processed, pad_info)
|
| 427 |
-
return x_cropped
|
| 428 |
-
|
| 429 |
# The parameters are derived from the XPixelGroup project files: HAT-L_SRx4_ImageNet-pretrain.yml and HAT-S_SRx4.yml.
|
| 430 |
# https://github.com/XPixelGroup/HAT/tree/main/options/test
|
| 431 |
if "hat-l" in upscale_model.lower():
|
|
@@ -446,7 +402,7 @@ class Upscale:
|
|
| 446 |
num_heads = [6, 6, 6, 6, 6, 6]
|
| 447 |
mlp_ratio = 2
|
| 448 |
upsampler = "pixelshuffle"
|
| 449 |
-
model =
|
| 450 |
squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=self.netscale,)
|
| 451 |
elif "RealPLKSR" in upscale_type:
|
| 452 |
from basicsr.archs.realplksr_arch import realplksr
|
|
@@ -493,18 +449,19 @@ class Upscale:
|
|
| 493 |
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
|
| 494 |
return new_image
|
| 495 |
|
| 496 |
-
def enhance(
|
| 497 |
# img: numpy
|
| 498 |
h_input, w_input = img.shape[0:2]
|
| 499 |
pil_img = self.cv2pil(img)
|
| 500 |
-
pil_img =
|
| 501 |
cv_image = self.pil2cv(pil_img)
|
| 502 |
if outscale is not None and outscale != float(self.netscale):
|
|
|
|
| 503 |
cv_image = cv2.resize(
|
| 504 |
cv_image, (
|
| 505 |
int(w_input * outscale),
|
| 506 |
int(h_input * outscale),
|
| 507 |
-
), interpolation=
|
| 508 |
return cv_image, None
|
| 509 |
|
| 510 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 324 |
self.img_name = os.path.basename(str(img))
|
| 325 |
self.basename, self.extension = os.path.splitext(self.img_name)
|
| 326 |
|
| 327 |
+
img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED) # numpy.ndarray
|
| 328 |
|
| 329 |
self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
|
| 330 |
if len(img.shape) == 2: # for gray inputs
|
|
|
|
| 381 |
# print(f"{param}: {value}")
|
| 382 |
elif upscale_type == "HAT":
|
| 383 |
half = False
|
|
|
|
| 384 |
from basicsr.archs.hat_arch import HAT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
# The parameters are derived from the XPixelGroup project files: HAT-L_SRx4_ImageNet-pretrain.yml and HAT-S_SRx4.yml.
|
| 386 |
# https://github.com/XPixelGroup/HAT/tree/main/options/test
|
| 387 |
if "hat-l" in upscale_model.lower():
|
|
|
|
| 402 |
num_heads = [6, 6, 6, 6, 6, 6]
|
| 403 |
mlp_ratio = 2
|
| 404 |
upsampler = "pixelshuffle"
|
| 405 |
+
model = HAT(img_size=64, patch_size=1, in_chans=3, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
|
| 406 |
squeeze_factor=squeeze_factor, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=mlp_ratio, upsampler=upsampler, upscale=self.netscale,)
|
| 407 |
elif "RealPLKSR" in upscale_type:
|
| 408 |
from basicsr.archs.realplksr_arch import realplksr
|
|
|
|
| 449 |
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
|
| 450 |
return new_image
|
| 451 |
|
| 452 |
+
def enhance(self_, img, outscale=None):
|
| 453 |
# img: numpy
|
| 454 |
h_input, w_input = img.shape[0:2]
|
| 455 |
pil_img = self.cv2pil(img)
|
| 456 |
+
pil_img = self_.__call__(pil_img)
|
| 457 |
cv_image = self.pil2cv(pil_img)
|
| 458 |
if outscale is not None and outscale != float(self.netscale):
|
| 459 |
+
interpolation = cv2.INTER_AREA if outscale < float(self.netscale) else cv2.INTER_LANCZOS4
|
| 460 |
cv_image = cv2.resize(
|
| 461 |
cv_image, (
|
| 462 |
int(w_input * outscale),
|
| 463 |
int(h_input * outscale),
|
| 464 |
+
), interpolation=interpolation)
|
| 465 |
return cv_image, None
|
| 466 |
|
| 467 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|