Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import glob | |
| import json | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from PIL import Image as PILImage | |
| try: | |
| from pycocotools import mask as mask_utils | |
| except: | |
| pass | |
| class JSONSegmentLoader: | |
| def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None): | |
| # Annotations in the json are provided every ann_every th frame | |
| self.ann_every = ann_every | |
| # Ids of the objects to consider when sampling this video | |
| self.valid_obj_ids = valid_obj_ids | |
| with open(video_json_path, "r") as f: | |
| data = json.load(f) | |
| if isinstance(data, list): | |
| self.frame_annots = data | |
| elif isinstance(data, dict): | |
| masklet_field_name = "masklet" if "masklet" in data else "masks" | |
| self.frame_annots = data[masklet_field_name] | |
| if "fps" in data: | |
| if isinstance(data["fps"], list): | |
| annotations_fps = int(data["fps"][0]) | |
| else: | |
| annotations_fps = int(data["fps"]) | |
| assert frames_fps % annotations_fps == 0 | |
| self.ann_every = frames_fps // annotations_fps | |
| else: | |
| raise NotImplementedError | |
| def load(self, frame_id, obj_ids=None): | |
| assert frame_id % self.ann_every == 0 | |
| rle_mask = self.frame_annots[frame_id // self.ann_every] | |
| valid_objs_ids = set(range(len(rle_mask))) | |
| if self.valid_obj_ids is not None: | |
| # Remove the masklets that have been filtered out for this video | |
| valid_objs_ids &= set(self.valid_obj_ids) | |
| if obj_ids is not None: | |
| # Only keep the objects that have been sampled | |
| valid_objs_ids &= set(obj_ids) | |
| valid_objs_ids = sorted(list(valid_objs_ids)) | |
| # Construct rle_masks_filtered that only contains the rle masks we are interested in | |
| id_2_idx = {} | |
| rle_mask_filtered = [] | |
| for obj_id in valid_objs_ids: | |
| if rle_mask[obj_id] is not None: | |
| id_2_idx[obj_id] = len(rle_mask_filtered) | |
| rle_mask_filtered.append(rle_mask[obj_id]) | |
| else: | |
| id_2_idx[obj_id] = None | |
| # Decode the masks | |
| raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute( | |
| 2, 0, 1 | |
| ) # (num_obj, h, w) | |
| segments = {} | |
| for obj_id in valid_objs_ids: | |
| if id_2_idx[obj_id] is None: | |
| segments[obj_id] = None | |
| else: | |
| idx = id_2_idx[obj_id] | |
| segments[obj_id] = raw_segments[idx] | |
| return segments | |
| def get_valid_obj_frames_ids(self, num_frames_min=None): | |
| # For each object, find all the frames with a valid (not None) mask | |
| num_objects = len(self.frame_annots[0]) | |
| # The result dict associates each obj_id with the id of its valid frames | |
| res = {obj_id: [] for obj_id in range(num_objects)} | |
| for annot_idx, annot in enumerate(self.frame_annots): | |
| for obj_id in range(num_objects): | |
| if annot[obj_id] is not None: | |
| res[obj_id].append(int(annot_idx * self.ann_every)) | |
| if num_frames_min is not None: | |
| # Remove masklets that have less than num_frames_min valid masks | |
| for obj_id, valid_frames in list(res.items()): | |
| if len(valid_frames) < num_frames_min: | |
| res.pop(obj_id) | |
| return res | |
| class PalettisedPNGSegmentLoader: | |
| def __init__(self, video_png_root): | |
| """ | |
| SegmentLoader for datasets with masks stored as palettised PNGs. | |
| video_png_root: the folder contains all the masks stored in png | |
| """ | |
| self.video_png_root = video_png_root | |
| # build a mapping from frame id to their PNG mask path | |
| # note that in some datasets, the PNG paths could have more | |
| # than 5 digits, e.g. "00000000.png" instead of "00000.png" | |
| png_filenames = os.listdir(self.video_png_root) | |
| self.frame_id_to_png_filename = {} | |
| for filename in png_filenames: | |
| frame_id, _ = os.path.splitext(filename) | |
| self.frame_id_to_png_filename[int(frame_id)] = filename | |
| def load(self, frame_id): | |
| """ | |
| load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png') | |
| Args: | |
| frame_id: int, define the mask path | |
| Return: | |
| binary_segments: dict | |
| """ | |
| # check the path | |
| mask_path = os.path.join( | |
| self.video_png_root, self.frame_id_to_png_filename[frame_id] | |
| ) | |
| # load the mask | |
| masks = PILImage.open(mask_path).convert("P") | |
| masks = np.array(masks) | |
| object_id = pd.unique(masks.flatten()) | |
| object_id = object_id[object_id != 0] # remove background (0) | |
| # convert into N binary segmentation masks | |
| binary_segments = {} | |
| for i in object_id: | |
| bs = masks == i | |
| binary_segments[i] = torch.from_numpy(bs) | |
| return binary_segments | |
| def __len__(self): | |
| return | |
| class MultiplePNGSegmentLoader: | |
| def __init__(self, video_png_root, single_object_mode=False): | |
| """ | |
| video_png_root: the folder contains all the masks stored in png | |
| single_object_mode: whether to load only a single object at a time | |
| """ | |
| self.video_png_root = video_png_root | |
| self.single_object_mode = single_object_mode | |
| # read a mask to know the resolution of the video | |
| if self.single_object_mode: | |
| tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0] | |
| else: | |
| tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0] | |
| tmp_mask = np.array(PILImage.open(tmp_mask_path)) | |
| self.H = tmp_mask.shape[0] | |
| self.W = tmp_mask.shape[1] | |
| if self.single_object_mode: | |
| self.obj_id = ( | |
| int(video_png_root.split("/")[-1]) + 1 | |
| ) # offset by 1 as bg is 0 | |
| else: | |
| self.obj_id = None | |
| def load(self, frame_id): | |
| if self.single_object_mode: | |
| return self._load_single_png(frame_id) | |
| else: | |
| return self._load_multiple_pngs(frame_id) | |
| def _load_single_png(self, frame_id): | |
| """ | |
| load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png') | |
| Args: | |
| frame_id: int, define the mask path | |
| Return: | |
| binary_segments: dict | |
| """ | |
| mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png") | |
| binary_segments = {} | |
| if os.path.exists(mask_path): | |
| mask = np.array(PILImage.open(mask_path)) | |
| else: | |
| # if png doesn't exist, empty mask | |
| mask = np.zeros((self.H, self.W), dtype=bool) | |
| binary_segments[self.obj_id] = torch.from_numpy(mask > 0) | |
| return binary_segments | |
| def _load_multiple_pngs(self, frame_id): | |
| """ | |
| load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png') | |
| Args: | |
| frame_id: int, define the mask path | |
| Return: | |
| binary_segments: dict | |
| """ | |
| # get the path | |
| all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*"))) | |
| num_objects = len(all_objects) | |
| assert num_objects > 0 | |
| # load the masks | |
| binary_segments = {} | |
| for obj_folder in all_objects: | |
| # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder | |
| obj_id = int(obj_folder.split("/")[-1]) | |
| obj_id = obj_id + 1 # offset 1 as bg is 0 | |
| mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png") | |
| if os.path.exists(mask_path): | |
| mask = np.array(PILImage.open(mask_path)) | |
| else: | |
| mask = np.zeros((self.H, self.W), dtype=bool) | |
| binary_segments[obj_id] = torch.from_numpy(mask > 0) | |
| return binary_segments | |
| def __len__(self): | |
| return | |
| class LazySegments: | |
| """ | |
| Only decodes segments that are actually used. | |
| """ | |
| def __init__(self): | |
| self.segments = {} | |
| self.cache = {} | |
| def __setitem__(self, key, item): | |
| self.segments[key] = item | |
| def __getitem__(self, key): | |
| if key in self.cache: | |
| return self.cache[key] | |
| rle = self.segments[key] | |
| mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0] | |
| self.cache[key] = mask | |
| return mask | |
| def __contains__(self, key): | |
| return key in self.segments | |
| def __len__(self): | |
| return len(self.segments) | |
| def keys(self): | |
| return self.segments.keys() | |
| class SA1BSegmentLoader: | |
| def __init__( | |
| self, | |
| video_mask_path, | |
| mask_area_frac_thresh=1.1, | |
| video_frame_path=None, | |
| uncertain_iou=-1, | |
| ): | |
| with open(video_mask_path, "r") as f: | |
| self.frame_annots = json.load(f) | |
| if mask_area_frac_thresh <= 1.0: | |
| # Lazily read frame | |
| orig_w, orig_h = PILImage.open(video_frame_path).size | |
| area = orig_w * orig_h | |
| self.frame_annots = self.frame_annots["annotations"] | |
| rle_masks = [] | |
| for frame_annot in self.frame_annots: | |
| if not frame_annot["area"] > 0: | |
| continue | |
| if ("uncertain_iou" in frame_annot) and ( | |
| frame_annot["uncertain_iou"] < uncertain_iou | |
| ): | |
| # uncertain_iou is stability score | |
| continue | |
| if ( | |
| mask_area_frac_thresh <= 1.0 | |
| and (frame_annot["area"] / area) >= mask_area_frac_thresh | |
| ): | |
| continue | |
| rle_masks.append(frame_annot["segmentation"]) | |
| self.segments = LazySegments() | |
| for i, rle in enumerate(rle_masks): | |
| self.segments[i] = rle | |
| def load(self, frame_idx): | |
| return self.segments | |