Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| # Note: This file has been barrowed from facebookresearch/slowfast repo. And it is used to add the bounding boxes and predictions to the frame. | |
| # TODO: Migrate this into the core PyTorchVideo libarary. | |
| from __future__ import annotations | |
| import itertools | |
| # import logging | |
| from types import SimpleNamespace | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from detectron2.utils.visualizer import Visualizer | |
| # logger = logging.getLogger(__name__) | |
| def _create_text_labels( | |
| classes: List[int], | |
| scores: List[float], | |
| class_names: List[str], | |
| ground_truth: bool = False, | |
| ) -> List[str]: | |
| """ | |
| Create text labels. | |
| Args: | |
| classes (list[int]): a list of class ids for each example. | |
| scores (list[float] or None): list of scores for each example. | |
| class_names (list[str]): a list of class names, ordered by their ids. | |
| ground_truth (bool): whether the labels are ground truth. | |
| Returns: | |
| labels (list[str]): formatted text labels. | |
| """ | |
| try: | |
| labels = [class_names.get(c, "n/a") for c in classes] | |
| except IndexError: | |
| # logger.error("Class indices get out of range: {}".format(classes)) | |
| return None | |
| if ground_truth: | |
| labels = ["[{}] {}".format("GT", label) for label in labels] | |
| elif scores is not None: | |
| assert len(classes) == len(scores) | |
| labels = ["[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)] | |
| return labels | |
| class ImgVisualizer(Visualizer): | |
| def __init__( | |
| self, img_rgb: torch.Tensor, meta: Optional[SimpleNamespace] = None, **kwargs | |
| ) -> None: | |
| """ | |
| See https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py | |
| for more details. | |
| Args: | |
| img_rgb: a tensor or numpy array of shape (H, W, C), where H and W correspond to | |
| the height and width of the image respectively. C is the number of | |
| color channels. The image is required to be in RGB format since that | |
| is a requirement of the Matplotlib library. The image is also expected | |
| to be in the range [0, 255]. | |
| meta (MetadataCatalog): image metadata. | |
| See https://github.com/facebookresearch/detectron2/blob/81d5a87763bfc71a492b5be89b74179bd7492f6b/detectron2/data/catalog.py#L90 | |
| """ | |
| super(ImgVisualizer, self).__init__(img_rgb, meta, **kwargs) | |
| def draw_text( | |
| self, | |
| text: str, | |
| position: List[int], | |
| *, | |
| font_size: Optional[int] = None, | |
| color: str = "w", | |
| horizontal_alignment: str = "center", | |
| vertical_alignment: str = "bottom", | |
| box_facecolor: str = "black", | |
| alpha: float = 0.5, | |
| ) -> None: | |
| """ | |
| Draw text at the specified position. | |
| Args: | |
| text (str): the text to draw on image. | |
| position (list of 2 ints): the x,y coordinate to place the text. | |
| font_size (Optional[int]): font of the text. If not provided, a font size | |
| proportional to the image width is calculated and used. | |
| color (str): color of the text. Refer to `matplotlib.colors` for full list | |
| of formats that are accepted. | |
| horizontal_alignment (str): see `matplotlib.text.Text`. | |
| vertical_alignment (str): see `matplotlib.text.Text`. | |
| box_facecolor (str): color of the box wrapped around the text. Refer to | |
| `matplotlib.colors` for full list of formats that are accepted. | |
| alpha (float): transparency level of the box. | |
| """ | |
| if not font_size: | |
| font_size = self._default_font_size | |
| x, y = position | |
| self.output.ax.text( | |
| x, | |
| y, | |
| text, | |
| size=font_size * self.output.scale, | |
| family="monospace", | |
| bbox={ | |
| "facecolor": box_facecolor, | |
| "alpha": alpha, | |
| "pad": 0.7, | |
| "edgecolor": "none", | |
| }, | |
| verticalalignment=vertical_alignment, | |
| horizontalalignment=horizontal_alignment, | |
| color=color, | |
| zorder=10, | |
| ) | |
| def draw_multiple_text( | |
| self, | |
| text_ls: List[str], | |
| box_coordinate: torch.Tensor, | |
| *, | |
| top_corner: bool = True, | |
| font_size: Optional[int] = None, | |
| color: str = "w", | |
| box_facecolors: str = "black", | |
| alpha: float = 0.5, | |
| ) -> None: | |
| """ | |
| Draw a list of text labels for some bounding box on the image. | |
| Args: | |
| text_ls (list of strings): a list of text labels. | |
| box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
| coordinates of the box. | |
| top_corner (bool): If True, draw the text labels at (x_left, y_top) of the box. | |
| Else, draw labels at (x_left, y_bottom). | |
| font_size (Optional[int]): font of the text. If not provided, a font size | |
| proportional to the image width is calculated and used. | |
| color (str): color of the text. Refer to `matplotlib.colors` for full list | |
| of formats that are accepted. | |
| box_facecolors (str): colors of the box wrapped around the text. Refer to | |
| `matplotlib.colors` for full list of formats that are accepted. | |
| alpha (float): transparency level of the box. | |
| """ | |
| if not isinstance(box_facecolors, list): | |
| box_facecolors = [box_facecolors] * len(text_ls) | |
| assert len(box_facecolors) == len( | |
| text_ls | |
| ), "Number of colors provided is not equal to the number of text labels." | |
| if not font_size: | |
| font_size = self._default_font_size | |
| text_box_width = font_size + font_size // 2 | |
| # If the texts does not fit in the assigned location, | |
| # we split the text and draw it in another place. | |
| if top_corner: | |
| num_text_split = self._align_y_top( | |
| box_coordinate, len(text_ls), text_box_width | |
| ) | |
| y_corner = 1 | |
| else: | |
| num_text_split = len(text_ls) - self._align_y_bottom( | |
| box_coordinate, len(text_ls), text_box_width | |
| ) | |
| y_corner = 3 | |
| text_color_sorted = sorted( | |
| zip(text_ls, box_facecolors), key=lambda x: x[0], reverse=True | |
| ) | |
| if len(text_color_sorted) != 0: | |
| text_ls, box_facecolors = zip(*text_color_sorted) | |
| else: | |
| text_ls, box_facecolors = [], [] | |
| text_ls, box_facecolors = list(text_ls), list(box_facecolors) | |
| self.draw_multiple_text_upward( | |
| text_ls[:num_text_split][::-1], | |
| box_coordinate, | |
| y_corner=y_corner, | |
| font_size=font_size, | |
| color=color, | |
| box_facecolors=box_facecolors[:num_text_split][::-1], | |
| alpha=alpha, | |
| ) | |
| self.draw_multiple_text_downward( | |
| text_ls[num_text_split:], | |
| box_coordinate, | |
| y_corner=y_corner, | |
| font_size=font_size, | |
| color=color, | |
| box_facecolors=box_facecolors[num_text_split:], | |
| alpha=alpha, | |
| ) | |
| def draw_multiple_text_upward( | |
| self, | |
| text_ls: List[str], | |
| box_coordinate: torch.Tensor, | |
| *, | |
| y_corner: int = 1, | |
| font_size: Optional[int] = None, | |
| color: str = "w", | |
| box_facecolors: str = "black", | |
| alpha: float = 0.5, | |
| ) -> None: | |
| """ | |
| Draw a list of text labels for some bounding box on the image in upward direction. | |
| The next text label will be on top of the previous one. | |
| Args: | |
| text_ls (list of strings): a list of text labels. | |
| box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
| coordinates of the box. | |
| y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of | |
| the box to draw labels around. | |
| font_size (Optional[int]): font of the text. If not provided, a font size | |
| proportional to the image width is calculated and used. | |
| color (str): color of the text. Refer to `matplotlib.colors` for full list | |
| of formats that are accepted. | |
| box_facecolors (str or list of strs): colors of the box wrapped around the | |
| text. Refer to `matplotlib.colors` for full list of formats that | |
| are accepted. | |
| alpha (float): transparency level of the box. | |
| """ | |
| if not isinstance(box_facecolors, list): | |
| box_facecolors = [box_facecolors] * len(text_ls) | |
| assert len(box_facecolors) == len( | |
| text_ls | |
| ), "Number of colors provided is not equal to the number of text labels." | |
| assert y_corner in [1, 3], "Y_corner must be either 1 or 3" | |
| if not font_size: | |
| font_size = self._default_font_size | |
| x, horizontal_alignment = self._align_x_coordinate(box_coordinate) | |
| y = box_coordinate[y_corner].item() | |
| for i, text in enumerate(text_ls): | |
| self.draw_text( | |
| text, | |
| (x, y), | |
| font_size=font_size, | |
| color=color, | |
| horizontal_alignment=horizontal_alignment, | |
| vertical_alignment="bottom", | |
| box_facecolor=box_facecolors[i], | |
| alpha=alpha, | |
| ) | |
| y -= font_size + font_size // 2 | |
| def draw_multiple_text_downward( | |
| self, | |
| text_ls: List[str], | |
| box_coordinate: torch.Tensor, | |
| *, | |
| y_corner: int = 1, | |
| font_size: Optional[int] = None, | |
| color: str = "w", | |
| box_facecolors: str = "black", | |
| alpha: float = 0.5, | |
| ) -> None: | |
| """ | |
| Draw a list of text labels for some bounding box on the image in downward direction. | |
| The next text label will be below the previous one. | |
| Args: | |
| text_ls (list of strings): a list of text labels. | |
| box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
| coordinates of the box. | |
| y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of | |
| the box to draw labels around. | |
| font_size (Optional[int]): font of the text. If not provided, a font size | |
| proportional to the image width is calculated and used. | |
| color (str): color of the text. Refer to `matplotlib.colors` for full list | |
| of formats that are accepted. | |
| box_facecolors (str): colors of the box wrapped around the text. Refer to | |
| `matplotlib.colors` for full list of formats that are accepted. | |
| alpha (float): transparency level of the box. | |
| """ | |
| if not isinstance(box_facecolors, list): | |
| box_facecolors = [box_facecolors] * len(text_ls) | |
| assert len(box_facecolors) == len( | |
| text_ls | |
| ), "Number of colors provided is not equal to the number of text labels." | |
| assert y_corner in [1, 3], "Y_corner must be either 1 or 3" | |
| if not font_size: | |
| font_size = self._default_font_size | |
| x, horizontal_alignment = self._align_x_coordinate(box_coordinate) | |
| y = box_coordinate[y_corner].item() | |
| for i, text in enumerate(text_ls): | |
| self.draw_text( | |
| text, | |
| (x, y), | |
| font_size=font_size, | |
| color=color, | |
| horizontal_alignment=horizontal_alignment, | |
| vertical_alignment="top", | |
| box_facecolor=box_facecolors[i], | |
| alpha=alpha, | |
| ) | |
| y += font_size + font_size // 2 | |
| def _align_x_coordinate(self, box_coordinate: torch.Tensor) -> Tuple[float, str]: | |
| """ | |
| Choose an x-coordinate from the box to make sure the text label | |
| does not go out of frames. By default, the left x-coordinate is | |
| chosen and text is aligned left. If the box is too close to the | |
| right side of the image, then the right x-coordinate is chosen | |
| instead and the text is aligned right. | |
| Args: | |
| box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
| coordinates of the box. | |
| Returns: | |
| x_coordinate (float): the chosen x-coordinate. | |
| alignment (str): whether to align left or right. | |
| """ | |
| # If the x-coordinate is greater than 5/6 of the image width, | |
| # then we align test to the right of the box. This is | |
| # chosen by heuristics. | |
| if box_coordinate[0] > (self.output.width * 5) // 6: | |
| return box_coordinate[2], "right" | |
| return box_coordinate[0], "left" | |
| def _align_y_top( | |
| self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float | |
| ) -> int: | |
| """ | |
| Calculate the number of text labels to plot on top of the box | |
| without going out of frames. | |
| Args: | |
| box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
| coordinates of the box. | |
| num_text (int): the number of text labels to plot. | |
| textbox_width (float): the width of the box wrapped around text label. | |
| """ | |
| dist_to_top = box_coordinate[1] | |
| num_text_top = dist_to_top // textbox_width | |
| if isinstance(num_text_top, torch.Tensor): | |
| num_text_top = int(num_text_top.item()) | |
| return min(num_text, num_text_top) | |
| def _align_y_bottom( | |
| self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float | |
| ) -> int: | |
| """ | |
| Calculate the number of text labels to plot at the bottom of the box | |
| without going out of frames. | |
| Args: | |
| box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
| coordinates of the box. | |
| num_text (int): the number of text labels to plot. | |
| textbox_width (float): the width of the box wrapped around text label. | |
| """ | |
| dist_to_bottom = self.output.height - box_coordinate[3] | |
| num_text_bottom = dist_to_bottom // textbox_width | |
| if isinstance(num_text_bottom, torch.Tensor): | |
| num_text_bottom = int(num_text_bottom.item()) | |
| return min(num_text, num_text_bottom) | |
| class VideoVisualizer: | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| class_names: Dict, | |
| top_k: int = 1, | |
| colormap: str = "rainbow", | |
| thres: float = 0.7, | |
| lower_thres: float = 0.3, | |
| common_class_names: Optional[List[str]] = None, | |
| mode: str = "top-k", | |
| ) -> None: | |
| """ | |
| Args: | |
| num_classes (int): total number of classes. | |
| class_names (dict): Dict mapping classID to name. | |
| top_k (int): number of top predicted classes to plot. | |
| colormap (str): the colormap to choose color for class labels from. | |
| See https://matplotlib.org/tutorials/colors/colormaps.html | |
| thres (float): threshold for picking predicted classes to visualize. | |
| lower_thres (Optional[float]): If `common_class_names` if given, | |
| this `lower_thres` will be applied to uncommon classes and | |
| `thres` will be applied to classes in `common_class_names`. | |
| common_class_names (Optional[list of str]): list of common class names | |
| to apply `thres`. Class names not included in `common_class_names` will | |
| have `lower_thres` as a threshold. If None, all classes will have | |
| `thres` as a threshold. This is helpful for model trained on | |
| highly imbalanced dataset. | |
| mode (str): Supported modes are {"top-k", "thres"}. | |
| This is used for choosing predictions for visualization. | |
| """ | |
| assert mode in ["top-k", "thres"], "Mode {} is not supported.".format(mode) | |
| self.mode = mode | |
| self.num_classes = num_classes | |
| self.class_names = class_names | |
| self.top_k = top_k | |
| self.thres = thres | |
| self.lower_thres = lower_thres | |
| if mode == "thres": | |
| self._get_thres_array(common_class_names=common_class_names) | |
| self.color_map = plt.get_cmap(colormap) | |
| def _get_color(self, class_id: int) -> List[float]: | |
| """ | |
| Get color for a class id. | |
| Args: | |
| class_id (int): class id. | |
| """ | |
| return self.color_map(class_id / self.num_classes)[:3] | |
| def draw_one_frame( | |
| self, | |
| frame: Union[torch.Tensor, np.ndarray], | |
| preds: Union[torch.Tensor, List[float]], | |
| bboxes: Optional[torch.Tensor] = None, | |
| alpha: float = 0.5, | |
| text_alpha: float = 0.7, | |
| ground_truth: bool = False, | |
| ) -> np.ndarray: | |
| """ | |
| Draw labels and bouding boxes for one image. By default, predicted | |
| labels are drawn in the top left corner of the image or corresponding | |
| bounding boxes. For ground truth labels (setting True for ground_truth flag), | |
| labels will be drawn in the bottom left corner. | |
| Args: | |
| frame (array-like): a tensor or numpy array of shape (H, W, C), | |
| where H and W correspond to | |
| the height and width of the image respectively. C is the number of | |
| color channels. The image is required to be in RGB format since that | |
| is a requirement of the Matplotlib library. The image is also expected | |
| to be in the range [0, 255]. | |
| preds (tensor or list): If ground_truth is False, provide a float tensor of | |
| shape (num_boxes, num_classes) that contains all of the confidence | |
| scores of the model. For recognition task, input shape can be (num_classes,). | |
| To plot true label (ground_truth is True), preds is a list contains int32 | |
| of the shape (num_boxes, true_class_ids) or (true_class_ids,). | |
| bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates | |
| of the bounding boxes. | |
| alpha (Optional[float]): transparency level of the bounding boxes. | |
| text_alpha (Optional[float]): transparency level of the box wrapped around | |
| text labels. | |
| ground_truth (bool): whether the prodived bounding boxes are ground-truth. | |
| Returns: | |
| An image with bounding box annotations and corresponding bbox | |
| labels plotted on it. | |
| """ | |
| if isinstance(preds, torch.Tensor): | |
| if preds.ndim == 1: | |
| preds = preds.unsqueeze(0) | |
| n_instances = preds.shape[0] | |
| elif isinstance(preds, list): | |
| n_instances = len(preds) | |
| else: | |
| # logger.error("Unsupported type of prediction input.") | |
| return | |
| if ground_truth: | |
| top_scores, top_classes = [None] * n_instances, preds | |
| elif self.mode == "top-k": | |
| top_scores, top_classes = torch.topk(preds, k=self.top_k) | |
| top_scores, top_classes = top_scores.tolist(), top_classes.tolist() | |
| elif self.mode == "thres": | |
| top_scores, top_classes = [], [] | |
| for pred in preds: | |
| mask = pred >= self.thres | |
| top_scores.append(pred[mask].tolist()) | |
| top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist() | |
| top_classes.append(top_class) | |
| # Create labels top k predicted classes with their scores. | |
| text_labels = [] | |
| for i in range(n_instances): | |
| text_labels.append( | |
| _create_text_labels( | |
| top_classes[i], | |
| top_scores[i], | |
| self.class_names, | |
| ground_truth=ground_truth, | |
| ) | |
| ) | |
| frame_visualizer = ImgVisualizer(frame, meta=None) | |
| font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 25, 5), 9) | |
| top_corner = not ground_truth | |
| if bboxes is not None: | |
| assert len(preds) == len( | |
| bboxes | |
| ), "Encounter {} predictions and {} bounding boxes".format( | |
| len(preds), len(bboxes) | |
| ) | |
| for i, box in enumerate(bboxes): | |
| text = text_labels[i] | |
| pred_class = top_classes[i] | |
| colors = [self._get_color(pred) for pred in pred_class] | |
| box_color = "r" if ground_truth else "g" | |
| line_style = "--" if ground_truth else "-." | |
| frame_visualizer.draw_box( | |
| box, | |
| alpha=alpha, | |
| edge_color=box_color, | |
| line_style=line_style, | |
| ) | |
| frame_visualizer.draw_multiple_text( | |
| text, | |
| box, | |
| top_corner=top_corner, | |
| font_size=font_size, | |
| box_facecolors=colors, | |
| alpha=text_alpha, | |
| ) | |
| else: | |
| text = text_labels[0] | |
| pred_class = top_classes[0] | |
| colors = [self._get_color(pred) for pred in pred_class] | |
| frame_visualizer.draw_multiple_text( | |
| text, | |
| torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]), | |
| top_corner=top_corner, | |
| font_size=font_size, | |
| box_facecolors=colors, | |
| alpha=text_alpha, | |
| ) | |
| return frame_visualizer.output.get_image() | |
| def draw_clip_range( | |
| self, | |
| frames: Union[torch.Tensor, np.ndarray], | |
| preds: Union[torch.Tensor, List[float]], | |
| bboxes: Optional[torch.Tensor] = None, | |
| text_alpha: float = 0.5, | |
| ground_truth: bool = False, | |
| keyframe_idx: Optional[int] = None, | |
| draw_range: Optional[List[int]] = None, | |
| repeat_frame: int = 1, | |
| ) -> List[np.ndarray]: | |
| """ | |
| Draw predicted labels or ground truth classes to clip. | |
| Draw bouding boxes to clip if bboxes is provided. Boxes will gradually | |
| fade in and out the clip, centered around the clip's central frame, | |
| within the provided `draw_range`. | |
| Args: | |
| frames (array-like): video data in the shape (T, H, W, C). | |
| preds (tensor): a tensor of shape (num_boxes, num_classes) that | |
| contains all of the confidence scores of the model. For recognition | |
| task or for ground_truth labels, input shape can be (num_classes,). | |
| bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates | |
| of the bounding boxes. | |
| text_alpha (float): transparency label of the box wrapped around text labels. | |
| ground_truth (bool): whether the prodived bounding boxes are ground-truth. | |
| keyframe_idx (int): the index of keyframe in the clip. | |
| draw_range (Optional[list[ints]): only draw frames in range | |
| [start_idx, end_idx] inclusively in the clip. If None, draw on | |
| the entire clip. | |
| repeat_frame (int): repeat each frame in draw_range for `repeat_frame` | |
| time for slow-motion effect. | |
| Returns: | |
| A list of frames with bounding box annotations and corresponding | |
| bbox labels ploted on them. | |
| """ | |
| if draw_range is None: | |
| draw_range = [0, len(frames) - 1] | |
| if draw_range is not None: | |
| draw_range[0] = max(0, draw_range[0]) | |
| left_frames = frames[: draw_range[0]] | |
| right_frames = frames[draw_range[1] + 1 :] | |
| draw_frames = frames[draw_range[0] : draw_range[1] + 1] | |
| if keyframe_idx is None: | |
| keyframe_idx = len(frames) // 2 | |
| img_ls = ( | |
| list(left_frames) | |
| + self.draw_clip( | |
| draw_frames, | |
| preds, | |
| bboxes=bboxes, | |
| text_alpha=text_alpha, | |
| ground_truth=ground_truth, | |
| keyframe_idx=keyframe_idx - draw_range[0], | |
| repeat_frame=repeat_frame, | |
| ) | |
| + list(right_frames) | |
| ) | |
| return img_ls | |
| def draw_clip( | |
| self, | |
| frames: Union[torch.Tensor, np.ndarray], | |
| preds: Union[torch.Tensor, List[float]], | |
| bboxes: Optional[torch.Tensor] = None, | |
| text_alpha: float = 0.5, | |
| ground_truth: bool = False, | |
| keyframe_idx: Optional[int] = None, | |
| repeat_frame: int = 1, | |
| ) -> List[np.ndarray]: | |
| """ | |
| Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip | |
| if bboxes is provided. Boxes will gradually fade in and out the clip, centered | |
| around the clip's central frame. | |
| Args: | |
| frames (array-like): video data in the shape (T, H, W, C). | |
| preds (tensor): a tensor of shape (num_boxes, num_classes) that contains | |
| all of the confidence scores of the model. For recognition task or for | |
| ground_truth labels, input shape can be (num_classes,). | |
| bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates | |
| of the bounding boxes. | |
| text_alpha (float): transparency label of the box wrapped around text labels. | |
| ground_truth (bool): whether the prodived bounding boxes are ground-truth. | |
| keyframe_idx (int): the index of keyframe in the clip. | |
| repeat_frame (int): repeat each frame in draw_range for `repeat_frame` | |
| time for slow-motion effect. | |
| Returns: | |
| A list of frames with bounding box annotations and corresponding | |
| bbox labels plotted on them. | |
| """ | |
| assert repeat_frame >= 1, "`repeat_frame` must be a positive integer." | |
| repeated_seq = range(0, len(frames)) | |
| repeated_seq = list( | |
| itertools.chain.from_iterable( | |
| itertools.repeat(x, repeat_frame) for x in repeated_seq | |
| ) | |
| ) | |
| frames, adjusted = self._adjust_frames_type(frames) | |
| if keyframe_idx is None: | |
| half_left = len(repeated_seq) // 2 | |
| half_right = (len(repeated_seq) + 1) // 2 | |
| else: | |
| mid = int((keyframe_idx / len(frames)) * len(repeated_seq)) | |
| half_left = mid | |
| half_right = len(repeated_seq) - mid | |
| alpha_ls = np.concatenate( | |
| [ | |
| np.linspace(0, 1, num=half_left), | |
| np.linspace(1, 0, num=half_right), | |
| ] | |
| ) | |
| text_alpha = text_alpha | |
| frames = frames[repeated_seq] | |
| img_ls = [] | |
| for alpha, frame in zip(alpha_ls, frames): | |
| draw_img = self.draw_one_frame( | |
| frame, | |
| preds, | |
| bboxes, | |
| alpha=alpha, | |
| text_alpha=text_alpha, | |
| ground_truth=ground_truth, | |
| ) | |
| if adjusted: | |
| draw_img = draw_img.astype("float32") / 255 | |
| img_ls.append(draw_img) | |
| return img_ls | |
| def _adjust_frames_type( | |
| self, frames: torch.Tensor | |
| ) -> Tuple[List[np.ndarray], bool]: | |
| """ | |
| Modify video data to have dtype of uint8 and values range in [0, 255]. | |
| Args: | |
| frames (array-like): 4D array of shape (T, H, W, C). | |
| Returns: | |
| frames (list of frames): list of frames in range [0, 1]. | |
| adjusted (bool): whether the original frames need adjusted. | |
| """ | |
| assert ( | |
| frames is not None and len(frames) != 0 | |
| ), "Frames does not contain any values" | |
| frames = np.array(frames) | |
| assert np.array(frames).ndim == 4, "Frames must have 4 dimensions" | |
| adjusted = False | |
| if frames.dtype in [np.float32, np.float64]: | |
| frames *= 255 | |
| frames = frames.astype(np.uint8) | |
| adjusted = True | |
| return frames, adjusted | |
| def _get_thres_array(self, common_class_names: Optional[List[str]] = None) -> None: | |
| """ | |
| Compute a thresholds array for all classes based on `self.thes` and `self.lower_thres`. | |
| Args: | |
| common_class_names (Optional[list of str]): a list of common class names. | |
| """ | |
| common_class_ids = [] | |
| if common_class_names is not None: | |
| common_classes = set(common_class_names) | |
| for key, name in self.class_names.items(): | |
| if name in common_classes: | |
| common_class_ids.append(key) | |
| else: | |
| common_class_ids = list(range(self.num_classes)) | |
| thres_array = np.full(shape=(self.num_classes,), fill_value=self.lower_thres) | |
| thres_array[common_class_ids] = self.thres | |
| self.thres = torch.from_numpy(thres_array) |