Spaces:
Sleeping
Sleeping
| from collections import defaultdict | |
| import os | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import cv2 | |
| import torch | |
| from torch import Tensor, nn | |
| import torch.nn.functional as F | |
| import pytorch_lightning as pl | |
| import numpy as np | |
| from .configs.base_config import base_cfg | |
| from .rgbd_model import RGBDModel | |
| class ModelPL(pl.LightningModule): | |
| def __init__(self, cfg: base_cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.model = RGBDModel(cfg) | |
| def forward(self, images: Tensor, depths: Tensor): | |
| return self.model.forward(images, depths) | |
| def __inference_v1( | |
| self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]] | |
| ): | |
| res_lst: List[List[np.ndarray]] = [] | |
| for output, image_size in zip(outputs["sod"], image_sizes): | |
| output: Tensor = F.interpolate( | |
| output.unsqueeze(0), | |
| size=(image_size[1], image_size[0]), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| res: np.ndarray = output.sigmoid().data.cpu().numpy().squeeze() | |
| res = (res - res.min()) / (res.max() - res.min() + 1e-8) | |
| if self.cfg.is_fp16: | |
| res = np.float32(res) | |
| res_lst.append([(res * 255).astype(np.uint8)]) | |
| return res_lst | |
| def __inference_v2( | |
| self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]] | |
| ): | |
| res_lst: List[List[np.ndarray]] = [] | |
| for output, image_size in zip(outputs["sod"], image_sizes): | |
| output: Tensor = F.interpolate( | |
| output.unsqueeze(0), | |
| size=(image_size[1], image_size[0]), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| res: np.ndarray = torch.argmax(output, dim=1).cpu().numpy().squeeze() | |
| res_lst.append([res]) | |
| return res_lst | |
| def __inference_v3v5( | |
| self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]] | |
| ): | |
| res_lst: List[List[np.ndarray]] = [] | |
| for bi, image_size in enumerate(image_sizes): | |
| res_lst_per_sample: List[np.ndarray] = [] | |
| for i in range(self.cfg.num_classes): | |
| pred = outputs[f"sod{i}"][bi] | |
| pred: Tensor = F.interpolate( | |
| pred.unsqueeze(0), | |
| size=(image_size[1], image_size[0]), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| res: np.ndarray = pred.sigmoid().data.cpu().numpy().squeeze() | |
| res = (res - res.min()) / (res.max() - res.min() + 1e-8) | |
| if self.cfg.is_fp16: | |
| res = np.float32(res) | |
| res_lst_per_sample.append((res * 255).astype(np.uint8)) | |
| res_lst.append(res_lst_per_sample) | |
| return res_lst | |
| def inference( | |
| self, | |
| image_sizes: List[Tuple[int, int]], | |
| images: Tensor, | |
| depths: Optional[Tensor], | |
| max_gts: Optional[List[int]], | |
| ) -> List[List[np.ndarray]]: | |
| self.model.eval() | |
| assert len(image_sizes) == len( | |
| images | |
| ), "The number of image_sizes must equal to the number of images" | |
| gpu_images: Tensor = images.to(self.device) | |
| gpu_depths: Tensor = depths.to(self.device) | |
| if self.cfg.ground_truth_version == 6: | |
| with torch.cuda.amp.autocast(enabled=self.cfg.is_fp16): | |
| outputs: Dict[str, Tensor] = dict() | |
| for i in range(self.cfg.num_classes): | |
| outputs[f"sod{i}"] = self.model.inference( | |
| gpu_images, gpu_depths, [i] * gpu_images.shape[0], max_gts | |
| )["sod"] | |
| return self.__inference_v3v5(outputs, image_sizes) | |
| else: | |
| raise Exception( | |
| f"Unsupported ground_truth_version {self.cfg.ground_truth_version}" | |
| ) | |