DiffQRCode / diffqrcoder /image_processor.py
sayshara's picture
added diffqrcoder_wrapper
70be616
from typing import Optional
import torch
IMAGE_MAX_VAL = 255
def min_max_normalize(x: torch.Tensor) -> torch.Tensor:
return (x - x.min()) / (x.max() - x.min())
def convert_to_gray(
images: torch.Tensor,
cr: float = 0.2999,
cg: float = 0.587,
cb: float = 0.1114,
) -> torch.Tensor:
assert images.shape[1] == 3, \
f"The channel of color images must be 3 but get {images.shape[1]}. They are not color images."
gray_image = cr * images[:, 0] + cg * images[:, 1] + cb * images[:, 2]
return gray_image.unsqueeze(1)
def image_binarize(
image: torch.Tensor,
binary_threshold: Optional[float] = None,
) -> torch.Tensor:
if image.shape[1] == 3:
image = convert_to_gray(image)
if binary_threshold is None:
if image.max() <= 1:
binary_threshold = 0.5
else:
binary_threshold = 0.5 * IMAGE_MAX_VAL
return (image > binary_threshold).to(image.dtype)
def crop_padding(x: torch.Tensor, padding: int):
return x[:, :, padding:-padding, padding:-padding]