Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| from tops.config import LazyCall as L | |
| import torch | |
| import functools | |
| from dp2.data.datasets.fdh import get_dataloader_fdh_wds | |
| from dp2.data.utils import get_coco_flipmap | |
| from dp2.data.transforms.transforms import ( | |
| Normalize, | |
| ToFloat, | |
| CreateCondition, | |
| RandomHorizontalFlip, | |
| CreateEmbedding, | |
| ) | |
| from dp2.metrics.torch_metrics import compute_metrics_iteratively | |
| from dp2.metrics.fid_clip import compute_fid_clip | |
| from dp2.metrics.ppl import calculate_ppl | |
| from .utils import train_eval_fn | |
| def final_eval_fn(*args, **kwargs): | |
| result = compute_metrics_iteratively(*args, **kwargs) | |
| result2 = calculate_ppl(*args, **kwargs, upsample_size=(288, 160)) | |
| result3 = compute_fid_clip(*args, **kwargs) | |
| assert all(key not in result for key in result2) | |
| result.update(result2) | |
| result.update(result3) | |
| return result | |
| def get_cache_directory(imsize, subset): | |
| return Path(metrics_cache, f"{subset}{imsize[0]}") | |
| dataset_base_dir = ( | |
| os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data" | |
| ) | |
| metrics_cache = ( | |
| os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache" | |
| ) | |
| data_dir = Path(dataset_base_dir, "fdh") | |
| data = dict( | |
| imsize=(288, 160), | |
| im_channels=3, | |
| cse_nc=16, | |
| n_keypoints=17, | |
| train=dict( | |
| loader=L(get_dataloader_fdh_wds)( | |
| path=data_dir.joinpath("train", "out-{000000..001423}.tar"), | |
| batch_size="${train.batch_size}", | |
| num_workers=6, | |
| transform=L(torch.nn.Sequential)( | |
| L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()), | |
| ), | |
| gpu_transform=L(torch.nn.Sequential)( | |
| L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]), | |
| L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")), | |
| L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True), | |
| L(CreateCondition)(), | |
| ), | |
| infinite=True, | |
| shuffle=True, | |
| partial_batches=False, | |
| load_embedding=True, | |
| keypoints_split="train", | |
| load_new_keypoints=False | |
| ) | |
| ), | |
| val=dict( | |
| loader=L(get_dataloader_fdh_wds)( | |
| path=data_dir.joinpath("val", "out-{000000..000023}.tar"), | |
| batch_size="${train.batch_size}", | |
| num_workers=6, | |
| transform=None, | |
| gpu_transform="${data.train.loader.gpu_transform}", | |
| infinite=False, | |
| shuffle=False, | |
| partial_batches=True, | |
| load_embedding=True, | |
| keypoints_split="val", | |
| load_new_keypoints="${data.train.loader.load_new_keypoints}" | |
| ) | |
| ), | |
| # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP. | |
| train_evaluation_fn=L(functools.partial)( | |
| train_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh"), | |
| data_len=30_000), | |
| evaluation_fn=L(functools.partial)( | |
| final_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh_eval"), | |
| data_len=30_000) | |
| ) | |