Spaces:
Runtime error
Runtime error
| import torch | |
| from typing import List | |
| import tops | |
| from torchvision.transforms.functional import InterpolationMode, resize | |
| from densepose.data.utils import get_class_to_mesh_name_mapping | |
| from densepose import add_densepose_config | |
| from densepose.structures import DensePoseEmbeddingPredictorOutput | |
| from densepose.vis.extractor import DensePoseOutputsExtractor | |
| from densepose.modeling import build_densepose_embedder | |
| from detectron2.config import get_cfg | |
| from detectron2.data.transforms import ResizeShortestEdge | |
| from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer | |
| from detectron2.modeling import build_model | |
| model_urls = { | |
| "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl", | |
| "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl", | |
| } | |
| def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape): | |
| assert len(S.shape) == 3 | |
| H, W = imshape | |
| N = len(boxes_XYXY) | |
| segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device) | |
| boxes_XYXY = boxes_XYXY.long() | |
| for i in range(N): | |
| x0, y0, x1, y1 = boxes_XYXY[i] | |
| assert x0 >= 0 and y0 >= 0 | |
| assert x1 <= imshape[1] | |
| assert y1 <= imshape[0] | |
| h = y1 - y0 | |
| w = x1 - x0 | |
| segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0 | |
| return segmentation | |
| class CSEDetector: | |
| def __init__( | |
| self, | |
| cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", | |
| cfg_2_download: List[str] = [ | |
| "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", | |
| "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml", | |
| "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"], | |
| score_thres: float = 0.9, | |
| nms_thresh: float = None, | |
| ) -> None: | |
| with tops.logger.capture_log_stdout(): | |
| cfg = get_cfg() | |
| self.device = tops.get_device() | |
| add_densepose_config(cfg) | |
| cfg_path = tops.download_file(cfg_url) | |
| for p in cfg_2_download: | |
| tops.download_file(p) | |
| with tops.logger.capture_log_stdout(): | |
| cfg.merge_from_file(cfg_path) | |
| assert cfg_url in model_urls, cfg_url | |
| model_path = tops.download_file(model_urls[cfg_url]) | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres | |
| if nms_thresh is not None: | |
| cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh | |
| cfg.MODEL.WEIGHTS = str(model_path) | |
| cfg.MODEL.DEVICE = str(self.device) | |
| cfg.freeze() | |
| with tops.logger.capture_log_stdout(): | |
| self.model = build_model(cfg) | |
| self.model.eval() | |
| DetectionCheckpointer(self.model).load(str(model_path)) | |
| self.input_format = cfg.INPUT.FORMAT | |
| self.densepose_extractor = DensePoseOutputsExtractor() | |
| self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) | |
| self.embedder = build_densepose_embedder(cfg) | |
| self.mesh_vertex_embeddings = { | |
| mesh_name: self.embedder(mesh_name).to(self.device) | |
| for mesh_name in self.class_to_mesh_name.values() | |
| if self.embedder.has_embeddings(mesh_name) | |
| } | |
| self.cfg = cfg | |
| self.embed_map = self.mesh_vertex_embeddings["smpl_27554"] | |
| tops.logger.log("CSEDetector built.") | |
| def __call__(self, *args, **kwargs): | |
| return self.forward(*args, **kwargs) | |
| def resize_im(self, im): | |
| H, W = im.shape[1:] | |
| newH, newW = ResizeShortestEdge.get_output_shape( | |
| H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST) | |
| return resize( | |
| im, (newH, newW), InterpolationMode.BILINEAR, antialias=True) | |
| def forward(self, im): | |
| assert im.dtype == torch.uint8 | |
| if self.input_format == "BGR": | |
| im = im.flip(0) | |
| H, W = im.shape[1:] | |
| im = self.resize_im(im) | |
| output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"] | |
| scores = output.get("scores") | |
| if len(scores) == 0: | |
| return dict( | |
| instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device), | |
| instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device), | |
| embed_map=self.mesh_vertex_embeddings["smpl_27554"], | |
| bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device), | |
| im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device), | |
| scores=torch.empty((0), dtype=torch.float, device=im.device) | |
| ) | |
| pred_densepose, boxes_xywh, classes = self.densepose_extractor(output) | |
| assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose | |
| S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes) | |
| E = pred_densepose.embedding | |
| mesh_name = self.class_to_mesh_name[classes[0]] | |
| assert mesh_name == "smpl_27554" | |
| x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)] | |
| boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1) | |
| boxes_XYXY = boxes_XYXY.round_().long() | |
| non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not() | |
| S = S[non_empty_boxes] | |
| E = E[non_empty_boxes] | |
| boxes_XYXY = boxes_XYXY[non_empty_boxes] | |
| scores = scores[non_empty_boxes] | |
| im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W]) | |
| return dict( | |
| instance_segmentation=S, instance_embedding=E, | |
| bbox_XYXY=boxes_XYXY, | |
| im_segmentation=im_segmentation, | |
| scores=scores.view(-1)) | |