faster-layerdivider / ldivider /ld_processor_fast.py
sk-uma's picture
init
eb106c6
import random
from typing import List, Tuple
import cv2
import numpy as np
from skimage import color
from sklearn.cluster import MiniBatchKMeans
from sklearn.utils import shuffle
def _fix_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
SEED = 42
_fix_seed(SEED)
def _get_new_group(rgb_means: np.ndarray, threshold: int):
merge_target = []
lab_means = color.rgb2lab(rgb_means, channel_axis=1)
for i in range(len(rgb_means)):
for j in range(i + 1, len(rgb_means)):
distance = color.deltaE_ciede2000(lab_means[i], lab_means[j])
if distance < threshold:
merge_target.append((i, j))
merge_dict = {k: v for k, v in enumerate(range(len(lab_means)))}
for a, b in merge_target:
a = merge_dict[a]
merge_dict[b] = a
new_group_keys = {k: v for v, k in enumerate(set(merge_dict.values()))}
groups = {k: [] for k in new_group_keys.values()}
for k in merge_dict.keys():
merge_dict[k] = new_group_keys[merge_dict[k]]
groups[merge_dict[k]].append(k)
return merge_dict, groups
def _get_rgb_means(
img: np.ndarray,
labels: np.ndarray,
label_counts: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""画像の平均色を取得する"""
cls = np.arange(label_counts)
masks = np.bitwise_and(img[:, :, 3] > 127, cls.reshape(-1, 1, 1) == labels)
cls_counts = masks.sum(axis=(1, 2)) # 各クラスのピクセル数
cls_sum = (img[:, :, :3] * masks[:, :, :, None]).sum(
axis=(1, 2)
) # 各クラスのRGBの合計値
rgb_means = cls_sum / (cls_counts[:, None] + 1e-6) # 各クラスのRGBの平均値
return rgb_means, cls_counts, masks
def get_base(
img: np.ndarray,
loop: int,
cls_num: int,
threshold: int,
size: int,
kmeans_samples: int = -1,
) -> Tuple[np.ndarray, np.ndarray]:
"""画像をクラスタリングして平均色を取得し、色の近いクラスタを統合する関数
Parameters
----------
img : np.ndarray
入力画像
loop : int
ループ回数
cls_num : int
クラスタ数
threshold : int
統合する閾値
size : int
ブラーのサイズ
kmeans_samples : int, optional
kmenas のサンプル数, by default -1
"""
rgb_flatten = cluster_samples = img[..., :3].reshape((-1, 3))
im_h, im_w = img.shape[:2]
alpha_mask = np.where(img[..., 3] > 127)
resampling = False
if rgb_flatten.shape[0] > len(alpha_mask[0]):
# 透過部分がある場合は透過部分のみをサンプリング
cluster_samples = img[..., :3][alpha_mask].reshape((-1, 3))
resampling = True
if len(rgb_flatten) > kmeans_samples and kmeans_samples > 0:
# kmeans のサンプル数が指定されている場合は一部のみを使用する
cluster_samples = shuffle(
cluster_samples, random_state=0, n_samples=kmeans_samples
)
resampling = True
kmeans = MiniBatchKMeans(n_clusters=cls_num).fit(cluster_samples)
if resampling:
labels = kmeans.predict(rgb_flatten)
else:
labels = kmeans.labels_
label_counts = kmeans.n_clusters
labels = labels.reshape(im_h, im_w)
assert loop > 0
img_ori = img.copy()
for i in range(loop):
img = cv2.blur(img, (size, size))
rgb_means, cls_counts, _ = _get_rgb_means(img, labels, label_counts)
merge_dict, groups = _get_new_group(rgb_means, threshold)
label_counts = len(groups)
group_means = {}
for group_id, label_ids in groups.items():
means = rgb_means[label_ids]
cnt = cls_counts[label_ids]
group_means[group_id] = (means * cnt[:, None]).sum(axis=0) / cnt.sum()
for k, v in merge_dict.items():
labels[labels == k] = v
if i != loop - 1:
img[labels == v, :3] = group_means[v]
img = img_ori
rgb_means, cls_counts, masks = _get_rgb_means(img, labels, label_counts)
for mask, rgb in zip(masks, rgb_means):
img[mask, :3] = rgb
img = img.clip(0, 255).astype(np.uint8)
labels = labels.squeeze().astype(np.uint32)
return img, labels
def _split_img_batch(
images: List[np.ndarray], labels: np.ndarray
) -> List[List[np.ndarray]]:
unique_labels = np.unique(labels) # ラベルの一意なクラスを取得
splited_images = [[] for _ in range(len(images))]
for cls_no in unique_labels:
mask = labels == cls_no # マスクを拡張してimageの次元に合わせる
for i, image in enumerate(images):
masked_img = image * mask[:, :, None]
splited_images[i].append(masked_img)
return splited_images
def get_normal_layer(
input_image: np.ndarray, base_image: np.ndarray, label: np.ndarray
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""通常のレイヤーを取得する関数"""
base_image = base_image.astype(np.int32)
input_image = input_image.astype(np.int32)
base_image_hsv = cv2.cvtColor(
base_image[:, :, :3].astype(np.uint8), cv2.COLOR_RGB2HSV
)
input_image_hsv = cv2.cvtColor(
input_image[:, :, :3].astype(np.uint8), cv2.COLOR_RGB2HSV
)
bright_mask = base_image_hsv[:, :, 2] < input_image_hsv[:, :, 2]
bright_image = input_image.copy()
bright_image[:, :, 3] = bright_image[:, :, 3] * bright_mask
shadow_mask = base_image_hsv[:, :, 2] >= input_image_hsv[:, :, 2]
shadow_image = input_image.copy()
shadow_image[:, :, 3] = shadow_image[:, :, 3] * shadow_mask
[
base_layer_list,
bright_layer_list,
shadow_layer_list,
] = _split_img_batch(
np.array(
[
base_image,
bright_image,
shadow_image,
]
),
label,
)
return (
[t.astype(np.uint8) for t in base_layer_list],
[t.astype(np.uint8) for t in bright_layer_list],
[t.astype(np.uint8) for t in shadow_layer_list],
)
def get_composite_layer(
input_image: np.ndarray, base_image: np.ndarray, label: np.ndarray
) -> Tuple[
List[np.ndarray],
List[np.ndarray],
List[np.ndarray],
List[np.ndarray],
List[np.ndarray],
]:
"""画像の合成を行う関数"""
base_image = base_image.astype(np.int32)
input_image = input_image.astype(np.int32)
diff_image = base_image - input_image
# Shadow (影)
shadow_mask = (diff_image[:, :, :3] > 0).all(axis=2)
shadow_image = input_image.copy()
shadow_image[:, :, 3] = shadow_image[:, :, 3] * shadow_mask
shadow_image[:, :, :3] = (shadow_image[:, :, :3] * 255) / base_image[:, :, :3]
# Screen (逆光)
screen_mask = (diff_image[:, :, :3] < 0).all(axis=2)
screen_image = input_image.copy()
screen_image[:, :, 3] = screen_image[:, :, 3] * screen_mask
screen_image[:, :, :3] = (screen_image[:, :, :3] - base_image[:, :, :3]) / (
1 - base_image[:, :, :3] / 255
)
# Residuals (残差)
residuals_mask = ~shadow_mask & ~screen_mask
residuals_image = input_image[:, :, 3].copy()
residuals_image = residuals_image * residuals_mask
# Addition (加算)
addition_image = input_image.copy()
addition_image[:, :, 3] = residuals_image
addition_image[:, :, :3] = input_image[:, :, :3] - base_image[:, :, :3]
addition_image[:, :, :3] = addition_image[:, :, :3].clip(0, 255)
# Subtract (減算)
subtract_image = input_image.copy()
subtract_image[:, :, 3] = residuals_image
subtract_image[:, :, :3] = base_image[:, :, :3] - input_image[:, :, :3]
subtract_image[:, :, :3] = subtract_image[:, :, :3].clip(0, 255)
[
base_layer_list,
shadow_layer_list,
screen_layer_list,
addition_layer_list,
subtract_layer_list,
] = _split_img_batch(
np.array(
[
base_image,
shadow_image,
screen_image,
addition_image,
subtract_image,
]
),
label,
)
return (
[t.astype(np.uint8) for t in base_layer_list],
[t.astype(np.uint8) for t in shadow_layer_list],
[t.astype(np.uint8) for t in screen_layer_list],
[t.astype(np.uint8) for t in addition_layer_list],
[t.astype(np.uint8) for t in subtract_layer_list],
)