Spaces:
Runtime error
Runtime error
| import os | |
| import glob | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from tqdm.notebook import tqdm | |
| import matplotlib.pyplot as plt | |
| from skimage.color import rgb2lab, lab2rgb | |
| import torch | |
| from torch import nn, optim | |
| from torchvision import transforms | |
| from torchvision.utils import make_grid | |
| from torch.utils.data import Dataset, DataLoader | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def init_weights(net, init="norm", gain=0.02): | |
| def init_func(m): | |
| classname = m.__class__.__name__ | |
| if hasattr(m, "weight") and "Conv" in classname: | |
| if init == "norm": | |
| nn.init.normal_(m.weight.data, mean=0.0, std=gain) | |
| elif init == "xavier": | |
| nn.init.xavier_normal_(m.weight.data, gain=gain) | |
| elif init == "kaiming": | |
| nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") | |
| if hasattr(m, "bias") and m.bias is not None: | |
| nn.init.constant_(m.bias.data, 0.0) | |
| elif "BatchNorm2d" in classname: | |
| nn.init.normal_(m.weight.data, 1.0, gain) | |
| nn.init.constant_(m.bias.data, 0.0) | |
| net.apply(init_func) | |
| print(f"model initialized with {init} initialization") | |
| return net | |