|
|
|
|
|
|
|
|
|
|
|
|
| import glob
|
| import json
|
| import os
|
|
|
| import numpy as np
|
| import pandas as pd
|
| import torch
|
|
|
| from PIL import Image as PILImage
|
|
|
| try:
|
| from pycocotools import mask as mask_utils
|
| except:
|
| pass
|
|
|
|
|
| class JSONSegmentLoader:
|
| def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
|
|
|
| self.ann_every = ann_every
|
|
|
| self.valid_obj_ids = valid_obj_ids
|
| with open(video_json_path, "r") as f:
|
| data = json.load(f)
|
| if isinstance(data, list):
|
| self.frame_annots = data
|
| elif isinstance(data, dict):
|
| masklet_field_name = "masklet" if "masklet" in data else "masks"
|
| self.frame_annots = data[masklet_field_name]
|
| if "fps" in data:
|
| if isinstance(data["fps"], list):
|
| annotations_fps = int(data["fps"][0])
|
| else:
|
| annotations_fps = int(data["fps"])
|
| assert frames_fps % annotations_fps == 0
|
| self.ann_every = frames_fps // annotations_fps
|
| else:
|
| raise NotImplementedError
|
|
|
| def load(self, frame_id, obj_ids=None):
|
| assert frame_id % self.ann_every == 0
|
| rle_mask = self.frame_annots[frame_id // self.ann_every]
|
|
|
| valid_objs_ids = set(range(len(rle_mask)))
|
| if self.valid_obj_ids is not None:
|
|
|
| valid_objs_ids &= set(self.valid_obj_ids)
|
| if obj_ids is not None:
|
|
|
| valid_objs_ids &= set(obj_ids)
|
| valid_objs_ids = sorted(list(valid_objs_ids))
|
|
|
|
|
| id_2_idx = {}
|
| rle_mask_filtered = []
|
| for obj_id in valid_objs_ids:
|
| if rle_mask[obj_id] is not None:
|
| id_2_idx[obj_id] = len(rle_mask_filtered)
|
| rle_mask_filtered.append(rle_mask[obj_id])
|
| else:
|
| id_2_idx[obj_id] = None
|
|
|
|
|
| raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
|
| 2, 0, 1
|
| )
|
| segments = {}
|
| for obj_id in valid_objs_ids:
|
| if id_2_idx[obj_id] is None:
|
| segments[obj_id] = None
|
| else:
|
| idx = id_2_idx[obj_id]
|
| segments[obj_id] = raw_segments[idx]
|
| return segments
|
|
|
| def get_valid_obj_frames_ids(self, num_frames_min=None):
|
|
|
| num_objects = len(self.frame_annots[0])
|
|
|
|
|
| res = {obj_id: [] for obj_id in range(num_objects)}
|
|
|
| for annot_idx, annot in enumerate(self.frame_annots):
|
| for obj_id in range(num_objects):
|
| if annot[obj_id] is not None:
|
| res[obj_id].append(int(annot_idx * self.ann_every))
|
|
|
| if num_frames_min is not None:
|
|
|
| for obj_id, valid_frames in list(res.items()):
|
| if len(valid_frames) < num_frames_min:
|
| res.pop(obj_id)
|
|
|
| return res
|
|
|
|
|
| class PalettisedPNGSegmentLoader:
|
| def __init__(self, video_png_root):
|
| """
|
| SegmentLoader for datasets with masks stored as palettised PNGs.
|
| video_png_root: the folder contains all the masks stored in png
|
| """
|
| self.video_png_root = video_png_root
|
|
|
|
|
|
|
| png_filenames = os.listdir(self.video_png_root)
|
| self.frame_id_to_png_filename = {}
|
| for filename in png_filenames:
|
| frame_id, _ = os.path.splitext(filename)
|
| self.frame_id_to_png_filename[int(frame_id)] = filename
|
|
|
| def load(self, frame_id):
|
| """
|
| load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
|
| Args:
|
| frame_id: int, define the mask path
|
| Return:
|
| binary_segments: dict
|
| """
|
|
|
| mask_path = os.path.join(
|
| self.video_png_root, self.frame_id_to_png_filename[frame_id]
|
| )
|
|
|
|
|
| masks = PILImage.open(mask_path).convert("P")
|
| masks = np.array(masks)
|
|
|
| object_id = pd.unique(masks.flatten())
|
| object_id = object_id[object_id != 0]
|
|
|
|
|
| binary_segments = {}
|
| for i in object_id:
|
| bs = masks == i
|
| binary_segments[i] = torch.from_numpy(bs)
|
|
|
| return binary_segments
|
|
|
| def __len__(self):
|
| return
|
|
|
|
|
| class MultiplePNGSegmentLoader:
|
| def __init__(self, video_png_root, single_object_mode=False):
|
| """
|
| video_png_root: the folder contains all the masks stored in png
|
| single_object_mode: whether to load only a single object at a time
|
| """
|
| self.video_png_root = video_png_root
|
| self.single_object_mode = single_object_mode
|
|
|
| if self.single_object_mode:
|
| tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
|
| else:
|
| tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
|
| tmp_mask = np.array(PILImage.open(tmp_mask_path))
|
| self.H = tmp_mask.shape[0]
|
| self.W = tmp_mask.shape[1]
|
| if self.single_object_mode:
|
| self.obj_id = (
|
| int(video_png_root.split("/")[-1]) + 1
|
| )
|
| else:
|
| self.obj_id = None
|
|
|
| def load(self, frame_id):
|
| if self.single_object_mode:
|
| return self._load_single_png(frame_id)
|
| else:
|
| return self._load_multiple_pngs(frame_id)
|
|
|
| def _load_single_png(self, frame_id):
|
| """
|
| load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
|
| Args:
|
| frame_id: int, define the mask path
|
| Return:
|
| binary_segments: dict
|
| """
|
| mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
|
| binary_segments = {}
|
|
|
| if os.path.exists(mask_path):
|
| mask = np.array(PILImage.open(mask_path))
|
| else:
|
|
|
| mask = np.zeros((self.H, self.W), dtype=bool)
|
| binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
|
| return binary_segments
|
|
|
| def _load_multiple_pngs(self, frame_id):
|
| """
|
| load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
|
| Args:
|
| frame_id: int, define the mask path
|
| Return:
|
| binary_segments: dict
|
| """
|
|
|
| all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
|
| num_objects = len(all_objects)
|
| assert num_objects > 0
|
|
|
|
|
| binary_segments = {}
|
| for obj_folder in all_objects:
|
|
|
| obj_id = int(obj_folder.split("/")[-1])
|
| obj_id = obj_id + 1
|
| mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
|
| if os.path.exists(mask_path):
|
| mask = np.array(PILImage.open(mask_path))
|
| else:
|
| mask = np.zeros((self.H, self.W), dtype=bool)
|
| binary_segments[obj_id] = torch.from_numpy(mask > 0)
|
|
|
| return binary_segments
|
|
|
| def __len__(self):
|
| return
|
|
|
|
|
| class LazySegments:
|
| """
|
| Only decodes segments that are actually used.
|
| """
|
|
|
| def __init__(self):
|
| self.segments = {}
|
| self.cache = {}
|
|
|
| def __setitem__(self, key, item):
|
| self.segments[key] = item
|
|
|
| def __getitem__(self, key):
|
| if key in self.cache:
|
| return self.cache[key]
|
| rle = self.segments[key]
|
| mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
|
| self.cache[key] = mask
|
| return mask
|
|
|
| def __contains__(self, key):
|
| return key in self.segments
|
|
|
| def __len__(self):
|
| return len(self.segments)
|
|
|
| def keys(self):
|
| return self.segments.keys()
|
|
|
|
|
| class SA1BSegmentLoader:
|
| def __init__(
|
| self,
|
| video_mask_path,
|
| mask_area_frac_thresh=1.1,
|
| video_frame_path=None,
|
| uncertain_iou=-1,
|
| ):
|
| with open(video_mask_path, "r") as f:
|
| self.frame_annots = json.load(f)
|
|
|
| if mask_area_frac_thresh <= 1.0:
|
|
|
| orig_w, orig_h = PILImage.open(video_frame_path).size
|
| area = orig_w * orig_h
|
|
|
| self.frame_annots = self.frame_annots["annotations"]
|
|
|
| rle_masks = []
|
| for frame_annot in self.frame_annots:
|
| if not frame_annot["area"] > 0:
|
| continue
|
| if ("uncertain_iou" in frame_annot) and (
|
| frame_annot["uncertain_iou"] < uncertain_iou
|
| ):
|
|
|
| continue
|
| if (
|
| mask_area_frac_thresh <= 1.0
|
| and (frame_annot["area"] / area) >= mask_area_frac_thresh
|
| ):
|
| continue
|
| rle_masks.append(frame_annot["segmentation"])
|
|
|
| self.segments = LazySegments()
|
| for i, rle in enumerate(rle_masks):
|
| self.segments[i] = rle
|
|
|
| def load(self, frame_idx):
|
| return self.segments
|
|
|