| import os | |
| import torch | |
| def main(run, cfg): | |
| from torch.utils.data.distributed import DistributedSampler | |
| from utils.trainer import Trainer | |
| if cfg.reproduce_dire: | |
| dataset = TMDireDataset(cfg.dataset_root) | |
| val_dataset = TMDireDataset(cfg.dataset_test_root) | |
| elif cfg.only_eps: | |
| dataset = TMEPSOnlyDataset(cfg.dataset_root) | |
| val_dataset = TMEPSOnlyDataset(cfg.dataset_root) | |
| elif cfg.only_img: | |
| dataset = TMIMGOnlyDataset(cfg.dataset_root, istrain=True) | |
| val_dataset = TMIMGOnlyDataset(cfg.dataset_test_root, istrain=False) | |
| else: | |
| dataset= TMDistilDireDataset(cfg.dataset_root) | |
| val_dataset = TMDistilDireDataset(cfg.dataset_test_root) | |
| sampler = DistributedSampler(dataset) | |
| val_samlper = DistributedSampler(val_dataset) | |
| dataloader = DataLoader(dataset, | |
| batch_size=cfg.batch_size, | |
| sampler=sampler, | |
| num_workers=2) | |
| val_loader = DataLoader(val_dataset, | |
| batch_size=cfg.batch_size, | |
| sampler=val_samlper, | |
| num_workers=2) | |
| trainer = Trainer(cfg, dataloader, val_loader, run, local_rank, True, world_size, cfg.kd) | |
| if cfg.pretrained_weights: | |
| trainer.load_networks(cfg.pretrained_weights) | |
| trainer.train() | |
| if __name__ == "__main__": | |
| import torch.distributed as dist | |
| import os | |
| import wandb | |
| from torch.utils.data import DataLoader | |
| from dataset import TMDistilDireDataset, TMDireDataset, TMEPSOnlyDataset, TMIMGOnlyDataset | |
| dist.init_process_group(backend='nccl', init_method='env://') | |
| local_rank = int(os.environ['LOCAL_RANK']) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| torch.cuda.set_device(local_rank) | |
| dist.barrier() | |
| from utils.config import cfg | |
| run = None | |
| if local_rank == 0: | |
| run = wandb.init(project=f'dire-distill-truemedia', config=cfg, dir=cfg.exp_dir) | |
| main(run, cfg) | |