|
|
|
|
| import torch
|
| import numpy as np
|
| import argparse
|
| from PIL import Image
|
|
|
| def convert_to_numpy(image):
|
| if isinstance(image, Image.Image):
|
| image = np.array(image)
|
| elif isinstance(image, torch.Tensor):
|
| image = image.detach().cpu().numpy()
|
| elif isinstance(image, np.ndarray):
|
| image = image.copy()
|
| else:
|
| raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
| return image
|
|
|
| class FlowAnnotator:
|
| def __init__(self, cfg, device=None):
|
| from .raft.raft import RAFT
|
| from .raft.utils.utils import InputPadder
|
| from .raft.utils import flow_viz
|
|
|
| params = {
|
| "small": False,
|
| "mixed_precision": False,
|
| "alternate_corr": False
|
| }
|
| params = argparse.Namespace(**params)
|
| pretrained_model = cfg['PRETRAINED_MODEL']
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| self.model = RAFT(params)
|
| self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_model, map_location="cpu", weights_only=True).items()})
|
| self.model = self.model.to(self.device).eval()
|
| self.InputPadder = InputPadder
|
| self.flow_viz = flow_viz
|
|
|
| def forward(self, frames):
|
|
|
| frames = [torch.from_numpy(convert_to_numpy(frame).astype(np.uint8)).permute(2, 0, 1).float()[None].to(self.device) for frame in frames]
|
| flow_up_list, flow_up_vis_list = [], []
|
| with torch.no_grad():
|
| for i, (image1, image2) in enumerate(zip(frames[:-1], frames[1:])):
|
| padder = self.InputPadder(image1.shape)
|
| image1, image2 = padder.pad(image1, image2)
|
| flow_low, flow_up = self.model(image1, image2, iters=20, test_mode=True)
|
| flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy()
|
| flow_up_vis = self.flow_viz.flow_to_image(flow_up)
|
| flow_up_list.append(flow_up)
|
| flow_up_vis_list.append(flow_up_vis)
|
| return flow_up_list, flow_up_vis_list
|
|
|
|
|
| class FlowVisAnnotator(FlowAnnotator):
|
| def forward(self, frames):
|
| flow_up_list, flow_up_vis_list = super().forward(frames)
|
| return flow_up_vis_list[:1] + flow_up_vis_list |