Spaces:
Runtime error
Runtime error
| from src.utils.config_loader import constants | |
| from huggingface_hub import snapshot_download | |
| from zipfile import ZipFile | |
| import numpy as np | |
| import os, shutil | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import math | |
| def download_hf_dataset(repo_id, allow_patterns=None): | |
| """Used to download dataset from any public hugging face dataset""" | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| local_dir=constants.RAW_DATASET_DIR, | |
| allow_patterns=allow_patterns, | |
| ) | |
| def download_personal_hf_dataset(name): | |
| """Used to download dataset from a specific hugging face dataset""" | |
| download_hf_dataset( | |
| repo_id="Anuj-Panthri/Image-Colorization-Datasets", allow_patterns=f"{name}/*" | |
| ) | |
| def unzip_file(file_path, destination_dir): | |
| """unzips file to destination_dir""" | |
| if os.path.exists(destination_dir): | |
| shutil.rmtree(destination_dir) | |
| os.makedirs(destination_dir) | |
| with ZipFile(file_path, "r") as zip: | |
| zip.extractall(destination_dir) | |
| def is_bw(img: np.ndarray): | |
| """checks if RGB image is black and white""" | |
| rg, gb, rb = ( | |
| img[:, :, 0] - img[:, :, 1], | |
| img[:, :, 1] - img[:, :, 2], | |
| img[:, :, 0] - img[:, :, 2], | |
| ) | |
| rg, gb, rb = np.abs(rg).sum(), np.abs(gb).sum(), np.abs(rb).sum() | |
| avg = np.mean([rg, gb, rb]) | |
| return avg < 10 | |
| def print_title(msg: str, max_chars=105): | |
| n = (max_chars - len(msg)) // 2 | |
| print("=" * n, msg.upper(), "=" * n, sep="") | |
| def scale_L(L): | |
| return L / 100 | |
| def rescale_L(L): | |
| return L * 100 | |
| def scale_AB(AB): | |
| return AB / 128 | |
| def rescale_AB(AB): | |
| return AB * 128 | |
| def show_images_from_paths( | |
| image_paths: list[str], | |
| image_size=64, | |
| cols=4, | |
| row_size=5, | |
| col_size=5, | |
| show_BW=False, | |
| title=None, | |
| save=False, | |
| label="", | |
| ): | |
| n = len(image_paths) | |
| rows = math.ceil(n / cols) | |
| fig = plt.figure(figsize=(col_size * cols, row_size * rows)) | |
| if title: | |
| plt.title(title) | |
| plt.axis("off") | |
| for i in range(n): | |
| fig.add_subplot(rows, cols, i + 1) | |
| img = cv2.imread(image_paths[i])[:, :, ::-1] | |
| img = cv2.resize(img, [image_size, image_size]) | |
| if show_BW: | |
| BW = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
| BW = np.tile(BW, (1, 1, 3)) | |
| img = np.concatenate([BW, img], axis=1) | |
| plt.imshow(img.astype("uint8")) | |
| if save: | |
| os.makedirs(constants.ARTIFACT_DATASET_VISUALIZATION_DIR, exist_ok=True) | |
| plt.savefig( | |
| os.path.join( | |
| constants.ARTIFACT_DATASET_VISUALIZATION_DIR, f"{label}_image.png" | |
| ) | |
| ) | |
| plt.show() | |
| def see_batch( | |
| L_batch, | |
| AB_batch, | |
| show_L=False, | |
| cols=4, | |
| row_size=5, | |
| col_size=5, | |
| title=None, | |
| save=False, | |
| label="", | |
| ): | |
| n = L_batch.shape[0] | |
| rows = math.ceil(n / cols) | |
| fig = plt.figure(figsize=(col_size * cols, row_size * rows)) | |
| if title: | |
| plt.title(title) | |
| plt.axis("off") | |
| for i in range(n): | |
| fig.add_subplot(rows, cols, i + 1) | |
| L, AB = L_batch[i], AB_batch[i] | |
| L, AB = rescale_L(L), rescale_AB(AB) | |
| # print(L.shape,AB.shape) | |
| img = np.concatenate([L, AB], axis=-1) | |
| img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB) * 255 | |
| # print(img.min(),img.max()) | |
| if show_L: | |
| L = np.tile(L, (1, 1, 3)) / 100 * 255 | |
| img = np.concatenate([L, img], axis=1) | |
| plt.imshow(img.astype("uint8")) | |
| if save: | |
| os.makedirs(constants.ARTIFACT_RESULT_VISUALIZATION_DIR, exist_ok=True) | |
| plt.savefig( | |
| os.path.join( | |
| constants.ARTIFACT_RESULT_VISUALIZATION_DIR, f"{label}_image.png" | |
| ) | |
| ) | |
| plt.show() | |