| | |
| | |
| |
|
| | |
| | |
| |
|
| | import glob |
| | import logging |
| | import os |
| | from dataclasses import dataclass |
| |
|
| | from typing import List, Optional |
| |
|
| | import pandas as pd |
| |
|
| | import torch |
| |
|
| | from iopath.common.file_io import g_pathmgr |
| |
|
| | from omegaconf.listconfig import ListConfig |
| |
|
| | from training.dataset.vos_segment_loader import ( |
| | JSONSegmentLoader, |
| | MultiplePNGSegmentLoader, |
| | PalettisedPNGSegmentLoader, |
| | SA1BSegmentLoader, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class VOSFrame: |
| | frame_idx: int |
| | image_path: str |
| | data: Optional[torch.Tensor] = None |
| | is_conditioning_only: Optional[bool] = False |
| |
|
| |
|
| | @dataclass |
| | class VOSVideo: |
| | video_name: str |
| | video_id: int |
| | frames: List[VOSFrame] |
| |
|
| | def __len__(self): |
| | return len(self.frames) |
| |
|
| |
|
| | class VOSRawDataset: |
| | def __init__(self): |
| | pass |
| |
|
| | def get_video(self, idx): |
| | raise NotImplementedError() |
| |
|
| |
|
| | class PNGRawDataset(VOSRawDataset): |
| | def __init__( |
| | self, |
| | img_folder, |
| | gt_folder, |
| | file_list_txt=None, |
| | excluded_videos_list_txt=None, |
| | sample_rate=1, |
| | is_palette=True, |
| | single_object_mode=False, |
| | truncate_video=-1, |
| | frames_sampling_mult=False, |
| | ): |
| | self.img_folder = img_folder |
| | self.gt_folder = gt_folder |
| | self.sample_rate = sample_rate |
| | self.is_palette = is_palette |
| | self.single_object_mode = single_object_mode |
| | self.truncate_video = truncate_video |
| |
|
| | |
| | if file_list_txt is not None: |
| | with g_pathmgr.open(file_list_txt, "r") as f: |
| | subset = [os.path.splitext(line.strip())[0] for line in f] |
| | else: |
| | subset = os.listdir(self.img_folder) |
| |
|
| | |
| | if excluded_videos_list_txt is not None: |
| | with g_pathmgr.open(excluded_videos_list_txt, "r") as f: |
| | excluded_files = [os.path.splitext(line.strip())[0] for line in f] |
| | else: |
| | excluded_files = [] |
| |
|
| | |
| | self.video_names = sorted( |
| | [video_name for video_name in subset if video_name not in excluded_files] |
| | ) |
| |
|
| | if self.single_object_mode: |
| | |
| | self.video_names = sorted( |
| | [ |
| | os.path.join(video_name, obj) |
| | for video_name in self.video_names |
| | for obj in os.listdir(os.path.join(self.gt_folder, video_name)) |
| | ] |
| | ) |
| |
|
| | if frames_sampling_mult: |
| | video_names_mult = [] |
| | for video_name in self.video_names: |
| | num_frames = len(os.listdir(os.path.join(self.img_folder, video_name))) |
| | video_names_mult.extend([video_name] * num_frames) |
| | self.video_names = video_names_mult |
| |
|
| | def get_video(self, idx): |
| | """ |
| | Given a VOSVideo object, return the mask tensors. |
| | """ |
| | video_name = self.video_names[idx] |
| |
|
| | if self.single_object_mode: |
| | video_frame_root = os.path.join( |
| | self.img_folder, os.path.dirname(video_name) |
| | ) |
| | else: |
| | video_frame_root = os.path.join(self.img_folder, video_name) |
| |
|
| | video_mask_root = os.path.join(self.gt_folder, video_name) |
| |
|
| | if self.is_palette: |
| | segment_loader = PalettisedPNGSegmentLoader(video_mask_root) |
| | else: |
| | segment_loader = MultiplePNGSegmentLoader( |
| | video_mask_root, self.single_object_mode |
| | ) |
| |
|
| | all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg"))) |
| | if self.truncate_video > 0: |
| | all_frames = all_frames[: self.truncate_video] |
| | frames = [] |
| | for _, fpath in enumerate(all_frames[:: self.sample_rate]): |
| | fid = int(os.path.basename(fpath).split(".")[0]) |
| | frames.append(VOSFrame(fid, image_path=fpath)) |
| | video = VOSVideo(video_name, idx, frames) |
| | return video, segment_loader |
| |
|
| | def __len__(self): |
| | return len(self.video_names) |
| |
|
| |
|
| | class SA1BRawDataset(VOSRawDataset): |
| | def __init__( |
| | self, |
| | img_folder, |
| | gt_folder, |
| | file_list_txt=None, |
| | excluded_videos_list_txt=None, |
| | num_frames=1, |
| | mask_area_frac_thresh=1.1, |
| | uncertain_iou=-1, |
| | ): |
| | self.img_folder = img_folder |
| | self.gt_folder = gt_folder |
| | self.num_frames = num_frames |
| | self.mask_area_frac_thresh = mask_area_frac_thresh |
| | self.uncertain_iou = uncertain_iou |
| |
|
| | |
| | if file_list_txt is not None: |
| | with g_pathmgr.open(file_list_txt, "r") as f: |
| | subset = [os.path.splitext(line.strip())[0] for line in f] |
| | else: |
| | subset = os.listdir(self.img_folder) |
| | subset = [ |
| | path.split(".")[0] for path in subset if path.endswith(".jpg") |
| | ] |
| |
|
| | |
| | if excluded_videos_list_txt is not None: |
| | with g_pathmgr.open(excluded_videos_list_txt, "r") as f: |
| | excluded_files = [os.path.splitext(line.strip())[0] for line in f] |
| | else: |
| | excluded_files = [] |
| |
|
| | |
| | self.video_names = [ |
| | video_name for video_name in subset if video_name not in excluded_files |
| | ] |
| |
|
| | def get_video(self, idx): |
| | """ |
| | Given a VOSVideo object, return the mask tensors. |
| | """ |
| | video_name = self.video_names[idx] |
| |
|
| | video_frame_path = os.path.join(self.img_folder, video_name + ".jpg") |
| | video_mask_path = os.path.join(self.gt_folder, video_name + ".json") |
| |
|
| | segment_loader = SA1BSegmentLoader( |
| | video_mask_path, |
| | mask_area_frac_thresh=self.mask_area_frac_thresh, |
| | video_frame_path=video_frame_path, |
| | uncertain_iou=self.uncertain_iou, |
| | ) |
| |
|
| | frames = [] |
| | for frame_idx in range(self.num_frames): |
| | frames.append(VOSFrame(frame_idx, image_path=video_frame_path)) |
| | video_name = video_name.split("_")[-1] |
| | |
| | video = VOSVideo(video_name, int(video_name), frames) |
| | return video, segment_loader |
| |
|
| | def __len__(self): |
| | return len(self.video_names) |
| |
|
| |
|
| | class JSONRawDataset(VOSRawDataset): |
| | """ |
| | Dataset where the annotation in the format of SA-V json files |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | img_folder, |
| | gt_folder, |
| | file_list_txt=None, |
| | excluded_videos_list_txt=None, |
| | sample_rate=1, |
| | rm_unannotated=True, |
| | ann_every=1, |
| | frames_fps=24, |
| | ): |
| | self.gt_folder = gt_folder |
| | self.img_folder = img_folder |
| | self.sample_rate = sample_rate |
| | self.rm_unannotated = rm_unannotated |
| | self.ann_every = ann_every |
| | self.frames_fps = frames_fps |
| |
|
| | |
| | excluded_files = [] |
| | if excluded_videos_list_txt is not None: |
| | if isinstance(excluded_videos_list_txt, str): |
| | excluded_videos_lists = [excluded_videos_list_txt] |
| | elif isinstance(excluded_videos_list_txt, ListConfig): |
| | excluded_videos_lists = list(excluded_videos_list_txt) |
| | else: |
| | raise NotImplementedError |
| |
|
| | for excluded_videos_list_txt in excluded_videos_lists: |
| | with open(excluded_videos_list_txt, "r") as f: |
| | excluded_files.extend( |
| | [os.path.splitext(line.strip())[0] for line in f] |
| | ) |
| | excluded_files = set(excluded_files) |
| |
|
| | |
| | if file_list_txt is not None: |
| | with g_pathmgr.open(file_list_txt, "r") as f: |
| | subset = [os.path.splitext(line.strip())[0] for line in f] |
| | else: |
| | subset = os.listdir(self.img_folder) |
| |
|
| | self.video_names = sorted( |
| | [video_name for video_name in subset if video_name not in excluded_files] |
| | ) |
| |
|
| | def get_video(self, video_idx): |
| | """ |
| | Given a VOSVideo object, return the mask tensors. |
| | """ |
| | video_name = self.video_names[video_idx] |
| | video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json") |
| | segment_loader = JSONSegmentLoader( |
| | video_json_path=video_json_path, |
| | ann_every=self.ann_every, |
| | frames_fps=self.frames_fps, |
| | ) |
| |
|
| | frame_ids = [ |
| | int(os.path.splitext(frame_name)[0]) |
| | for frame_name in sorted( |
| | os.listdir(os.path.join(self.img_folder, video_name)) |
| | ) |
| | ] |
| |
|
| | frames = [ |
| | VOSFrame( |
| | frame_id, |
| | image_path=os.path.join( |
| | self.img_folder, f"{video_name}/%05d.jpg" % (frame_id) |
| | ), |
| | ) |
| | for frame_id in frame_ids[:: self.sample_rate] |
| | ] |
| |
|
| | if self.rm_unannotated: |
| | |
| | valid_frame_ids = [ |
| | i * segment_loader.ann_every |
| | for i, annot in enumerate(segment_loader.frame_annots) |
| | if annot is not None and None not in annot |
| | ] |
| | frames = [f for f in frames if f.frame_idx in valid_frame_ids] |
| |
|
| | video = VOSVideo(video_name, video_idx, frames) |
| | return video, segment_loader |
| |
|
| | def __len__(self): |
| | return len(self.video_names) |
| |
|