Spaces:
Running
Running
Upload 8 files
Browse files- utils/__init__.py +21 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/common.cpython-39.pyc +0 -0
- utils/__pycache__/image_processing.cpython-39.pyc +0 -0
- utils/common.py +165 -0
- utils/fast_numpyio.py +43 -0
- utils/image_processing.py +135 -0
- utils/logger.py +24 -0
utils/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .common import *
|
| 2 |
+
from .image_processing import *
|
| 3 |
+
|
| 4 |
+
class DefaultArgs:
|
| 5 |
+
dataset ='Hayao'
|
| 6 |
+
data_dir ='/content'
|
| 7 |
+
epochs = 10
|
| 8 |
+
batch_size = 1
|
| 9 |
+
checkpoint_dir ='/content/checkpoints'
|
| 10 |
+
save_image_dir ='/content/images'
|
| 11 |
+
display_image =True
|
| 12 |
+
save_interval =2
|
| 13 |
+
debug_samples =0
|
| 14 |
+
lr_g = 0.001
|
| 15 |
+
lr_d = 0.002
|
| 16 |
+
wadvg = 300.0
|
| 17 |
+
wadvd = 300.0
|
| 18 |
+
wcon = 1.5
|
| 19 |
+
wgra = 3
|
| 20 |
+
wcol = 10
|
| 21 |
+
use_sn = False
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (739 Bytes). View file
|
|
|
utils/__pycache__/common.cpython-39.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
utils/__pycache__/image_processing.cpython-39.pyc
ADDED
|
Binary file (2.83 kB). View file
|
|
|
utils/common.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import urllib.request
|
| 6 |
+
import cv2
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
HTTP_PREFIXES = [
|
| 10 |
+
'http',
|
| 11 |
+
'data:image/jpeg',
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
RELEASED_WEIGHTS = {
|
| 16 |
+
"hayao:v2": (
|
| 17 |
+
# Dataset trained on Google Landmark micro as training real photo
|
| 18 |
+
"v2",
|
| 19 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.1/GeneratorV2_gldv2_Hayao.pt"
|
| 20 |
+
),
|
| 21 |
+
"hayao:v1": (
|
| 22 |
+
"v1",
|
| 23 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
|
| 24 |
+
),
|
| 25 |
+
"hayao": (
|
| 26 |
+
"v1",
|
| 27 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
|
| 28 |
+
),
|
| 29 |
+
"shinkai:v1": (
|
| 30 |
+
"v1",
|
| 31 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
|
| 32 |
+
),
|
| 33 |
+
"shinkai": (
|
| 34 |
+
"v1",
|
| 35 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
|
| 36 |
+
),
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def is_image_file(path):
|
| 40 |
+
_, ext = os.path.splitext(path)
|
| 41 |
+
return ext.lower() in (".png", ".jpg", ".jpeg")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def read_image(path):
|
| 45 |
+
"""
|
| 46 |
+
Read image from given path
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
if any(path.startswith(p) for p in HTTP_PREFIXES):
|
| 50 |
+
urllib.request.urlretrieve(path, "temp.jpg")
|
| 51 |
+
path = "temp.jpg"
|
| 52 |
+
|
| 53 |
+
return cv2.imread(path)[: ,: ,::-1]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def save_checkpoint(model, path, optimizer=None, epoch=None):
|
| 57 |
+
checkpoint = {
|
| 58 |
+
'model_state_dict': model.state_dict(),
|
| 59 |
+
'epoch': epoch,
|
| 60 |
+
}
|
| 61 |
+
if optimizer is not None:
|
| 62 |
+
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
| 63 |
+
|
| 64 |
+
torch.save(checkpoint, path)
|
| 65 |
+
|
| 66 |
+
def maybe_remove_module(state_dict):
|
| 67 |
+
# Remove added module ins state_dict in ddp training
|
| 68 |
+
# https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3
|
| 69 |
+
new_state_dict = {}
|
| 70 |
+
module_str = 'module.'
|
| 71 |
+
for k, v in state_dict.items():
|
| 72 |
+
|
| 73 |
+
if k.startswith(module_str):
|
| 74 |
+
k = k[len(module_str):]
|
| 75 |
+
new_state_dict[k] = v
|
| 76 |
+
return new_state_dict
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int:
|
| 80 |
+
state_dict = load_state_dict(path, map_location)
|
| 81 |
+
model_state_dict = maybe_remove_module(state_dict['model_state_dict'])
|
| 82 |
+
model.load_state_dict(
|
| 83 |
+
model_state_dict,
|
| 84 |
+
strict=True
|
| 85 |
+
)
|
| 86 |
+
if 'optimizer_state_dict' in state_dict:
|
| 87 |
+
if optimizer is not None:
|
| 88 |
+
optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
| 89 |
+
if strip_optimizer:
|
| 90 |
+
del state_dict["optimizer_state_dict"]
|
| 91 |
+
torch.save(state_dict, path)
|
| 92 |
+
print(f"Optimizer stripped and saved to {path}")
|
| 93 |
+
|
| 94 |
+
epoch = state_dict.get('epoch', 0)
|
| 95 |
+
return epoch
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def load_state_dict(weight, map_location) -> dict:
|
| 99 |
+
if weight.lower() in RELEASED_WEIGHTS:
|
| 100 |
+
weight = _download_weight(weight.lower())
|
| 101 |
+
|
| 102 |
+
if map_location is None:
|
| 103 |
+
# auto select
|
| 104 |
+
map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 105 |
+
state_dict = torch.load(weight, map_location=map_location)
|
| 106 |
+
|
| 107 |
+
return state_dict
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def initialize_weights(net):
|
| 111 |
+
for m in net.modules():
|
| 112 |
+
try:
|
| 113 |
+
if isinstance(m, nn.Conv2d):
|
| 114 |
+
# m.weight.data.normal_(0, 0.02)
|
| 115 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 116 |
+
m.bias.data.zero_()
|
| 117 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 118 |
+
# m.weight.data.normal_(0, 0.02)
|
| 119 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 120 |
+
m.bias.data.zero_()
|
| 121 |
+
elif isinstance(m, nn.Linear):
|
| 122 |
+
# m.weight.data.normal_(0, 0.02)
|
| 123 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 124 |
+
m.bias.data.zero_()
|
| 125 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 126 |
+
m.weight.data.fill_(1)
|
| 127 |
+
m.bias.data.zero_()
|
| 128 |
+
except Exception as e:
|
| 129 |
+
# print(f'SKip layer {m}, {e}')
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def set_lr(optimizer, lr):
|
| 134 |
+
for param_group in optimizer.param_groups:
|
| 135 |
+
param_group['lr'] = lr
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class DownloadProgressBar(tqdm):
|
| 139 |
+
'''
|
| 140 |
+
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
|
| 141 |
+
'''
|
| 142 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
| 143 |
+
if tsize is not None:
|
| 144 |
+
self.total = tsize
|
| 145 |
+
self.update(b * bsize - self.n)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _download_weight(weight):
|
| 149 |
+
'''
|
| 150 |
+
Download weight and save to local file
|
| 151 |
+
'''
|
| 152 |
+
os.makedirs('.cache', exist_ok=True)
|
| 153 |
+
url = RELEASED_WEIGHTS[weight][1]
|
| 154 |
+
filename = os.path.basename(url)
|
| 155 |
+
save_path = f'.cache/{filename}'
|
| 156 |
+
|
| 157 |
+
if os.path.isfile(save_path):
|
| 158 |
+
return save_path
|
| 159 |
+
|
| 160 |
+
desc = f'Downloading {url} to {save_path}'
|
| 161 |
+
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t:
|
| 162 |
+
urllib.request.urlretrieve(url, save_path, reporthook=t.update_to)
|
| 163 |
+
|
| 164 |
+
return save_path
|
| 165 |
+
|
utils/fast_numpyio.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# code from https://github.com/divideconcept/fastnumpyio/blob/main/fastnumpyio.py
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
import numpy.lib.format
|
| 6 |
+
import struct
|
| 7 |
+
|
| 8 |
+
def save(file, array):
|
| 9 |
+
magic_string=b"\x93NUMPY\x01\x00v\x00"
|
| 10 |
+
header=bytes(("{'descr': '"+array.dtype.descr[0][1]+"', 'fortran_order': False, 'shape': "+str(array.shape)+", }").ljust(127-len(magic_string))+"\n",'utf-8')
|
| 11 |
+
if type(file) == str:
|
| 12 |
+
file=open(file,"wb")
|
| 13 |
+
file.write(magic_string)
|
| 14 |
+
file.write(header)
|
| 15 |
+
file.write(array.data)
|
| 16 |
+
|
| 17 |
+
def pack(array):
|
| 18 |
+
size=len(array.shape)
|
| 19 |
+
return bytes(array.dtype.byteorder.replace('=','<' if sys.byteorder == 'little' else '>')+array.dtype.kind,'utf-8')+array.dtype.itemsize.to_bytes(1,byteorder='little')+struct.pack(f'<B{size}I',size,*array.shape)+array.data
|
| 20 |
+
|
| 21 |
+
def load(file):
|
| 22 |
+
if type(file) == str:
|
| 23 |
+
file=open(file,"rb")
|
| 24 |
+
header = file.read(128)
|
| 25 |
+
if not header:
|
| 26 |
+
return None
|
| 27 |
+
descr = str(header[19:25], 'utf-8').replace("'","").replace(" ","")
|
| 28 |
+
shape = tuple(int(num) for num in str(header[60:120], 'utf-8').replace(', }', '').replace('(', '').replace(')', '').split(','))
|
| 29 |
+
datasize = numpy.lib.format.descr_to_dtype(descr).itemsize
|
| 30 |
+
for dimension in shape:
|
| 31 |
+
datasize *= dimension
|
| 32 |
+
return np.ndarray(shape, dtype=descr, buffer=file.read(datasize))
|
| 33 |
+
|
| 34 |
+
def unpack(data):
|
| 35 |
+
dtype = str(data[:2],'utf-8')
|
| 36 |
+
dtype += str(data[2])
|
| 37 |
+
size = data[3]
|
| 38 |
+
shape = struct.unpack_from(f'<{size}I', data, 4)
|
| 39 |
+
datasize=data[2]
|
| 40 |
+
for dimension in shape:
|
| 41 |
+
datasize *= dimension
|
| 42 |
+
return np.ndarray(shape, dtype=dtype, buffer=data[4+size*4:4+size*4+datasize])
|
| 43 |
+
|
utils/image_processing.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def gram(input):
|
| 9 |
+
"""
|
| 10 |
+
Calculate Gram Matrix
|
| 11 |
+
|
| 12 |
+
https://pytorch.org/tutorials/advanced/neural_style_tutorial.html#style-loss
|
| 13 |
+
"""
|
| 14 |
+
b, c, w, h = input.size()
|
| 15 |
+
|
| 16 |
+
x = input.contiguous().view(b * c, w * h)
|
| 17 |
+
|
| 18 |
+
# x = x / 2
|
| 19 |
+
|
| 20 |
+
# Work around, torch.mm would generate some inf values.
|
| 21 |
+
# https://discuss.pytorch.org/t/gram-matrix-in-mixed-precision/166800/2
|
| 22 |
+
# x = torch.clamp(x, max=1.0e2, min=-1.0e2)
|
| 23 |
+
# x[x > 1.0e2] = 1.0e2
|
| 24 |
+
# x[x < -1.0e2] = -1.0e2
|
| 25 |
+
|
| 26 |
+
G = torch.mm(x, x.T)
|
| 27 |
+
G = torch.clamp(G, -64990.0, 64990.0)
|
| 28 |
+
# normalize by total elements
|
| 29 |
+
result = G.div(b * c * w * h)
|
| 30 |
+
return result
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def divisible(dim):
|
| 35 |
+
'''
|
| 36 |
+
Make width and height divisible by 32
|
| 37 |
+
'''
|
| 38 |
+
width, height = dim
|
| 39 |
+
return width - (width % 32), height - (height % 32)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def resize_image(image, width=None, height=None, inter=cv2.INTER_AREA):
|
| 43 |
+
dim = None
|
| 44 |
+
h, w = image.shape[:2]
|
| 45 |
+
|
| 46 |
+
if width and height:
|
| 47 |
+
return cv2.resize(image, divisible((width, height)), interpolation=inter)
|
| 48 |
+
|
| 49 |
+
if width is None and height is None:
|
| 50 |
+
return cv2.resize(image, divisible((w, h)), interpolation=inter)
|
| 51 |
+
|
| 52 |
+
if width is None:
|
| 53 |
+
r = height / float(h)
|
| 54 |
+
dim = (int(w * r), height)
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
r = width / float(w)
|
| 58 |
+
dim = (width, int(h * r))
|
| 59 |
+
|
| 60 |
+
return cv2.resize(image, divisible(dim), interpolation=inter)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def normalize_input(images):
|
| 64 |
+
'''
|
| 65 |
+
[0, 255] -> [-1, 1]
|
| 66 |
+
'''
|
| 67 |
+
return images / 127.5 - 1.0
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def denormalize_input(images, dtype=None):
|
| 71 |
+
'''
|
| 72 |
+
[-1, 1] -> [0, 255]
|
| 73 |
+
'''
|
| 74 |
+
images = images * 127.5 + 127.5
|
| 75 |
+
|
| 76 |
+
if dtype is not None:
|
| 77 |
+
if isinstance(images, torch.Tensor):
|
| 78 |
+
images = images.type(dtype)
|
| 79 |
+
else:
|
| 80 |
+
# numpy.ndarray
|
| 81 |
+
images = images.astype(dtype)
|
| 82 |
+
|
| 83 |
+
return images
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def preprocess_images(images):
|
| 87 |
+
'''
|
| 88 |
+
Preprocess image for inference
|
| 89 |
+
|
| 90 |
+
@Arguments:
|
| 91 |
+
- images: np.ndarray
|
| 92 |
+
|
| 93 |
+
@Returns
|
| 94 |
+
- images: torch.tensor
|
| 95 |
+
'''
|
| 96 |
+
images = images.astype(np.float32)
|
| 97 |
+
|
| 98 |
+
# Normalize to [-1, 1]
|
| 99 |
+
images = normalize_input(images)
|
| 100 |
+
images = torch.from_numpy(images)
|
| 101 |
+
|
| 102 |
+
# Add batch dim
|
| 103 |
+
if len(images.shape) == 3:
|
| 104 |
+
images = images.unsqueeze(0)
|
| 105 |
+
|
| 106 |
+
# channel first
|
| 107 |
+
images = images.permute(0, 3, 1, 2)
|
| 108 |
+
|
| 109 |
+
return images
|
| 110 |
+
|
| 111 |
+
def compute_data_mean(data_folder):
|
| 112 |
+
if not os.path.exists(data_folder):
|
| 113 |
+
raise FileNotFoundError(f'Folder {data_folder} does not exits')
|
| 114 |
+
|
| 115 |
+
image_files = os.listdir(data_folder)
|
| 116 |
+
total = np.zeros(3)
|
| 117 |
+
|
| 118 |
+
print(f"Compute mean (R, G, B) from {len(image_files)} images")
|
| 119 |
+
|
| 120 |
+
for img_file in tqdm(image_files):
|
| 121 |
+
path = os.path.join(data_folder, img_file)
|
| 122 |
+
image = cv2.imread(path)
|
| 123 |
+
total += image.mean(axis=(0, 1))
|
| 124 |
+
|
| 125 |
+
channel_mean = total / len(image_files)
|
| 126 |
+
mean = np.mean(channel_mean)
|
| 127 |
+
|
| 128 |
+
return mean - channel_mean[...,::-1] # Convert to BGR for training
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == '__main__':
|
| 132 |
+
t = torch.rand(2, 14, 32, 32)
|
| 133 |
+
|
| 134 |
+
with torch.autocast("cpu"):
|
| 135 |
+
print(gram(t))
|
utils/logger.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_logger(path, *args, **kwargs):
|
| 5 |
+
# logger = logging.getLogger('train')
|
| 6 |
+
# logger.setLevel(logging.NOTSET)
|
| 7 |
+
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 8 |
+
# # add filehandler
|
| 9 |
+
# fh = logging.FileHandler(path)
|
| 10 |
+
# fh.setLevel(logging.NOTSET)
|
| 11 |
+
# fh.setFormatter(formatter)
|
| 12 |
+
# ch = logging.StreamHandler()
|
| 13 |
+
# ch.setLevel(logging.ERROR)
|
| 14 |
+
# logger.addHandler(fh)
|
| 15 |
+
# logger.addHandler(ch)
|
| 16 |
+
# return logger
|
| 17 |
+
logging.basicConfig(format = '%(asctime)s %(message)s',
|
| 18 |
+
datefmt = '%m/%d/%Y %I:%M:%S %p',
|
| 19 |
+
handlers=[
|
| 20 |
+
logging.FileHandler(path),
|
| 21 |
+
logging.StreamHandler()
|
| 22 |
+
],
|
| 23 |
+
level=logging.DEBUG)
|
| 24 |
+
return logging
|