| import os |
| import glob |
| import time |
| import numpy as np |
| from PIL import Image |
| from pathlib import Path |
| from tqdm.notebook import tqdm |
| import matplotlib.pyplot as plt |
| from skimage.color import rgb2lab, lab2rgb |
|
|
| import torch |
| from torch import nn, optim |
| from torchvision import transforms |
| from torchvision.utils import make_grid |
| from torch.utils.data import Dataset, DataLoader |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def init_weights(net, init="norm", gain=0.02): |
| def init_func(m): |
| classname = m.__class__.__name__ |
| if hasattr(m, "weight") and "Conv" in classname: |
| if init == "norm": |
| nn.init.normal_(m.weight.data, mean=0.0, std=gain) |
| elif init == "xavier": |
| nn.init.xavier_normal_(m.weight.data, gain=gain) |
| elif init == "kaiming": |
| nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") |
|
|
| if hasattr(m, "bias") and m.bias is not None: |
| nn.init.constant_(m.bias.data, 0.0) |
| elif "BatchNorm2d" in classname: |
| nn.init.normal_(m.weight.data, 1.0, gain) |
| nn.init.constant_(m.bias.data, 0.0) |
|
|
| net.apply(init_func) |
| print(f"model initialized with {init} initialization") |
| return net |
|
|