Spaces:
Runtime error
Runtime error
| def build_loader_simmim(config): | |
| ############ single model ##################### | |
| # transform = SimMIMTransform(config) | |
| # dataset = ImageFolder(config.DATA.DATA_PATH, transform) | |
| # sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) | |
| # dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) | |
| ############## multi model #################### | |
| datasets = [] | |
| ### 数据增强 ###### | |
| model_paths = config.DATA.TYPE_PATH[0] | |
| for i in model_paths.keys(): | |
| a = config.DATA.SCALE[0][i].split(',') | |
| scale_model = (float(a[0].split('(')[1]) ,float(a[1].split(')')[0])) | |
| transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model) | |
| dataset = CachedImageFolder(model_paths[i], transform = transform, model = i) | |
| datasets.append(dataset) | |
| multi_task_train_dataset = MultiTaskDataset(datasets) | |
| print(len(datasets)) | |
| multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True ,shuffle =True) | |
| dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn) | |
| # dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn) | |
| print(len(dataloader)) | |
| return dataloader |