Upload 6 files
Browse files- basicsr/VERSION +1 -0
- basicsr/__init__.py +12 -0
- basicsr/setup.py +166 -0
- basicsr/test.py +45 -0
- basicsr/train.py +215 -0
- basicsr/version.py +5 -0
basicsr/VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1.3.2
|
basicsr/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/xinntao/BasicSR
|
| 2 |
+
# flake8: noqa
|
| 3 |
+
from .archs import *
|
| 4 |
+
from .data import *
|
| 5 |
+
from .losses import *
|
| 6 |
+
from .metrics import *
|
| 7 |
+
from .models import *
|
| 8 |
+
from .ops import *
|
| 9 |
+
from .test import *
|
| 10 |
+
from .train import *
|
| 11 |
+
from .utils import *
|
| 12 |
+
from .version import __gitsha__, __version__
|
basicsr/setup.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
from setuptools import find_packages, setup
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
|
| 10 |
+
from utils.misc import gpu_is_available
|
| 11 |
+
|
| 12 |
+
version_file = './basicsr/version.py'
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def readme():
|
| 16 |
+
with open('README.md', encoding='utf-8') as f:
|
| 17 |
+
content = f.read()
|
| 18 |
+
return content
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_git_hash():
|
| 22 |
+
|
| 23 |
+
def _minimal_ext_cmd(cmd):
|
| 24 |
+
# construct minimal environment
|
| 25 |
+
env = {}
|
| 26 |
+
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
| 27 |
+
v = os.environ.get(k)
|
| 28 |
+
if v is not None:
|
| 29 |
+
env[k] = v
|
| 30 |
+
# LANGUAGE is used on win32
|
| 31 |
+
env['LANGUAGE'] = 'C'
|
| 32 |
+
env['LANG'] = 'C'
|
| 33 |
+
env['LC_ALL'] = 'C'
|
| 34 |
+
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
| 39 |
+
sha = out.strip().decode('ascii')
|
| 40 |
+
except OSError:
|
| 41 |
+
sha = 'unknown'
|
| 42 |
+
|
| 43 |
+
return sha
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_hash():
|
| 47 |
+
if os.path.exists('.git'):
|
| 48 |
+
sha = get_git_hash()[:7]
|
| 49 |
+
elif os.path.exists(version_file):
|
| 50 |
+
try:
|
| 51 |
+
from basicsr.version import __version__
|
| 52 |
+
sha = __version__.split('+')[-1]
|
| 53 |
+
except ImportError:
|
| 54 |
+
raise ImportError('Unable to get git version')
|
| 55 |
+
else:
|
| 56 |
+
sha = 'unknown'
|
| 57 |
+
|
| 58 |
+
return sha
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def write_version_py():
|
| 62 |
+
content = """# GENERATED VERSION FILE
|
| 63 |
+
# TIME: {}
|
| 64 |
+
__version__ = '{}'
|
| 65 |
+
__gitsha__ = '{}'
|
| 66 |
+
version_info = ({})
|
| 67 |
+
"""
|
| 68 |
+
sha = get_hash()
|
| 69 |
+
with open('./basicsr/VERSION', 'r') as f:
|
| 70 |
+
SHORT_VERSION = f.read().strip()
|
| 71 |
+
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
| 72 |
+
|
| 73 |
+
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
|
| 74 |
+
with open(version_file, 'w') as f:
|
| 75 |
+
f.write(version_file_str)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_version():
|
| 79 |
+
with open(version_file, 'r') as f:
|
| 80 |
+
exec(compile(f.read(), version_file, 'exec'))
|
| 81 |
+
return locals()['__version__']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def make_cuda_ext(name, module, sources, sources_cuda=None):
|
| 85 |
+
if sources_cuda is None:
|
| 86 |
+
sources_cuda = []
|
| 87 |
+
define_macros = []
|
| 88 |
+
extra_compile_args = {'cxx': []}
|
| 89 |
+
|
| 90 |
+
# if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
|
| 91 |
+
if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1':
|
| 92 |
+
define_macros += [('WITH_CUDA', None)]
|
| 93 |
+
extension = CUDAExtension
|
| 94 |
+
extra_compile_args['nvcc'] = [
|
| 95 |
+
'-D__CUDA_NO_HALF_OPERATORS__',
|
| 96 |
+
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
| 97 |
+
'-D__CUDA_NO_HALF2_OPERATORS__',
|
| 98 |
+
]
|
| 99 |
+
sources += sources_cuda
|
| 100 |
+
else:
|
| 101 |
+
print(f'Compiling {name} without CUDA')
|
| 102 |
+
extension = CppExtension
|
| 103 |
+
|
| 104 |
+
return extension(
|
| 105 |
+
name=f'{module}.{name}',
|
| 106 |
+
sources=[os.path.join(*module.split('.'), p) for p in sources],
|
| 107 |
+
define_macros=define_macros,
|
| 108 |
+
extra_compile_args=extra_compile_args)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_requirements(filename='requirements.txt'):
|
| 112 |
+
with open(os.path.join('.', filename), 'r') as f:
|
| 113 |
+
requires = [line.replace('\n', '') for line in f.readlines()]
|
| 114 |
+
return requires
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == '__main__':
|
| 118 |
+
if '--cuda_ext' in sys.argv:
|
| 119 |
+
ext_modules = [
|
| 120 |
+
make_cuda_ext(
|
| 121 |
+
name='deform_conv_ext',
|
| 122 |
+
module='ops.dcn',
|
| 123 |
+
sources=['src/deform_conv_ext.cpp'],
|
| 124 |
+
sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
|
| 125 |
+
make_cuda_ext(
|
| 126 |
+
name='fused_act_ext',
|
| 127 |
+
module='ops.fused_act',
|
| 128 |
+
sources=['src/fused_bias_act.cpp'],
|
| 129 |
+
sources_cuda=['src/fused_bias_act_kernel.cu']),
|
| 130 |
+
make_cuda_ext(
|
| 131 |
+
name='upfirdn2d_ext',
|
| 132 |
+
module='ops.upfirdn2d',
|
| 133 |
+
sources=['src/upfirdn2d.cpp'],
|
| 134 |
+
sources_cuda=['src/upfirdn2d_kernel.cu']),
|
| 135 |
+
]
|
| 136 |
+
sys.argv.remove('--cuda_ext')
|
| 137 |
+
else:
|
| 138 |
+
ext_modules = []
|
| 139 |
+
|
| 140 |
+
write_version_py()
|
| 141 |
+
setup(
|
| 142 |
+
name='basicsr',
|
| 143 |
+
version=get_version(),
|
| 144 |
+
description='Open Source Image and Video Super-Resolution Toolbox',
|
| 145 |
+
long_description=readme(),
|
| 146 |
+
long_description_content_type='text/markdown',
|
| 147 |
+
author='Xintao Wang',
|
| 148 |
+
author_email='xintao.wang@outlook.com',
|
| 149 |
+
keywords='computer vision, restoration, super resolution',
|
| 150 |
+
url='https://github.com/xinntao/BasicSR',
|
| 151 |
+
include_package_data=True,
|
| 152 |
+
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
|
| 153 |
+
classifiers=[
|
| 154 |
+
'Development Status :: 4 - Beta',
|
| 155 |
+
'License :: OSI Approved :: Apache Software License',
|
| 156 |
+
'Operating System :: OS Independent',
|
| 157 |
+
'Programming Language :: Python :: 3',
|
| 158 |
+
'Programming Language :: Python :: 3.7',
|
| 159 |
+
'Programming Language :: Python :: 3.8',
|
| 160 |
+
],
|
| 161 |
+
license='Apache License 2.0',
|
| 162 |
+
setup_requires=['cython', 'numpy'],
|
| 163 |
+
install_requires=get_requirements(),
|
| 164 |
+
ext_modules=ext_modules,
|
| 165 |
+
cmdclass={'build_ext': BuildExtension},
|
| 166 |
+
zip_safe=False)
|
basicsr/test.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
from basicsr.data import build_dataloader, build_dataset
|
| 6 |
+
from basicsr.models import build_model
|
| 7 |
+
from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs
|
| 8 |
+
from basicsr.utils.options import dict2str, parse_options
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_pipeline(root_path):
|
| 12 |
+
# parse options, set distributed setting, set ramdom seed
|
| 13 |
+
opt, _ = parse_options(root_path, is_train=False)
|
| 14 |
+
|
| 15 |
+
torch.backends.cudnn.benchmark = True
|
| 16 |
+
# torch.backends.cudnn.deterministic = True
|
| 17 |
+
|
| 18 |
+
# mkdir and initialize loggers
|
| 19 |
+
make_exp_dirs(opt)
|
| 20 |
+
log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log")
|
| 21 |
+
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
|
| 22 |
+
logger.info(get_env_info())
|
| 23 |
+
logger.info(dict2str(opt))
|
| 24 |
+
|
| 25 |
+
# create test dataset and dataloader
|
| 26 |
+
test_loaders = []
|
| 27 |
+
for _, dataset_opt in sorted(opt['datasets'].items()):
|
| 28 |
+
test_set = build_dataset(dataset_opt)
|
| 29 |
+
test_loader = build_dataloader(
|
| 30 |
+
test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
|
| 31 |
+
logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
|
| 32 |
+
test_loaders.append(test_loader)
|
| 33 |
+
|
| 34 |
+
# create model
|
| 35 |
+
model = build_model(opt)
|
| 36 |
+
|
| 37 |
+
for test_loader in test_loaders:
|
| 38 |
+
test_set_name = test_loader.dataset.opt['name']
|
| 39 |
+
logger.info(f'Testing {test_set_name}...')
|
| 40 |
+
model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__ == '__main__':
|
| 44 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
| 45 |
+
test_pipeline(root_path)
|
basicsr/train.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
from os import path as osp
|
| 7 |
+
|
| 8 |
+
from basicsr.data import build_dataloader, build_dataset
|
| 9 |
+
from basicsr.data.data_sampler import EnlargedSampler
|
| 10 |
+
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
|
| 11 |
+
from basicsr.models import build_model
|
| 12 |
+
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
|
| 13 |
+
init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
|
| 14 |
+
from basicsr.utils.options import copy_opt_file, dict2str, parse_options
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def init_tb_loggers(opt):
|
| 18 |
+
# initialize wandb logger before tensorboard logger to allow proper sync
|
| 19 |
+
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
|
| 20 |
+
is not None) and ('debug' not in opt['name']):
|
| 21 |
+
assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
|
| 22 |
+
init_wandb_logger(opt)
|
| 23 |
+
tb_logger = None
|
| 24 |
+
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
|
| 25 |
+
tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
|
| 26 |
+
return tb_logger
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_train_val_dataloader(opt, logger):
|
| 30 |
+
# create train and val dataloaders
|
| 31 |
+
train_loader, val_loaders = None, []
|
| 32 |
+
for phase, dataset_opt in opt['datasets'].items():
|
| 33 |
+
if phase == 'train':
|
| 34 |
+
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
|
| 35 |
+
train_set = build_dataset(dataset_opt)
|
| 36 |
+
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
|
| 37 |
+
train_loader = build_dataloader(
|
| 38 |
+
train_set,
|
| 39 |
+
dataset_opt,
|
| 40 |
+
num_gpu=opt['num_gpu'],
|
| 41 |
+
dist=opt['dist'],
|
| 42 |
+
sampler=train_sampler,
|
| 43 |
+
seed=opt['manual_seed'])
|
| 44 |
+
|
| 45 |
+
num_iter_per_epoch = math.ceil(
|
| 46 |
+
len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
|
| 47 |
+
total_iters = int(opt['train']['total_iter'])
|
| 48 |
+
total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
|
| 49 |
+
logger.info('Training statistics:'
|
| 50 |
+
f'\n\tNumber of train images: {len(train_set)}'
|
| 51 |
+
f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
|
| 52 |
+
f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
|
| 53 |
+
f'\n\tWorld size (gpu number): {opt["world_size"]}'
|
| 54 |
+
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
|
| 55 |
+
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
|
| 56 |
+
elif phase.split('_')[0] == 'val':
|
| 57 |
+
val_set = build_dataset(dataset_opt)
|
| 58 |
+
val_loader = build_dataloader(
|
| 59 |
+
val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
|
| 60 |
+
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
|
| 61 |
+
val_loaders.append(val_loader)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f'Dataset phase {phase} is not recognized.')
|
| 64 |
+
|
| 65 |
+
return train_loader, train_sampler, val_loaders, total_epochs, total_iters
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_resume_state(opt):
|
| 69 |
+
resume_state_path = None
|
| 70 |
+
if opt['auto_resume']:
|
| 71 |
+
state_path = osp.join('experiments', opt['name'], 'training_states')
|
| 72 |
+
if osp.isdir(state_path):
|
| 73 |
+
states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
|
| 74 |
+
if len(states) != 0:
|
| 75 |
+
states = [float(v.split('.state')[0]) for v in states]
|
| 76 |
+
resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
|
| 77 |
+
opt['path']['resume_state'] = resume_state_path
|
| 78 |
+
else:
|
| 79 |
+
if opt['path'].get('resume_state'):
|
| 80 |
+
resume_state_path = opt['path']['resume_state']
|
| 81 |
+
|
| 82 |
+
if resume_state_path is None:
|
| 83 |
+
resume_state = None
|
| 84 |
+
else:
|
| 85 |
+
device_id = torch.cuda.current_device()
|
| 86 |
+
resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
|
| 87 |
+
check_resume(opt, resume_state['iter'])
|
| 88 |
+
return resume_state
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def train_pipeline(root_path):
|
| 92 |
+
# parse options, set distributed setting, set random seed
|
| 93 |
+
opt, args = parse_options(root_path, is_train=True)
|
| 94 |
+
opt['root_path'] = root_path
|
| 95 |
+
|
| 96 |
+
torch.backends.cudnn.benchmark = True
|
| 97 |
+
# torch.backends.cudnn.deterministic = True
|
| 98 |
+
|
| 99 |
+
# load resume states if necessary
|
| 100 |
+
resume_state = load_resume_state(opt)
|
| 101 |
+
# mkdir for experiments and logger
|
| 102 |
+
if resume_state is None:
|
| 103 |
+
make_exp_dirs(opt)
|
| 104 |
+
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
|
| 105 |
+
mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
|
| 106 |
+
|
| 107 |
+
# copy the yml file to the experiment root
|
| 108 |
+
copy_opt_file(args.opt, opt['path']['experiments_root'])
|
| 109 |
+
|
| 110 |
+
# WARNING: should not use get_root_logger in the above codes, including the called functions
|
| 111 |
+
# Otherwise the logger will not be properly initialized
|
| 112 |
+
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
|
| 113 |
+
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
|
| 114 |
+
logger.info(get_env_info())
|
| 115 |
+
logger.info(dict2str(opt))
|
| 116 |
+
# initialize wandb and tb loggers
|
| 117 |
+
tb_logger = init_tb_loggers(opt)
|
| 118 |
+
|
| 119 |
+
# create train and validation dataloaders
|
| 120 |
+
result = create_train_val_dataloader(opt, logger)
|
| 121 |
+
train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
|
| 122 |
+
|
| 123 |
+
# create model
|
| 124 |
+
model = build_model(opt)
|
| 125 |
+
if resume_state: # resume training
|
| 126 |
+
model.resume_training(resume_state) # handle optimizers and schedulers
|
| 127 |
+
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.")
|
| 128 |
+
start_epoch = resume_state['epoch']
|
| 129 |
+
current_iter = resume_state['iter']
|
| 130 |
+
else:
|
| 131 |
+
start_epoch = 0
|
| 132 |
+
current_iter = 0
|
| 133 |
+
|
| 134 |
+
# create message logger (formatted outputs)
|
| 135 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
| 136 |
+
|
| 137 |
+
# dataloader prefetcher
|
| 138 |
+
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
|
| 139 |
+
if prefetch_mode is None or prefetch_mode == 'cpu':
|
| 140 |
+
prefetcher = CPUPrefetcher(train_loader)
|
| 141 |
+
elif prefetch_mode == 'cuda':
|
| 142 |
+
prefetcher = CUDAPrefetcher(train_loader, opt)
|
| 143 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader')
|
| 144 |
+
if opt['datasets']['train'].get('pin_memory') is not True:
|
| 145 |
+
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.")
|
| 148 |
+
|
| 149 |
+
# training
|
| 150 |
+
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
|
| 151 |
+
data_timer, iter_timer = AvgTimer(), AvgTimer()
|
| 152 |
+
start_time = time.time()
|
| 153 |
+
|
| 154 |
+
for epoch in range(start_epoch, total_epochs + 1):
|
| 155 |
+
train_sampler.set_epoch(epoch)
|
| 156 |
+
prefetcher.reset()
|
| 157 |
+
train_data = prefetcher.next()
|
| 158 |
+
|
| 159 |
+
while train_data is not None:
|
| 160 |
+
data_timer.record()
|
| 161 |
+
|
| 162 |
+
current_iter += 1
|
| 163 |
+
if current_iter > total_iters:
|
| 164 |
+
break
|
| 165 |
+
# update learning rate
|
| 166 |
+
model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
|
| 167 |
+
# training
|
| 168 |
+
model.feed_data(train_data)
|
| 169 |
+
model.optimize_parameters(current_iter)
|
| 170 |
+
iter_timer.record()
|
| 171 |
+
if current_iter == 1:
|
| 172 |
+
# reset start time in msg_logger for more accurate eta_time
|
| 173 |
+
# not work in resume mode
|
| 174 |
+
msg_logger.reset_start_time()
|
| 175 |
+
# log
|
| 176 |
+
if current_iter % opt['logger']['print_freq'] == 0:
|
| 177 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
| 178 |
+
log_vars.update({'lrs': model.get_current_learning_rate()})
|
| 179 |
+
log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
|
| 180 |
+
log_vars.update(model.get_current_log())
|
| 181 |
+
msg_logger(log_vars)
|
| 182 |
+
|
| 183 |
+
# save models and training states
|
| 184 |
+
if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
|
| 185 |
+
logger.info('Saving models and training states.')
|
| 186 |
+
model.save(epoch, current_iter)
|
| 187 |
+
|
| 188 |
+
# validation
|
| 189 |
+
if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
|
| 190 |
+
if len(val_loaders) > 1:
|
| 191 |
+
logger.warning('Multiple validation datasets are *only* supported by SRModel.')
|
| 192 |
+
for val_loader in val_loaders:
|
| 193 |
+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
|
| 194 |
+
|
| 195 |
+
data_timer.start()
|
| 196 |
+
iter_timer.start()
|
| 197 |
+
train_data = prefetcher.next()
|
| 198 |
+
# end of iter
|
| 199 |
+
|
| 200 |
+
# end of epoch
|
| 201 |
+
|
| 202 |
+
consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
|
| 203 |
+
logger.info(f'End of training. Time consumed: {consumed_time}')
|
| 204 |
+
logger.info('Save the latest model.')
|
| 205 |
+
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
|
| 206 |
+
if opt.get('val') is not None:
|
| 207 |
+
for val_loader in val_loaders:
|
| 208 |
+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
|
| 209 |
+
if tb_logger:
|
| 210 |
+
tb_logger.close()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == '__main__':
|
| 214 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
| 215 |
+
train_pipeline(root_path)
|
basicsr/version.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GENERATED VERSION FILE
|
| 2 |
+
# TIME: Sun Jul 2 18:58:12 2023
|
| 3 |
+
__version__ = '1.3.2'
|
| 4 |
+
__gitsha__ = '4724c90'
|
| 5 |
+
version_info = (1, 3, 2)
|