| """SHIFT result writer.""" |
|
|
| from __future__ import annotations |
|
|
| import io |
| import itertools |
| import json |
| import os |
| from collections import defaultdict |
|
|
| import numpy as np |
| from PIL import Image |
|
|
| from vis4d.common.array import array_to_numpy |
| from vis4d.common.imports import SCALABEL_AVAILABLE |
| from vis4d.common.typing import ( |
| ArrayLike, |
| GenericFunc, |
| MetricLogs, |
| NDArrayNumber, |
| ) |
| from vis4d.data.datasets.shift import shift_det_map |
| from vis4d.data.io import DataBackend, ZipBackend |
| from vis4d.eval.base import Evaluator |
|
|
| if SCALABEL_AVAILABLE: |
| from scalabel.label.transforms import mask_to_rle, xyxy_to_box2d |
| from scalabel.label.typing import Dataset, Frame, Label |
| else: |
| raise ImportError("scalabel is not installed.") |
|
|
|
|
| class SHIFTMultitaskWriter(Evaluator): |
| """SHIFT result writer for online evaluation.""" |
|
|
| inverse_cat_map = {v: k for k, v in shift_det_map.items()} |
|
|
| def __init__( |
| self, |
| output_dir: str, |
| submission_file: str = "submission.zip", |
| ) -> None: |
| """Creates a new writer. |
| |
| Args: |
| output_dir (str): Output directory. |
| submission_file (str): Submission file name. Defaults to |
| "submission.zip". |
| """ |
| super().__init__() |
| assert submission_file.endswith( |
| ".zip" |
| ), "Submission file must be a zip file." |
| self.backend: DataBackend = ZipBackend() |
| self.output_path = os.path.join(output_dir, submission_file) |
| self.frames_det_2d: list[Frame] = [] |
| self.frames_det_3d: list[Frame] = [] |
| self.sample_counts: defaultdict[str, int] = defaultdict(int) |
|
|
| def _write_sem_mask( |
| self, sem_mask: NDArrayNumber, sample_name: str, video_name: str |
| ) -> None: |
| """Write semantic mask. |
| |
| Args: |
| sem_mask (NDArrayNumber): Predicted semantic mask, shape (H, W). |
| sample_name (str): Sample name. |
| video_name (str): Video name. |
| """ |
| image = Image.fromarray(sem_mask.astype("uint8"), mode="L") |
| image_bytes = io.BytesIO() |
| image.save(image_bytes, format="PNG") |
| self.backend.set( |
| f"{self.output_path}/semseg/{video_name}/{sample_name}", |
| image_bytes.getvalue(), |
| mode="w", |
| ) |
|
|
| def _write_depth( |
| self, depth_map: NDArrayNumber, sample_name: str, video_name: str |
| ) -> None: |
| """Write depth map. |
| |
| Args: |
| depth_map (NDArrayNumber): Predicted depth map, shape (H, W). |
| sample_name (str): Sample name. |
| video_name (str): Video name. |
| """ |
| depth_map = np.clip(depth_map / 80.0 * 255.0, 0, 255) |
| image = Image.fromarray(depth_map.astype("uint8"), mode="L") |
| image_bytes = io.BytesIO() |
| image.save(image_bytes, format="PNG") |
| self.backend.set( |
| f"{self.output_path}/depth/{video_name}/{sample_name}", |
| image_bytes.getvalue(), |
| mode="w", |
| ) |
|
|
| def _write_flow( |
| self, flow: NDArrayNumber, sample_name: str, video_name: str |
| ) -> None: |
| """Write semantic mask. |
| |
| Args: |
| flow (NDArrayNumber): Predicted optical flow, shape (H, W, 2). |
| sample_name (str): Sample name. |
| video_name (str): Video name. |
| """ |
| raise NotImplementedError |
|
|
| def process_batch( |
| self, |
| frame_ids: list[int], |
| sample_names: list[str], |
| sequence_names: list[str], |
| pred_sem_mask: list[ArrayLike] | None = None, |
| pred_depth: list[ArrayLike] | None = None, |
| pred_flow: list[ArrayLike] | None = None, |
| pred_boxes2d: list[ArrayLike] | None = None, |
| pred_boxes2d_classes: list[ArrayLike] | None = None, |
| pred_boxes2d_scores: list[ArrayLike] | None = None, |
| pred_boxes2d_track_ids: list[ArrayLike] | None = None, |
| pred_instance_masks: list[ArrayLike] | None = None, |
| ) -> None: |
| """Process SHIFT results. |
| |
| You can omit some of the predictions if they are not used. |
| |
| Args: |
| frame_ids (list[int]): Frame IDs. |
| sample_names (list[str]): Sample names. |
| sequence_names (list[str]): Sequence names. |
| pred_sem_mask (list[ArrayLike], optional): Predicted semantic |
| masks, each in shape (C, H, W) or (H, W). Defaults to None. |
| pred_depth (list[ArrayLike], optional): Predicted depth maps, |
| each in shape (H, W), with meter unit. Defaults to None. |
| pred_flow (list[ArrayLike], optional): Predicted optical flows, |
| each in shape (H, W, 2). Defaults to None. |
| pred_boxes2d (list[ArrayLike], optional): Predicted 2D boxes, |
| each in shape (N, 4). Defaults to None. |
| pred_boxes2d_classes (list[ArrayLike], optional): Predicted |
| 2D box classes, each in shape (N,). Defaults to None. |
| pred_boxes2d_scores (list[ArrayLike], optional): Predicted |
| 2D box scores, each in shape (N,). Defaults to None. |
| pred_boxes2d_track_ids (list[ArrayLike], optional): Predicted |
| 2D box track IDs, each in shape (N,). Defaults to None. |
| pred_instance_masks (list[ArrayLike], optional): Predicted |
| instance masks, each in shape (N, H, W). Defaults to None. |
| """ |
| for i, (frame_id, sample_name, sequence_name) in enumerate( |
| zip(frame_ids, sample_names, sequence_names) |
| ): |
| if pred_sem_mask is not None: |
| sem_mask_ = array_to_numpy( |
| pred_sem_mask[i], |
| n_dims=None, |
| dtype=np.float32, |
| ) |
| if len(sem_mask_.shape) == 3: |
| sem_mask = sem_mask_.argmax(axis=0) |
| else: |
| sem_mask = sem_mask_.astype(np.uint8) |
| semseg_filename = sample_name.replace(".jpg", ".png").replace( |
| "img", "semseg" |
| ) |
| self._write_sem_mask(sem_mask, semseg_filename, sequence_name) |
| self.sample_counts["semseg"] += 1 |
| if pred_depth is not None: |
| depth = array_to_numpy( |
| pred_depth[i], n_dims=None, dtype=np.float32 |
| ) |
| depth_filename = sample_name.replace(".jpg", ".png").replace( |
| "img", "depth" |
| ) |
| self._write_depth(depth, depth_filename, sequence_name) |
| self.sample_counts["depth"] += 1 |
| if pred_flow is not None: |
| flow = array_to_numpy( |
| pred_flow[i], n_dims=None, dtype=np.float32 |
| ) |
| self._write_flow(flow, sample_name, sequence_name) |
| self.sample_counts["flow"] += 1 |
| if ( |
| pred_boxes2d is not None |
| and pred_boxes2d_classes is not None |
| and pred_boxes2d_scores is not None |
| ): |
| labels = [] |
| if pred_instance_masks: |
| masks = array_to_numpy( |
| pred_instance_masks[i], n_dims=None, dtype=np.float32 |
| ) |
| if pred_boxes2d_track_ids: |
| track_ids = array_to_numpy( |
| pred_boxes2d_track_ids[i], |
| n_dims=None, |
| dtype=np.int64, |
| ) |
| for box, score, class_id in zip( |
| pred_boxes2d[i], |
| pred_boxes2d_scores[i], |
| pred_boxes2d_classes[i], |
| ): |
| box2d = xyxy_to_box2d(*box.tolist()) |
| if pred_instance_masks: |
| rle = mask_to_rle( |
| (masks[class_id] > 0.0).astype(np.uint8) |
| ) |
| else: |
| rle = None |
|
|
| if pred_boxes2d_track_ids: |
| track_id = str(int(track_ids[0])) |
| else: |
| track_id = None |
|
|
| label = Label( |
| box2d=box2d, |
| category=( |
| self.inverse_cat_map[int(class_id)] |
| if self.inverse_cat_map != {} |
| else str(class_id) |
| ), |
| score=float(score), |
| rle=rle, |
| id=track_id, |
| ) |
| labels.append(label) |
| frame = Frame( |
| name=sample_name, |
| videoName=sequence_name, |
| frameIndex=frame_id, |
| labels=labels, |
| ) |
| self.frames_det_2d.append(frame) |
| self.sample_counts["det_2d"] += 1 |
|
|
| def gather(self, gather_func: GenericFunc) -> None: |
| """Gather variables in case of distributed setting (if needed). |
| |
| Args: |
| gather_func (Callable[[Any], Any]): Gather function. |
| """ |
| all_preds = gather_func(self.frames_det_2d) |
| if all_preds is not None: |
| self.frames_det_2d = list(itertools.chain(*all_preds)) |
|
|
| def evaluate(self, metric: str) -> tuple[MetricLogs, str]: |
| """No evaluation locally.""" |
| return {}, "No evaluation locally." |
|
|
| def save(self, metric: str, output_dir: str) -> None: |
| """Save scalabel output to zip file. |
| |
| Raises: |
| ValueError: If the number of samples in each category is not the |
| same. |
| """ |
| |
| equal_size = True |
| for key in self.sample_counts: |
| if self.sample_counts[key] != len(self.frames_det_2d): |
| equal_size = False |
| break |
| if not equal_size: |
| raise ValueError( |
| "The number of samples in each category is not the same." |
| ) |
|
|
| |
| if len(self.frames_det_2d) > 0: |
| ds = Dataset(frames=self.frames_det_2d, groups=None, config=None) |
| ds_bytes = json.dumps(ds.dict()).encode("utf-8") |
| self.backend.set( |
| f"{self.output_path}/det_2d.json", ds_bytes, mode="w" |
| ) |
|
|
| self.backend.close() |
| print(f"Saved the submission file at {self.output_path}.") |
|
|