Spaces:
Runtime error
Runtime error
| import torch | |
| import tops | |
| import numpy as np | |
| import io | |
| import webdataset as wds | |
| import os | |
| import json | |
| from pathlib import Path | |
| from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn | |
| def kp_decoder(x): | |
| # Keypoints are between [0, 1] for webdataset | |
| keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float() | |
| def check_outside(x): return (x < 0).logical_or(x > 1) | |
| is_outside = check_outside(keypoints[:, 0]).logical_or( | |
| check_outside(keypoints[:, 1]) | |
| ) | |
| keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not()) | |
| return keypoints | |
| def vertices_decoder(x): | |
| vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32)) | |
| return vertices.squeeze()[None] | |
| class InsertNewKeypoints: | |
| def __init__(self, keypoints_path: Path) -> None: | |
| with open(keypoints_path, "r") as fp: | |
| self.keypoints = json.load(fp) | |
| def __call__(self, sample): | |
| key = sample["__key__"] | |
| keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32) | |
| def check_outside(x): return (x < 0).logical_or(x > 1) | |
| is_outside = check_outside(keypoints[:, 0]).logical_or( | |
| check_outside(keypoints[:, 1]) | |
| ) | |
| keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not()) | |
| sample["keypoints.npy"] = keypoints | |
| return sample | |
| def get_dataloader_fdh_wds( | |
| path, | |
| batch_size: int, | |
| num_workers: int, | |
| transform: torch.nn.Module, | |
| gpu_transform: torch.nn.Module, | |
| infinite: bool, | |
| shuffle: bool, | |
| partial_batches: bool, | |
| load_embedding: bool, | |
| sample_shuffle=10_000, | |
| tar_shuffle=100, | |
| read_condition=False, | |
| channels_last=False, | |
| load_new_keypoints=False, | |
| keypoints_split=None, | |
| ): | |
| # Need to set this for split_by_node to work. | |
| os.environ["RANK"] = str(tops.rank()) | |
| os.environ["WORLD_SIZE"] = str(tops.world_size()) | |
| if infinite: | |
| pipeline = [wds.ResampledShards(str(path))] | |
| else: | |
| pipeline = [wds.SimpleShardList(str(path))] | |
| if shuffle: | |
| pipeline.append(wds.shuffle(tar_shuffle)) | |
| pipeline.extend([ | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| ]) | |
| if shuffle: | |
| pipeline.append(wds.shuffle(sample_shuffle)) | |
| decoder = [ | |
| wds.handle_extension("image.png", png_decoder), | |
| wds.handle_extension("mask.png", mask_decoder), | |
| wds.handle_extension("maskrcnn_mask.png", mask_decoder), | |
| wds.handle_extension("keypoints.npy", kp_decoder), | |
| ] | |
| rename_keys = [ | |
| ["img", "image.png"], ["mask", "mask.png"], | |
| ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"], | |
| ["__key__", "__key__"] | |
| ] | |
| if load_embedding: | |
| decoder.extend([ | |
| wds.handle_extension("vertices.npy", vertices_decoder), | |
| wds.handle_extension("E_mask.png", mask_decoder) | |
| ]) | |
| rename_keys.extend([ | |
| ["vertices", "vertices.npy"], | |
| ["E_mask", "e_mask.png"] | |
| ]) | |
| if read_condition: | |
| decoder.append( | |
| wds.handle_extension("condition.png", png_decoder) | |
| ) | |
| rename_keys.append(["condition", "condition.png"]) | |
| pipeline.extend([ | |
| wds.tarfile_to_samples(), | |
| wds.decode(*decoder), | |
| ]) | |
| if load_new_keypoints: | |
| assert keypoints_split in ["train", "val"] | |
| keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f" | |
| file_name = "fdh_keypoints_val-050133b34d.json" | |
| if keypoints_split == "train": | |
| keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58" | |
| file_name = "fdh_keypoints_train-2cff11f69a.json" | |
| # Set check_hash=True if you suspect download is incorrect. | |
| filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False) | |
| pipeline.append( | |
| wds.map(InsertNewKeypoints(filepath)) | |
| ) | |
| pipeline.extend([ | |
| wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), | |
| wds.rename_keys(*rename_keys), | |
| ]) | |
| if transform is not None: | |
| pipeline.append(wds.map(transform)) | |
| pipeline = wds.DataPipeline(*pipeline) | |
| if infinite: | |
| pipeline = pipeline.repeat(nepochs=1000000) | |
| loader = wds.WebLoader( | |
| pipeline, batch_size=None, shuffle=False, | |
| num_workers=get_num_workers(num_workers), | |
| persistent_workers=True, | |
| ) | |
| loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) | |
| return loader | |