| | |
| | |
| |
|
| | |
| | |
| | import os |
| | import numpy as np |
| | import imageio |
| | import torch |
| |
|
| | from matplotlib import cm |
| | import torch.nn.functional as F |
| | import torchvision.transforms as transforms |
| | import matplotlib.pyplot as plt |
| | from PIL import Image, ImageDraw |
| | |
| | |
| | import torchvision |
| | from einops import rearrange |
| |
|
| |
|
| | def read_video_from_path(path): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='THWC') |
| | vframes = vframes.numpy() |
| | return vframes |
| |
|
| |
|
| |
|
| | def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True): |
| | |
| | draw = ImageDraw.Draw(rgb) |
| | |
| | left_up_point = (coord[0] - radius, coord[1] - radius) |
| | right_down_point = (coord[0] + radius, coord[1] + radius) |
| | |
| | draw.ellipse( |
| | [left_up_point, right_down_point], |
| | fill=tuple(color) if visible else None, |
| | outline=tuple(color), |
| | ) |
| | return rgb |
| |
|
| |
|
| | def draw_line(rgb, coord_y, coord_x, color, linewidth): |
| | draw = ImageDraw.Draw(rgb) |
| | draw.line( |
| | (coord_y[0], coord_y[1], coord_x[0], coord_x[1]), |
| | fill=tuple(color), |
| | width=linewidth, |
| | ) |
| | return rgb |
| |
|
| |
|
| | def add_weighted(rgb, alpha, original, beta, gamma): |
| | return (rgb * alpha + original * beta + gamma).astype("uint8") |
| |
|
| |
|
| | class Visualizer: |
| | def __init__( |
| | self, |
| | save_dir: str = "./results", |
| | grayscale: bool = False, |
| | pad_value: int = 0, |
| | fps: int = 10, |
| | mode: str = "rainbow", |
| | linewidth: int = 2, |
| | show_first_frame: int = 10, |
| | tracks_leave_trace: int = 0, |
| | ): |
| | self.mode = mode |
| | self.save_dir = save_dir |
| | if mode == "rainbow": |
| | self.color_map = cm.get_cmap("gist_rainbow") |
| | elif mode == "cool": |
| | self.color_map = cm.get_cmap(mode) |
| | self.show_first_frame = show_first_frame |
| | self.grayscale = grayscale |
| | self.tracks_leave_trace = tracks_leave_trace |
| | self.pad_value = pad_value |
| | self.linewidth = linewidth |
| | self.fps = fps |
| |
|
| | def visualize( |
| | self, |
| | video: torch.Tensor, |
| | tracks: torch.Tensor, |
| | visibility: torch.Tensor = None, |
| | gt_tracks: torch.Tensor = None, |
| | segm_mask: torch.Tensor = None, |
| | filename: str = "video", |
| | writer=None, |
| | step: int = 0, |
| | query_frame: int = 0, |
| | save_video: bool = True, |
| | compensate_for_camera_motion: bool = False, |
| | ): |
| | if compensate_for_camera_motion: |
| | assert segm_mask is not None |
| | if segm_mask is not None: |
| | coords = tracks[0, query_frame].round().long() |
| | segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() |
| |
|
| | video = F.pad( |
| | video, |
| | (self.pad_value, self.pad_value, self.pad_value, self.pad_value), |
| | "constant", |
| | 255, |
| | ) |
| | tracks = tracks + self.pad_value |
| |
|
| | if self.grayscale: |
| | transform = transforms.Grayscale() |
| | video = transform(video) |
| | video = video.repeat(1, 1, 3, 1, 1) |
| |
|
| | res_video = self.draw_tracks_on_video( |
| | video=video, |
| | tracks=tracks, |
| | visibility=visibility, |
| | segm_mask=segm_mask, |
| | gt_tracks=gt_tracks, |
| | query_frame=query_frame, |
| | compensate_for_camera_motion=compensate_for_camera_motion, |
| | ) |
| | if save_video: |
| | self.save_video(res_video, filename=filename, writer=writer, step=step) |
| | return res_video |
| |
|
| | def save_video(self, video, filename, writer=None, step=0): |
| | if writer is not None: |
| | writer.add_video( |
| | filename, |
| | video.to(torch.uint8), |
| | global_step=step, |
| | fps=self.fps, |
| | ) |
| | else: |
| | os.makedirs(self.save_dir, exist_ok=True) |
| |
|
| | |
| | save_path = os.path.join(self.save_dir, f"{filename}.mp4") |
| | |
| | assert video.shape[0] == 1 |
| | video = rearrange(video[0], 'T C H W -> T H W C') |
| | torchvision.io.write_video(save_path, video, fps=self.fps) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | print(f"Video saved to {save_path}") |
| |
|
| | def draw_tracks_on_video( |
| | self, |
| | video: torch.Tensor, |
| | tracks: torch.Tensor, |
| | visibility: torch.Tensor = None, |
| | segm_mask: torch.Tensor = None, |
| | gt_tracks=None, |
| | query_frame: int = 0, |
| | compensate_for_camera_motion=False, |
| | ): |
| | B, T, C, H, W = video.shape |
| | _, _, N, D = tracks.shape |
| |
|
| | assert D == 2 |
| | assert C == 3 |
| | video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() |
| | tracks = tracks[0].long().detach().cpu().numpy() |
| | if gt_tracks is not None: |
| | gt_tracks = gt_tracks[0].detach().cpu().numpy() |
| |
|
| | res_video = [] |
| |
|
| | |
| | for rgb in video: |
| | res_video.append(rgb.copy()) |
| | vector_colors = np.zeros((T, N, 3)) |
| |
|
| | |
| | if self.mode == "optical_flow": |
| | import flow_vis |
| |
|
| | vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) |
| | elif segm_mask is None: |
| | if self.mode == "rainbow": |
| | y_min, y_max = ( |
| | tracks[query_frame, :, 1].min(), |
| | tracks[query_frame, :, 1].max(), |
| | ) |
| | norm = plt.Normalize(y_min, y_max) |
| | for n in range(N): |
| | color = self.color_map(norm(tracks[query_frame, n, 1])) |
| | color = np.array(color[:3])[None] * 255 |
| | vector_colors[:, n] = np.repeat(color, T, axis=0) |
| | else: |
| | |
| | for t in range(T): |
| | color = np.array(self.color_map(t / T)[:3])[None] * 255 |
| | vector_colors[t] = np.repeat(color, N, axis=0) |
| | else: |
| | if self.mode == "rainbow": |
| | vector_colors[:, segm_mask <= 0, :] = 255 |
| |
|
| | y_min, y_max = ( |
| | tracks[0, segm_mask > 0, 1].min(), |
| | tracks[0, segm_mask > 0, 1].max(), |
| | ) |
| | norm = plt.Normalize(y_min, y_max) |
| | for n in range(N): |
| | if segm_mask[n] > 0: |
| | color = self.color_map(norm(tracks[0, n, 1])) |
| | color = np.array(color[:3])[None] * 255 |
| | vector_colors[:, n] = np.repeat(color, T, axis=0) |
| |
|
| | else: |
| | |
| | segm_mask = segm_mask.cpu() |
| | color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) |
| | color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 |
| | color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 |
| | vector_colors = np.repeat(color[None], T, axis=0) |
| |
|
| | |
| | if self.tracks_leave_trace != 0: |
| | for t in range(query_frame + 1, T): |
| | first_ind = ( |
| | max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 |
| | ) |
| | curr_tracks = tracks[first_ind : t + 1] |
| | curr_colors = vector_colors[first_ind : t + 1] |
| | if compensate_for_camera_motion: |
| | diff = ( |
| | tracks[first_ind : t + 1, segm_mask <= 0] |
| | - tracks[t : t + 1, segm_mask <= 0] |
| | ).mean(1)[:, None] |
| |
|
| | curr_tracks = curr_tracks - diff |
| | curr_tracks = curr_tracks[:, segm_mask > 0] |
| | curr_colors = curr_colors[:, segm_mask > 0] |
| |
|
| | res_video[t] = self._draw_pred_tracks( |
| | res_video[t], |
| | curr_tracks, |
| | curr_colors, |
| | ) |
| | if gt_tracks is not None: |
| | res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1]) |
| |
|
| | |
| | for t in range(query_frame, T): |
| | img = Image.fromarray(np.uint8(res_video[t])) |
| | for i in range(N): |
| | coord = (tracks[t, i, 0], tracks[t, i, 1]) |
| | visibile = True |
| | if visibility is not None: |
| | visibile = visibility[0, t, i] |
| | if coord[0] != 0 and coord[1] != 0: |
| | if not compensate_for_camera_motion or ( |
| | compensate_for_camera_motion and segm_mask[i] > 0 |
| | ): |
| | img = draw_circle( |
| | img, |
| | coord=coord, |
| | radius=int(self.linewidth * 2), |
| | color=vector_colors[t, i].astype(int), |
| | visible=visibile, |
| | ) |
| | res_video[t] = np.array(img) |
| |
|
| | |
| | if self.show_first_frame > 0: |
| | res_video = [res_video[0]] * self.show_first_frame + res_video[1:] |
| | return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() |
| |
|
| | def _draw_pred_tracks( |
| | self, |
| | rgb: np.ndarray, |
| | tracks: np.ndarray, |
| | vector_colors: np.ndarray, |
| | alpha: float = 0.5, |
| | ): |
| | T, N, _ = tracks.shape |
| | rgb = Image.fromarray(np.uint8(rgb)) |
| | for s in range(T - 1): |
| | vector_color = vector_colors[s] |
| | original = rgb.copy() |
| | alpha = (s / T) ** 2 |
| | for i in range(N): |
| | coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) |
| | coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) |
| | if coord_y[0] != 0 and coord_y[1] != 0: |
| | rgb = draw_line( |
| | rgb, |
| | coord_y, |
| | coord_x, |
| | vector_color[i].astype(int), |
| | self.linewidth, |
| | ) |
| | if self.tracks_leave_trace > 0: |
| | rgb = Image.fromarray( |
| | np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0)) |
| | ) |
| | rgb = np.array(rgb) |
| | return rgb |
| |
|
| | def _draw_gt_tracks( |
| | self, |
| | rgb: np.ndarray, |
| | gt_tracks: np.ndarray, |
| | ): |
| | T, N, _ = gt_tracks.shape |
| | color = np.array((211, 0, 0)) |
| | rgb = Image.fromarray(np.uint8(rgb)) |
| | for t in range(T): |
| | for i in range(N): |
| | gt_tracks = gt_tracks[t][i] |
| | |
| | if gt_tracks[0] > 0 and gt_tracks[1] > 0: |
| | length = self.linewidth * 3 |
| | coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) |
| | coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) |
| | rgb = draw_line( |
| | rgb, |
| | coord_y, |
| | coord_x, |
| | color, |
| | self.linewidth, |
| | ) |
| | coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) |
| | coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) |
| | rgb = draw_line( |
| | rgb, |
| | coord_y, |
| | coord_x, |
| | color, |
| | self.linewidth, |
| | ) |
| | rgb = np.array(rgb) |
| | return rgb |
| |
|