Spaces:
Sleeping
Sleeping
| import random | |
| from typing import List, Tuple | |
| import cv2 | |
| import numpy as np | |
| from skimage import color | |
| from sklearn.cluster import MiniBatchKMeans | |
| from sklearn.utils import shuffle | |
| def _fix_seed(seed: int) -> None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| SEED = 42 | |
| _fix_seed(SEED) | |
| def _get_new_group(rgb_means: np.ndarray, threshold: int): | |
| merge_target = [] | |
| lab_means = color.rgb2lab(rgb_means, channel_axis=1) | |
| for i in range(len(rgb_means)): | |
| for j in range(i + 1, len(rgb_means)): | |
| distance = color.deltaE_ciede2000(lab_means[i], lab_means[j]) | |
| if distance < threshold: | |
| merge_target.append((i, j)) | |
| merge_dict = {k: v for k, v in enumerate(range(len(lab_means)))} | |
| for a, b in merge_target: | |
| a = merge_dict[a] | |
| merge_dict[b] = a | |
| new_group_keys = {k: v for v, k in enumerate(set(merge_dict.values()))} | |
| groups = {k: [] for k in new_group_keys.values()} | |
| for k in merge_dict.keys(): | |
| merge_dict[k] = new_group_keys[merge_dict[k]] | |
| groups[merge_dict[k]].append(k) | |
| return merge_dict, groups | |
| def _get_rgb_means( | |
| img: np.ndarray, | |
| labels: np.ndarray, | |
| label_counts: int, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| """画像の平均色を取得する""" | |
| cls = np.arange(label_counts) | |
| masks = np.bitwise_and(img[:, :, 3] > 127, cls.reshape(-1, 1, 1) == labels) | |
| cls_counts = masks.sum(axis=(1, 2)) # 各クラスのピクセル数 | |
| cls_sum = (img[:, :, :3] * masks[:, :, :, None]).sum( | |
| axis=(1, 2) | |
| ) # 各クラスのRGBの合計値 | |
| rgb_means = cls_sum / (cls_counts[:, None] + 1e-6) # 各クラスのRGBの平均値 | |
| return rgb_means, cls_counts, masks | |
| def get_base( | |
| img: np.ndarray, | |
| loop: int, | |
| cls_num: int, | |
| threshold: int, | |
| size: int, | |
| kmeans_samples: int = -1, | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """画像をクラスタリングして平均色を取得し、色の近いクラスタを統合する関数 | |
| Parameters | |
| ---------- | |
| img : np.ndarray | |
| 入力画像 | |
| loop : int | |
| ループ回数 | |
| cls_num : int | |
| クラスタ数 | |
| threshold : int | |
| 統合する閾値 | |
| size : int | |
| ブラーのサイズ | |
| kmeans_samples : int, optional | |
| kmenas のサンプル数, by default -1 | |
| """ | |
| rgb_flatten = cluster_samples = img[..., :3].reshape((-1, 3)) | |
| im_h, im_w = img.shape[:2] | |
| alpha_mask = np.where(img[..., 3] > 127) | |
| resampling = False | |
| if rgb_flatten.shape[0] > len(alpha_mask[0]): | |
| # 透過部分がある場合は透過部分のみをサンプリング | |
| cluster_samples = img[..., :3][alpha_mask].reshape((-1, 3)) | |
| resampling = True | |
| if len(rgb_flatten) > kmeans_samples and kmeans_samples > 0: | |
| # kmeans のサンプル数が指定されている場合は一部のみを使用する | |
| cluster_samples = shuffle( | |
| cluster_samples, random_state=0, n_samples=kmeans_samples | |
| ) | |
| resampling = True | |
| kmeans = MiniBatchKMeans(n_clusters=cls_num).fit(cluster_samples) | |
| if resampling: | |
| labels = kmeans.predict(rgb_flatten) | |
| else: | |
| labels = kmeans.labels_ | |
| label_counts = kmeans.n_clusters | |
| labels = labels.reshape(im_h, im_w) | |
| assert loop > 0 | |
| img_ori = img.copy() | |
| for i in range(loop): | |
| img = cv2.blur(img, (size, size)) | |
| rgb_means, cls_counts, _ = _get_rgb_means(img, labels, label_counts) | |
| merge_dict, groups = _get_new_group(rgb_means, threshold) | |
| label_counts = len(groups) | |
| group_means = {} | |
| for group_id, label_ids in groups.items(): | |
| means = rgb_means[label_ids] | |
| cnt = cls_counts[label_ids] | |
| group_means[group_id] = (means * cnt[:, None]).sum(axis=0) / cnt.sum() | |
| for k, v in merge_dict.items(): | |
| labels[labels == k] = v | |
| if i != loop - 1: | |
| img[labels == v, :3] = group_means[v] | |
| img = img_ori | |
| rgb_means, cls_counts, masks = _get_rgb_means(img, labels, label_counts) | |
| for mask, rgb in zip(masks, rgb_means): | |
| img[mask, :3] = rgb | |
| img = img.clip(0, 255).astype(np.uint8) | |
| labels = labels.squeeze().astype(np.uint32) | |
| return img, labels | |
| def _split_img_batch( | |
| images: List[np.ndarray], labels: np.ndarray | |
| ) -> List[List[np.ndarray]]: | |
| unique_labels = np.unique(labels) # ラベルの一意なクラスを取得 | |
| splited_images = [[] for _ in range(len(images))] | |
| for cls_no in unique_labels: | |
| mask = labels == cls_no # マスクを拡張してimageの次元に合わせる | |
| for i, image in enumerate(images): | |
| masked_img = image * mask[:, :, None] | |
| splited_images[i].append(masked_img) | |
| return splited_images | |
| def get_normal_layer( | |
| input_image: np.ndarray, base_image: np.ndarray, label: np.ndarray | |
| ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: | |
| """通常のレイヤーを取得する関数""" | |
| base_image = base_image.astype(np.int32) | |
| input_image = input_image.astype(np.int32) | |
| base_image_hsv = cv2.cvtColor( | |
| base_image[:, :, :3].astype(np.uint8), cv2.COLOR_RGB2HSV | |
| ) | |
| input_image_hsv = cv2.cvtColor( | |
| input_image[:, :, :3].astype(np.uint8), cv2.COLOR_RGB2HSV | |
| ) | |
| bright_mask = base_image_hsv[:, :, 2] < input_image_hsv[:, :, 2] | |
| bright_image = input_image.copy() | |
| bright_image[:, :, 3] = bright_image[:, :, 3] * bright_mask | |
| shadow_mask = base_image_hsv[:, :, 2] >= input_image_hsv[:, :, 2] | |
| shadow_image = input_image.copy() | |
| shadow_image[:, :, 3] = shadow_image[:, :, 3] * shadow_mask | |
| [ | |
| base_layer_list, | |
| bright_layer_list, | |
| shadow_layer_list, | |
| ] = _split_img_batch( | |
| np.array( | |
| [ | |
| base_image, | |
| bright_image, | |
| shadow_image, | |
| ] | |
| ), | |
| label, | |
| ) | |
| return ( | |
| [t.astype(np.uint8) for t in base_layer_list], | |
| [t.astype(np.uint8) for t in bright_layer_list], | |
| [t.astype(np.uint8) for t in shadow_layer_list], | |
| ) | |
| def get_composite_layer( | |
| input_image: np.ndarray, base_image: np.ndarray, label: np.ndarray | |
| ) -> Tuple[ | |
| List[np.ndarray], | |
| List[np.ndarray], | |
| List[np.ndarray], | |
| List[np.ndarray], | |
| List[np.ndarray], | |
| ]: | |
| """画像の合成を行う関数""" | |
| base_image = base_image.astype(np.int32) | |
| input_image = input_image.astype(np.int32) | |
| diff_image = base_image - input_image | |
| # Shadow (影) | |
| shadow_mask = (diff_image[:, :, :3] > 0).all(axis=2) | |
| shadow_image = input_image.copy() | |
| shadow_image[:, :, 3] = shadow_image[:, :, 3] * shadow_mask | |
| shadow_image[:, :, :3] = (shadow_image[:, :, :3] * 255) / base_image[:, :, :3] | |
| # Screen (逆光) | |
| screen_mask = (diff_image[:, :, :3] < 0).all(axis=2) | |
| screen_image = input_image.copy() | |
| screen_image[:, :, 3] = screen_image[:, :, 3] * screen_mask | |
| screen_image[:, :, :3] = (screen_image[:, :, :3] - base_image[:, :, :3]) / ( | |
| 1 - base_image[:, :, :3] / 255 | |
| ) | |
| # Residuals (残差) | |
| residuals_mask = ~shadow_mask & ~screen_mask | |
| residuals_image = input_image[:, :, 3].copy() | |
| residuals_image = residuals_image * residuals_mask | |
| # Addition (加算) | |
| addition_image = input_image.copy() | |
| addition_image[:, :, 3] = residuals_image | |
| addition_image[:, :, :3] = input_image[:, :, :3] - base_image[:, :, :3] | |
| addition_image[:, :, :3] = addition_image[:, :, :3].clip(0, 255) | |
| # Subtract (減算) | |
| subtract_image = input_image.copy() | |
| subtract_image[:, :, 3] = residuals_image | |
| subtract_image[:, :, :3] = base_image[:, :, :3] - input_image[:, :, :3] | |
| subtract_image[:, :, :3] = subtract_image[:, :, :3].clip(0, 255) | |
| [ | |
| base_layer_list, | |
| shadow_layer_list, | |
| screen_layer_list, | |
| addition_layer_list, | |
| subtract_layer_list, | |
| ] = _split_img_batch( | |
| np.array( | |
| [ | |
| base_image, | |
| shadow_image, | |
| screen_image, | |
| addition_image, | |
| subtract_image, | |
| ] | |
| ), | |
| label, | |
| ) | |
| return ( | |
| [t.astype(np.uint8) for t in base_layer_list], | |
| [t.astype(np.uint8) for t in shadow_layer_list], | |
| [t.astype(np.uint8) for t in screen_layer_list], | |
| [t.astype(np.uint8) for t in addition_layer_list], | |
| [t.astype(np.uint8) for t in subtract_layer_list], | |
| ) | |