RxnIM / molscribe /augment.py
L3ul's picture
Fix corrupted augment.py — use patched version with inlined functions
ad8fab9 verified
import albumentations as A
import cv2
import math
import random
import numpy as np
# Inlined from albumentations internals (removed in albumentations>=1.4)
def safe_rotate_enlarged_img_size(angle, rows, cols):
"""Calculate the size of an image after rotation without cropping."""
deg_angle = abs(angle)
rad_angle = math.radians(deg_angle)
new_rows = int(math.ceil(rows * math.cos(rad_angle) + cols * math.sin(rad_angle)))
new_cols = int(math.ceil(cols * math.cos(rad_angle) + rows * math.sin(rad_angle)))
return new_rows, new_cols
def _maybe_process_in_chunks(process_fn, **kwargs):
"""Wrap a processing function to handle multi-channel images by processing in chunks."""
def wrapper(img):
if img.ndim == 2 or (img.ndim == 3 and img.shape[2] <= 4):
return process_fn(img, **kwargs)
# Process in chunks for images with many channels
result = np.empty_like(img)
for i in range(0, img.shape[2], 4):
end = min(i + 4, img.shape[2])
result[:, :, i:end] = process_fn(img[:, :, i:end], **kwargs)
return result
return wrapper
def keypoint_rotate(keypoint, angle, rows, cols):
"""Rotate a keypoint by angle degrees around the center of the image."""
x, y, a, s = keypoint
center = (cols / 2, rows / 2)
matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
x_new = matrix[0, 0] * x + matrix[0, 1] * y + matrix[0, 2]
y_new = matrix[1, 0] * x + matrix[1, 1] * y + matrix[1, 2]
return (x_new, y_new, a, s)
def safe_rotate(
img: np.ndarray,
angle: int = 0,
interpolation: int = cv2.INTER_LINEAR,
value: int = None,
border_mode: int = cv2.BORDER_REFLECT_101,
):
old_rows, old_cols = img.shape[:2]
# getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
image_center = (old_cols / 2, old_rows / 2)
# Rows and columns of the rotated image (not cropped)
new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols)
# Rotation Matrix
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
# Shift the image to create padding
rotation_mat[0, 2] += new_cols / 2 - image_center[0]
rotation_mat[1, 2] += new_rows / 2 - image_center[1]
# CV2 Transformation function
warp_affine_fn = _maybe_process_in_chunks(
cv2.warpAffine,
M=rotation_mat,
dsize=(new_cols, new_rows),
flags=interpolation,
borderMode=border_mode,
borderValue=value,
)
# rotate image with the new bounds
rotated_img = warp_affine_fn(img)
return rotated_img
def keypoint_safe_rotate(keypoint, angle, rows, cols):
old_rows = rows
old_cols = cols
# Rows and columns of the rotated image (not cropped)
new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols)
col_diff = (new_cols - old_cols) / 2
row_diff = (new_rows - old_rows) / 2
# Shift keypoint
shifted_keypoint = (int(keypoint[0] + col_diff), int(keypoint[1] + row_diff), keypoint[2], keypoint[3])
# Rotate keypoint
rotated_keypoint = keypoint_rotate(shifted_keypoint, angle, rows=new_rows, cols=new_cols)
return rotated_keypoint
class SafeRotate(A.SafeRotate):
def __init__(
self,
limit=90,
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101,
value=None,
mask_value=None,
always_apply=False,
p=0.5,
):
super(SafeRotate, self).__init__(
limit=limit,
interpolation=interpolation,
border_mode=border_mode,
value=value,
mask_value=mask_value,
always_apply=always_apply,
p=p)
def apply(self, img, angle=0, interpolation=cv2.INTER_LINEAR, **params):
return safe_rotate(
img=img, value=self.value, angle=angle, interpolation=interpolation, border_mode=self.border_mode)
def apply_to_keypoint(self, keypoint, angle=0, **params):
return keypoint_safe_rotate(keypoint, angle=angle, rows=params["rows"], cols=params["cols"])
class CropWhite(A.DualTransform):
def __init__(self, value=(255, 255, 255), pad=0, p=1.0):
super(CropWhite, self).__init__(p=p)
self.value = value
self.pad = pad
assert pad >= 0
def update_params(self, params, **kwargs):
super().update_params(params, **kwargs)
assert "image" in kwargs
img = kwargs["image"]
height, width, _ = img.shape
x = (img != self.value).sum(axis=2)
if x.sum() == 0:
return params
row_sum = x.sum(axis=1)
top = 0
while row_sum[top] == 0 and top+1 < height:
top += 1
bottom = height
while row_sum[bottom-1] == 0 and bottom-1 > top:
bottom -= 1
col_sum = x.sum(axis=0)
left = 0
while col_sum[left] == 0 and left+1 < width:
left += 1
right = width
while col_sum[right-1] == 0 and right-1 > left:
right -= 1
params.update({"crop_top": top, "crop_bottom": height - bottom,
"crop_left": left, "crop_right": width - right})
return params
def apply(self, img, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params):
height, width, _ = img.shape
img = img[crop_top:height - crop_bottom, crop_left:width - crop_right]
img = cv2.copyMakeBorder(
img, self.pad, self.pad, self.pad, self.pad,
borderType=cv2.BORDER_CONSTANT, value=self.value)
return img
def apply_to_keypoint(self, keypoint, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params):
x, y, angle, scale = keypoint[:4]
return x - crop_left + self.pad, y - crop_top + self.pad, angle, scale
def get_transform_init_args_names(self):
return ('value', 'pad')
class PadWhite(A.DualTransform):
def __init__(self, pad_ratio=0.2, p=0.5, value=(255, 255, 255)):
super(PadWhite, self).__init__(p=p)
self.pad_ratio = pad_ratio
self.value = value
def update_params(self, params, **kwargs):
super().update_params(params, **kwargs)
assert "image" in kwargs
img = kwargs["image"]
height, width, _ = img.shape
side = random.randrange(4)
if side == 0:
params['pad_top'] = int(height * self.pad_ratio * random.random())
elif side == 1:
params['pad_bottom'] = int(height * self.pad_ratio * random.random())
elif side == 2:
params['pad_left'] = int(width * self.pad_ratio * random.random())
elif side == 3:
params['pad_right'] = int(width * self.pad_ratio * random.random())
return params
def apply(self, img, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params):
height, width, _ = img.shape
img = cv2.copyMakeBorder(
img, pad_top, pad_bottom, pad_left, pad_right,
borderType=cv2.BORDER_CONSTANT, value=self.value)
return img
def apply_to_keypoint(self, keypoint, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params):
x, y, angle, scale = keypoint[:4]
return x + pad_left, y + pad_top, angle, scale
def get_transform_init_args_names(self):
return ('value', 'pad_ratio')
class SaltAndPepperNoise(A.DualTransform):
def __init__(self, num_dots, value=(0, 0, 0), p=0.5):
super().__init__(p)
self.num_dots = num_dots
self.value = value
def apply(self, img, **params):
height, width, _ = img.shape
num_dots = random.randrange(self.num_dots + 1)
for i in range(num_dots):
x = random.randrange(height)
y = random.randrange(width)
img[x, y] = self.value
return img
def apply_to_keypoint(self, keypoint, **params):
return keypoint
def get_transform_init_args_names(self):
return ('value', 'num_dots')
class ResizePad(A.DualTransform):
"""Resize and pad image to target dimensions. Uses cv2 directly."""
def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, value=(255, 255, 255)):
super(ResizePad, self).__init__(always_apply=True)
self.height = height
self.width = width
self.interpolation = interpolation
self.value = value
def apply(self, img, interpolation=cv2.INTER_LINEAR, **params):
h, w, _ = img.shape
new_h = min(h, self.height)
new_w = min(w, self.width)
img = cv2.resize(img, (new_w, new_h), interpolation=interpolation)
h, w, _ = img.shape
pad_top = (self.height - h) // 2
pad_bottom = (self.height - h) - pad_top
pad_left = (self.width - w) // 2
pad_right = (self.width - w) - pad_left
img = cv2.copyMakeBorder(
img, pad_top, pad_bottom, pad_left, pad_right,
borderType=cv2.BORDER_CONSTANT, value=self.value)
return img
# NormalizedGridDistortion removed — only used during training, not inference