|
|
import os |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
from loguru import logger |
|
|
|
|
|
from configs.train_config import TrainConfig |
|
|
from data.dataset import TrainDatasetDataLoader |
|
|
from models.model import HifiFace |
|
|
from utils.visualizer import Visualizer |
|
|
|
|
|
use_ddp = TrainConfig().use_ddp |
|
|
if use_ddp: |
|
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
def setup(): |
|
|
|
|
|
|
|
|
dist.init_process_group("nccl") |
|
|
return dist.get_rank() |
|
|
|
|
|
def cleanup(): |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
def train(): |
|
|
rank = 0 |
|
|
if use_ddp: |
|
|
rank = setup() |
|
|
device = torch.device(f"cuda:{rank}") |
|
|
logger.info(f"use device {device}") |
|
|
|
|
|
opt = TrainConfig() |
|
|
dataloader = TrainDatasetDataLoader() |
|
|
dataset_length = len(dataloader) |
|
|
logger.info(f"Dataset length: {dataset_length}") |
|
|
|
|
|
model = HifiFace( |
|
|
opt.identity_extractor_config, is_training=True, device=device, load_checkpoint=opt.load_checkpoint |
|
|
) |
|
|
model.train() |
|
|
|
|
|
logger.info("model initialized") |
|
|
visualizer = None |
|
|
ckpt = False |
|
|
if not opt.use_ddp or rank == 0: |
|
|
visualizer = Visualizer(opt) |
|
|
ckpt = True |
|
|
|
|
|
total_iter = 0 |
|
|
epoch = 0 |
|
|
while True: |
|
|
if opt.use_ddp: |
|
|
dataloader.train_sampler.set_epoch(epoch) |
|
|
for data in dataloader: |
|
|
source_image = data["source_image"].to(device) |
|
|
target_image = data["target_image"].to(device) |
|
|
targe_mask = data["target_mask"].to(device) |
|
|
same = data["same"].to(device) |
|
|
loss_dict, visual_dict = model.optimize(source_image, target_image, targe_mask, same) |
|
|
|
|
|
total_iter += 1 |
|
|
|
|
|
if total_iter % opt.visualize_interval == 0 and visualizer is not None: |
|
|
visualizer.display_current_results(total_iter, visual_dict) |
|
|
|
|
|
if total_iter % opt.plot_interval == 0 and visualizer is not None: |
|
|
visualizer.plot_current_losses(total_iter, loss_dict) |
|
|
logger.info(f"Iter: {total_iter}") |
|
|
for k, v in loss_dict.items(): |
|
|
logger.info(f" {k}: {v}") |
|
|
logger.info("=" * 20) |
|
|
|
|
|
if total_iter % opt.checkpoint_interval == 0 and ckpt: |
|
|
logger.info(f"Saving model at iter {total_iter}") |
|
|
model.save(opt.checkpoint_dir, total_iter) |
|
|
|
|
|
if total_iter > opt.max_iters: |
|
|
logger.info(f"Maximum iterations exceeded. Stopping training.") |
|
|
if ckpt: |
|
|
model.save(opt.checkpoint_dir, total_iter) |
|
|
if use_ddp: |
|
|
cleanup() |
|
|
sys.exit(0) |
|
|
epoch += 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if use_ddp: |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
n_gpus = torch.cuda.device_count() |
|
|
train() |
|
|
else: |
|
|
train() |
|
|
|