Spaces:
Runtime error
Runtime error
| from src.utils.config_loader import constants | |
| from huggingface_hub import snapshot_download | |
| from zipfile import ZipFile | |
| import numpy as np | |
| import os,shutil | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import math | |
| def download_hf_dataset(repo_id,allow_patterns=None): | |
| """Used to download dataset from any public hugging face dataset""" | |
| snapshot_download(repo_id=repo_id, | |
| repo_type="dataset", | |
| local_dir=constants.RAW_DATASET_DIR, | |
| allow_patterns=allow_patterns) | |
| def download_personal_hf_dataset(name): | |
| """Used to download dataset from a specific hugging face dataset""" | |
| download_hf_dataset(repo_id="Anuj-Panthri/Image-Colorization-Datasets", | |
| allow_patterns=f"{name}/*") | |
| def unzip_file(file_path,destination_dir): | |
| """unzips file to destination_dir""" | |
| if os.path.exists(destination_dir): | |
| shutil.rmtree(destination_dir) | |
| os.makedirs(destination_dir) | |
| with ZipFile(file_path,"r") as zip: | |
| zip.extractall(destination_dir) | |
| def is_bw(img:np.ndarray): | |
| """checks if RGB image is black and white""" | |
| rg,gb,rb = img[:,:,0]-img[:,:,1] , img[:,:,1]-img[:,:,2] , img[:,:,0]-img[:,:,2] | |
| rg,gb,rb = np.abs(rg).sum(),np.abs(gb).sum(),np.abs(rb).sum() | |
| avg = np.mean([rg,gb,rb]) | |
| return avg<10 | |
| def print_title(msg:str,max_chars=105): | |
| n = (max_chars-len(msg))//2 | |
| print("="*n,msg.upper(),"="*n,sep="") | |
| def scale_L(L): | |
| return L/100 | |
| def rescale_L(L): | |
| return L*100 | |
| def scale_AB(AB): | |
| return AB/128 | |
| def rescale_AB(AB): | |
| return AB*128 | |
| def show_images_from_paths(image_paths:list[str],image_size=64,cols=4,row_size=5,col_size=5,show_BW=False,title=None): | |
| n = len(image_paths) | |
| rows = math.ceil(n/cols) | |
| fig = plt.figure(figsize=(col_size*cols,row_size*rows)) | |
| if title: | |
| plt.title(title) | |
| plt.axis("off") | |
| for i in range(n): | |
| fig.add_subplot(rows,cols,i+1) | |
| img = cv2.imread(image_paths[i])[:,:,::-1] | |
| img = cv2.resize(img,[image_size,image_size]) | |
| if show_BW: | |
| BW = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY) | |
| BW = np.tile(BW,(1,1,3)) | |
| img = np.concatenate([BW,img],axis=1) | |
| plt.imshow(img.astype("uint8")) | |
| plt.show() | |
| def see_batch(L_batch,AB_batch,show_L=False,cols=4,row_size=5,col_size=5,title=None): | |
| n = L_batch.shape[0] | |
| rows = math.ceil(n/cols) | |
| fig = plt.figure(figsize=(col_size*cols,row_size*rows)) | |
| if title: | |
| plt.title(title) | |
| plt.axis("off") | |
| for i in range(n): | |
| fig.add_subplot(rows,cols,i+1) | |
| L,AB = L_batch[i],AB_batch[i] | |
| L,AB = rescale_L(L), rescale_AB(AB) | |
| # print(L.shape,AB.shape) | |
| img = np.concatenate([L,AB],axis=-1) | |
| img = cv2.cvtColor(img,cv2.COLOR_LAB2RGB)*255 | |
| # print(img.min(),img.max()) | |
| if show_L: | |
| L = np.tile(L,(1,1,3))/100*255 | |
| img = np.concatenate([L,img],axis=1) | |
| plt.imshow(img.astype("uint8")) | |
| plt.show() | |