|
|
""" |
|
|
# First update `train_config.py` to set paths to your dataset locations. |
|
|
|
|
|
# You may want to change `--num-workers` according to your machine's memory. |
|
|
# The default num-workers=8 may cause dataloader to exit unexpectedly when |
|
|
# machine is out of memory. |
|
|
|
|
|
# Stage 1 |
|
|
python train.py \ |
|
|
--model-variant mobilenetv3 \ |
|
|
--dataset videomatte \ |
|
|
--resolution-lr 512 \ |
|
|
--seq-length-lr 15 \ |
|
|
--learning-rate-backbone 0.0001 \ |
|
|
--learning-rate-aspp 0.0002 \ |
|
|
--learning-rate-decoder 0.0002 \ |
|
|
--learning-rate-refiner 0 \ |
|
|
--checkpoint-dir checkpoint/stage1 \ |
|
|
--log-dir log/stage1 \ |
|
|
--epoch-start 0 \ |
|
|
--epoch-end 20 |
|
|
|
|
|
# Stage 2 |
|
|
python train.py \ |
|
|
--model-variant mobilenetv3 \ |
|
|
--dataset videomatte \ |
|
|
--resolution-lr 512 \ |
|
|
--seq-length-lr 50 \ |
|
|
--learning-rate-backbone 0.00005 \ |
|
|
--learning-rate-aspp 0.0001 \ |
|
|
--learning-rate-decoder 0.0001 \ |
|
|
--learning-rate-refiner 0 \ |
|
|
--checkpoint checkpoint/stage1/epoch-19.pth \ |
|
|
--checkpoint-dir checkpoint/stage2 \ |
|
|
--log-dir log/stage2 \ |
|
|
--epoch-start 20 \ |
|
|
--epoch-end 22 |
|
|
|
|
|
# Stage 3 |
|
|
python train.py \ |
|
|
--model-variant mobilenetv3 \ |
|
|
--dataset videomatte \ |
|
|
--train-hr \ |
|
|
--resolution-lr 512 \ |
|
|
--resolution-hr 2048 \ |
|
|
--seq-length-lr 40 \ |
|
|
--seq-length-hr 6 \ |
|
|
--learning-rate-backbone 0.00001 \ |
|
|
--learning-rate-aspp 0.00001 \ |
|
|
--learning-rate-decoder 0.00001 \ |
|
|
--learning-rate-refiner 0.0002 \ |
|
|
--checkpoint checkpoint/stage2/epoch-21.pth \ |
|
|
--checkpoint-dir checkpoint/stage3 \ |
|
|
--log-dir log/stage3 \ |
|
|
--epoch-start 22 \ |
|
|
--epoch-end 23 |
|
|
|
|
|
# Stage 4 |
|
|
python train.py \ |
|
|
--model-variant mobilenetv3 \ |
|
|
--dataset imagematte \ |
|
|
--train-hr \ |
|
|
--resolution-lr 512 \ |
|
|
--resolution-hr 2048 \ |
|
|
--seq-length-lr 40 \ |
|
|
--seq-length-hr 6 \ |
|
|
--learning-rate-backbone 0.00001 \ |
|
|
--learning-rate-aspp 0.00001 \ |
|
|
--learning-rate-decoder 0.00005 \ |
|
|
--learning-rate-refiner 0.0002 \ |
|
|
--checkpoint checkpoint/stage3/epoch-22.pth \ |
|
|
--checkpoint-dir checkpoint/stage4 \ |
|
|
--log-dir log/stage4 \ |
|
|
--epoch-start 23 \ |
|
|
--epoch-end 28 |
|
|
""" |
|
|
|
|
|
|
|
|
import argparse |
|
|
import torch |
|
|
import random |
|
|
import os |
|
|
from torch import nn |
|
|
from torch import distributed as dist |
|
|
from torch import multiprocessing as mp |
|
|
from torch.nn import functional as F |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.optim import Adam |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
from torch.utils.data import DataLoader, ConcatDataset |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from torchvision.utils import make_grid |
|
|
from torchvision.transforms.functional import center_crop |
|
|
from tqdm import tqdm |
|
|
|
|
|
from dataset.videomatte import ( |
|
|
VideoMatteDataset, |
|
|
VideoMatteTrainAugmentation, |
|
|
VideoMatteValidAugmentation, |
|
|
) |
|
|
from dataset.imagematte import ( |
|
|
ImageMatteDataset, |
|
|
ImageMatteAugmentation |
|
|
) |
|
|
from dataset.coco import ( |
|
|
CocoPanopticDataset, |
|
|
CocoPanopticTrainAugmentation, |
|
|
) |
|
|
from dataset.spd import ( |
|
|
SuperviselyPersonDataset |
|
|
) |
|
|
from dataset.youtubevis import ( |
|
|
YouTubeVISDataset, |
|
|
YouTubeVISAugmentation |
|
|
) |
|
|
from dataset.augmentation import ( |
|
|
TrainFrameSampler, |
|
|
ValidFrameSampler |
|
|
) |
|
|
from model import MattingNetwork |
|
|
from train_config import DATA_PATHS |
|
|
from train_loss import matting_loss, segmentation_loss |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
def __init__(self, rank, world_size): |
|
|
self.parse_args() |
|
|
self.init_distributed(rank, world_size) |
|
|
self.init_datasets() |
|
|
self.init_model() |
|
|
self.init_writer() |
|
|
self.train() |
|
|
self.cleanup() |
|
|
|
|
|
def parse_args(self): |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50']) |
|
|
|
|
|
parser.add_argument('--dataset', type=str, required=True, choices=['videomatte', 'imagematte']) |
|
|
|
|
|
parser.add_argument('--learning-rate-backbone', type=float, required=True) |
|
|
parser.add_argument('--learning-rate-aspp', type=float, required=True) |
|
|
parser.add_argument('--learning-rate-decoder', type=float, required=True) |
|
|
parser.add_argument('--learning-rate-refiner', type=float, required=True) |
|
|
|
|
|
parser.add_argument('--train-hr', action='store_true') |
|
|
parser.add_argument('--resolution-lr', type=int, default=512) |
|
|
parser.add_argument('--resolution-hr', type=int, default=2048) |
|
|
parser.add_argument('--seq-length-lr', type=int, required=True) |
|
|
parser.add_argument('--seq-length-hr', type=int, default=6) |
|
|
parser.add_argument('--downsample-ratio', type=float, default=0.25) |
|
|
parser.add_argument('--batch-size-per-gpu', type=int, default=1) |
|
|
parser.add_argument('--num-workers', type=int, default=8) |
|
|
parser.add_argument('--epoch-start', type=int, default=0) |
|
|
parser.add_argument('--epoch-end', type=int, default=16) |
|
|
|
|
|
parser.add_argument('--log-dir', type=str, required=True) |
|
|
parser.add_argument('--log-train-loss-interval', type=int, default=20) |
|
|
parser.add_argument('--log-train-images-interval', type=int, default=500) |
|
|
|
|
|
parser.add_argument('--checkpoint', type=str) |
|
|
parser.add_argument('--checkpoint-dir', type=str, required=True) |
|
|
parser.add_argument('--checkpoint-save-interval', type=int, default=500) |
|
|
|
|
|
parser.add_argument('--distributed-addr', type=str, default='localhost') |
|
|
parser.add_argument('--distributed-port', type=str, default='12355') |
|
|
|
|
|
parser.add_argument('--disable-progress-bar', action='store_true') |
|
|
parser.add_argument('--disable-validation', action='store_true') |
|
|
parser.add_argument('--disable-mixed-precision', action='store_true') |
|
|
self.args = parser.parse_args() |
|
|
|
|
|
def init_distributed(self, rank, world_size): |
|
|
self.rank = rank |
|
|
self.world_size = world_size |
|
|
self.log('Initializing distributed') |
|
|
os.environ['MASTER_ADDR'] = self.args.distributed_addr |
|
|
os.environ['MASTER_PORT'] = self.args.distributed_port |
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size) |
|
|
|
|
|
def init_datasets(self): |
|
|
self.log('Initializing matting datasets') |
|
|
size_hr = (self.args.resolution_hr, self.args.resolution_hr) |
|
|
size_lr = (self.args.resolution_lr, self.args.resolution_lr) |
|
|
|
|
|
|
|
|
if self.args.dataset == 'videomatte': |
|
|
self.dataset_lr_train = VideoMatteDataset( |
|
|
videomatte_dir=DATA_PATHS['videomatte']['train'], |
|
|
background_image_dir=DATA_PATHS['background_images']['train'], |
|
|
background_video_dir=DATA_PATHS['background_videos']['train'], |
|
|
size=self.args.resolution_lr, |
|
|
seq_length=self.args.seq_length_lr, |
|
|
seq_sampler=TrainFrameSampler(), |
|
|
transform=VideoMatteTrainAugmentation(size_lr)) |
|
|
if self.args.train_hr: |
|
|
self.dataset_hr_train = VideoMatteDataset( |
|
|
videomatte_dir=DATA_PATHS['videomatte']['train'], |
|
|
background_image_dir=DATA_PATHS['background_images']['train'], |
|
|
background_video_dir=DATA_PATHS['background_videos']['train'], |
|
|
size=self.args.resolution_hr, |
|
|
seq_length=self.args.seq_length_hr, |
|
|
seq_sampler=TrainFrameSampler(), |
|
|
transform=VideoMatteTrainAugmentation(size_hr)) |
|
|
self.dataset_valid = VideoMatteDataset( |
|
|
videomatte_dir=DATA_PATHS['videomatte']['valid'], |
|
|
background_image_dir=DATA_PATHS['background_images']['valid'], |
|
|
background_video_dir=DATA_PATHS['background_videos']['valid'], |
|
|
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, |
|
|
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, |
|
|
seq_sampler=ValidFrameSampler(), |
|
|
transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr)) |
|
|
else: |
|
|
self.dataset_lr_train = ImageMatteDataset( |
|
|
imagematte_dir=DATA_PATHS['imagematte']['train'], |
|
|
background_image_dir=DATA_PATHS['background_images']['train'], |
|
|
background_video_dir=DATA_PATHS['background_videos']['train'], |
|
|
size=self.args.resolution_lr, |
|
|
seq_length=self.args.seq_length_lr, |
|
|
seq_sampler=TrainFrameSampler(), |
|
|
transform=ImageMatteAugmentation(size_lr)) |
|
|
if self.args.train_hr: |
|
|
self.dataset_hr_train = ImageMatteDataset( |
|
|
imagematte_dir=DATA_PATHS['imagematte']['train'], |
|
|
background_image_dir=DATA_PATHS['background_images']['train'], |
|
|
background_video_dir=DATA_PATHS['background_videos']['train'], |
|
|
size=self.args.resolution_hr, |
|
|
seq_length=self.args.seq_length_hr, |
|
|
seq_sampler=TrainFrameSampler(), |
|
|
transform=ImageMatteAugmentation(size_hr)) |
|
|
self.dataset_valid = ImageMatteDataset( |
|
|
imagematte_dir=DATA_PATHS['imagematte']['valid'], |
|
|
background_image_dir=DATA_PATHS['background_images']['valid'], |
|
|
background_video_dir=DATA_PATHS['background_videos']['valid'], |
|
|
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, |
|
|
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, |
|
|
seq_sampler=ValidFrameSampler(), |
|
|
transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr)) |
|
|
|
|
|
|
|
|
self.datasampler_lr_train = DistributedSampler( |
|
|
dataset=self.dataset_lr_train, |
|
|
rank=self.rank, |
|
|
num_replicas=self.world_size, |
|
|
shuffle=True) |
|
|
self.dataloader_lr_train = DataLoader( |
|
|
dataset=self.dataset_lr_train, |
|
|
batch_size=self.args.batch_size_per_gpu, |
|
|
num_workers=self.args.num_workers, |
|
|
sampler=self.datasampler_lr_train, |
|
|
pin_memory=True) |
|
|
if self.args.train_hr: |
|
|
self.datasampler_hr_train = DistributedSampler( |
|
|
dataset=self.dataset_hr_train, |
|
|
rank=self.rank, |
|
|
num_replicas=self.world_size, |
|
|
shuffle=True) |
|
|
self.dataloader_hr_train = DataLoader( |
|
|
dataset=self.dataset_hr_train, |
|
|
batch_size=self.args.batch_size_per_gpu, |
|
|
num_workers=self.args.num_workers, |
|
|
sampler=self.datasampler_hr_train, |
|
|
pin_memory=True) |
|
|
self.dataloader_valid = DataLoader( |
|
|
dataset=self.dataset_valid, |
|
|
batch_size=self.args.batch_size_per_gpu, |
|
|
num_workers=self.args.num_workers, |
|
|
pin_memory=True) |
|
|
|
|
|
|
|
|
self.log('Initializing image segmentation datasets') |
|
|
self.dataset_seg_image = ConcatDataset([ |
|
|
CocoPanopticDataset( |
|
|
imgdir=DATA_PATHS['coco_panoptic']['imgdir'], |
|
|
anndir=DATA_PATHS['coco_panoptic']['anndir'], |
|
|
annfile=DATA_PATHS['coco_panoptic']['annfile'], |
|
|
transform=CocoPanopticTrainAugmentation(size_lr)), |
|
|
SuperviselyPersonDataset( |
|
|
imgdir=DATA_PATHS['spd']['imgdir'], |
|
|
segdir=DATA_PATHS['spd']['segdir'], |
|
|
transform=CocoPanopticTrainAugmentation(size_lr)) |
|
|
]) |
|
|
self.datasampler_seg_image = DistributedSampler( |
|
|
dataset=self.dataset_seg_image, |
|
|
rank=self.rank, |
|
|
num_replicas=self.world_size, |
|
|
shuffle=True) |
|
|
self.dataloader_seg_image = DataLoader( |
|
|
dataset=self.dataset_seg_image, |
|
|
batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr, |
|
|
num_workers=self.args.num_workers, |
|
|
sampler=self.datasampler_seg_image, |
|
|
pin_memory=True) |
|
|
|
|
|
self.log('Initializing video segmentation datasets') |
|
|
self.dataset_seg_video = YouTubeVISDataset( |
|
|
videodir=DATA_PATHS['youtubevis']['videodir'], |
|
|
annfile=DATA_PATHS['youtubevis']['annfile'], |
|
|
size=self.args.resolution_lr, |
|
|
seq_length=self.args.seq_length_lr, |
|
|
seq_sampler=TrainFrameSampler(speed=[1]), |
|
|
transform=YouTubeVISAugmentation(size_lr)) |
|
|
self.datasampler_seg_video = DistributedSampler( |
|
|
dataset=self.dataset_seg_video, |
|
|
rank=self.rank, |
|
|
num_replicas=self.world_size, |
|
|
shuffle=True) |
|
|
self.dataloader_seg_video = DataLoader( |
|
|
dataset=self.dataset_seg_video, |
|
|
batch_size=self.args.batch_size_per_gpu, |
|
|
num_workers=self.args.num_workers, |
|
|
sampler=self.datasampler_seg_video, |
|
|
pin_memory=True) |
|
|
|
|
|
def init_model(self): |
|
|
self.log('Initializing model') |
|
|
self.model = MattingNetwork(self.args.model_variant, pretrained_backbone=True).to(self.rank) |
|
|
|
|
|
if self.args.checkpoint: |
|
|
self.log(f'Restoring from checkpoint: {self.args.checkpoint}') |
|
|
self.log(self.model.load_state_dict( |
|
|
torch.load(self.args.checkpoint, map_location=f'cuda:{self.rank}'))) |
|
|
|
|
|
self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) |
|
|
self.model_ddp = DDP(self.model, device_ids=[self.rank], broadcast_buffers=False, find_unused_parameters=True) |
|
|
self.optimizer = Adam([ |
|
|
{'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone}, |
|
|
{'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp}, |
|
|
{'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder}, |
|
|
{'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder}, |
|
|
{'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder}, |
|
|
{'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner}, |
|
|
]) |
|
|
self.scaler = GradScaler() |
|
|
|
|
|
def init_writer(self): |
|
|
if self.rank == 0: |
|
|
self.log('Initializing writer') |
|
|
self.writer = SummaryWriter(self.args.log_dir) |
|
|
|
|
|
def train(self): |
|
|
for epoch in range(self.args.epoch_start, self.args.epoch_end): |
|
|
self.epoch = epoch |
|
|
self.step = epoch * len(self.dataloader_lr_train) |
|
|
|
|
|
if not self.args.disable_validation: |
|
|
self.validate() |
|
|
|
|
|
self.log(f'Training epoch: {epoch}') |
|
|
for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_lr_train, disable=self.args.disable_progress_bar, dynamic_ncols=True): |
|
|
|
|
|
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=1, tag='lr') |
|
|
|
|
|
|
|
|
if self.args.train_hr: |
|
|
true_fgr, true_pha, true_bgr = self.load_next_mat_hr_sample() |
|
|
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr') |
|
|
|
|
|
|
|
|
if self.step % 2 == 0: |
|
|
true_img, true_seg = self.load_next_seg_video_sample() |
|
|
self.train_seg(true_img, true_seg, log_label='seg_video') |
|
|
else: |
|
|
true_img, true_seg = self.load_next_seg_image_sample() |
|
|
self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image') |
|
|
|
|
|
if self.step % self.args.checkpoint_save_interval == 0: |
|
|
self.save() |
|
|
|
|
|
self.step += 1 |
|
|
|
|
|
def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag): |
|
|
true_fgr = true_fgr.to(self.rank, non_blocking=True) |
|
|
true_pha = true_pha.to(self.rank, non_blocking=True) |
|
|
true_bgr = true_bgr.to(self.rank, non_blocking=True) |
|
|
true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr) |
|
|
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) |
|
|
|
|
|
with autocast(enabled=not self.args.disable_mixed_precision): |
|
|
pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2] |
|
|
loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha) |
|
|
|
|
|
self.scaler.scale(loss['total']).backward() |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
if self.rank == 0 and self.step % self.args.log_train_loss_interval == 0: |
|
|
for loss_name, loss_value in loss.items(): |
|
|
self.writer.add_scalar(f'train_{tag}_{loss_name}', loss_value, self.step) |
|
|
|
|
|
if self.rank == 0 and self.step % self.args.log_train_images_interval == 0: |
|
|
self.writer.add_image(f'train_{tag}_pred_fgr', make_grid(pred_fgr.flatten(0, 1), nrow=pred_fgr.size(1)), self.step) |
|
|
self.writer.add_image(f'train_{tag}_pred_pha', make_grid(pred_pha.flatten(0, 1), nrow=pred_pha.size(1)), self.step) |
|
|
self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step) |
|
|
self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step) |
|
|
self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step) |
|
|
|
|
|
def train_seg(self, true_img, true_seg, log_label): |
|
|
true_img = true_img.to(self.rank, non_blocking=True) |
|
|
true_seg = true_seg.to(self.rank, non_blocking=True) |
|
|
|
|
|
true_img, true_seg = self.random_crop(true_img, true_seg) |
|
|
|
|
|
with autocast(enabled=not self.args.disable_mixed_precision): |
|
|
pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0] |
|
|
loss = segmentation_loss(pred_seg, true_seg) |
|
|
|
|
|
self.scaler.scale(loss).backward() |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0: |
|
|
self.writer.add_scalar(f'{log_label}_loss', loss, self.step) |
|
|
|
|
|
if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0: |
|
|
self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step) |
|
|
self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) |
|
|
self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) |
|
|
|
|
|
def load_next_mat_hr_sample(self): |
|
|
try: |
|
|
sample = next(self.dataiterator_mat_hr) |
|
|
except: |
|
|
self.datasampler_hr_train.set_epoch(self.datasampler_hr_train.epoch + 1) |
|
|
self.dataiterator_mat_hr = iter(self.dataloader_hr_train) |
|
|
sample = next(self.dataiterator_mat_hr) |
|
|
return sample |
|
|
|
|
|
def load_next_seg_video_sample(self): |
|
|
try: |
|
|
sample = next(self.dataiterator_seg_video) |
|
|
except: |
|
|
self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1) |
|
|
self.dataiterator_seg_video = iter(self.dataloader_seg_video) |
|
|
sample = next(self.dataiterator_seg_video) |
|
|
return sample |
|
|
|
|
|
def load_next_seg_image_sample(self): |
|
|
try: |
|
|
sample = next(self.dataiterator_seg_image) |
|
|
except: |
|
|
self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1) |
|
|
self.dataiterator_seg_image = iter(self.dataloader_seg_image) |
|
|
sample = next(self.dataiterator_seg_image) |
|
|
return sample |
|
|
|
|
|
def validate(self): |
|
|
if self.rank == 0: |
|
|
self.log(f'Validating at the start of epoch: {self.epoch}') |
|
|
self.model_ddp.eval() |
|
|
total_loss, total_count = 0, 0 |
|
|
with torch.no_grad(): |
|
|
with autocast(enabled=not self.args.disable_mixed_precision): |
|
|
for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_valid, disable=self.args.disable_progress_bar, dynamic_ncols=True): |
|
|
true_fgr = true_fgr.to(self.rank, non_blocking=True) |
|
|
true_pha = true_pha.to(self.rank, non_blocking=True) |
|
|
true_bgr = true_bgr.to(self.rank, non_blocking=True) |
|
|
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) |
|
|
batch_size = true_src.size(0) |
|
|
pred_fgr, pred_pha = self.model(true_src)[:2] |
|
|
total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size |
|
|
total_count += batch_size |
|
|
avg_loss = total_loss / total_count |
|
|
self.log(f'Validation set average loss: {avg_loss}') |
|
|
self.writer.add_scalar('valid_loss', avg_loss, self.step) |
|
|
self.model_ddp.train() |
|
|
dist.barrier() |
|
|
|
|
|
def random_crop(self, *imgs): |
|
|
h, w = imgs[0].shape[-2:] |
|
|
w = random.choice(range(w // 2, w)) |
|
|
h = random.choice(range(h // 2, h)) |
|
|
results = [] |
|
|
for img in imgs: |
|
|
B, T = img.shape[:2] |
|
|
img = img.flatten(0, 1) |
|
|
img = F.interpolate(img, (max(h, w), max(h, w)), mode='bilinear', align_corners=False) |
|
|
img = center_crop(img, (h, w)) |
|
|
img = img.reshape(B, T, *img.shape[1:]) |
|
|
results.append(img) |
|
|
return results |
|
|
|
|
|
def save(self): |
|
|
if self.rank == 0: |
|
|
os.makedirs(self.args.checkpoint_dir, exist_ok=True) |
|
|
torch.save(self.model.state_dict(), os.path.join(self.args.checkpoint_dir, f'epoch-{self.epoch}.pth')) |
|
|
self.log('Model saved') |
|
|
dist.barrier() |
|
|
|
|
|
def cleanup(self): |
|
|
dist.destroy_process_group() |
|
|
|
|
|
def log(self, msg): |
|
|
print(f'[GPU{self.rank}] {msg}') |
|
|
|
|
|
if __name__ == '__main__': |
|
|
world_size = torch.cuda.device_count() |
|
|
mp.spawn( |
|
|
Trainer, |
|
|
nprocs=world_size, |
|
|
args=(world_size,), |
|
|
join=True) |
|
|
|