| | import torch |
| | import torch.nn.functional as F |
| | import cv2 |
| | from PIL import Image |
| | import numpy as np |
| |
|
| |
|
| | class Colors: |
| | def __init__(self): |
| | |
| | hexs = ( |
| | "00FF00", |
| | "FF3838", |
| | "FF701F", |
| | "FFB21D", |
| | "CFD231", |
| | "48F90A", |
| | "92CC17", |
| | "3DDB86", |
| | "1A9334", |
| | "00D4BB", |
| | "2C99A8", |
| | "00C2FF", |
| | "344593", |
| | "6473FF", |
| | "0018EC", |
| | "8438FF", |
| | "520085", |
| | "CB38FF", |
| | "FF95C8", |
| | "FF37C7", |
| | ) |
| | self.palette = [self.hex2rgb(f"#{c}") for c in hexs] |
| | self.n = len(self.palette) |
| |
|
| | def __call__(self, i, bgr=False): |
| | c = self.palette[int(i) % self.n] |
| | return (c[2], c[1], c[0]) if bgr else c |
| |
|
| | @staticmethod |
| | def hex2rgb(h): |
| | return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) |
| |
|
| |
|
| | colors = Colors() |
| |
|
| |
|
| | def is_ascii(s=""): |
| | |
| | s = str(s) |
| | return len(s.encode().decode("ascii", "ignore")) == len(s) |
| |
|
| |
|
| | def clip_boxes(boxes, shape): |
| | |
| | if isinstance(boxes, torch.Tensor): |
| | boxes[:, 0].clamp_(0, shape[1]) |
| | boxes[:, 1].clamp_(0, shape[0]) |
| | boxes[:, 2].clamp_(0, shape[1]) |
| | boxes[:, 3].clamp_(0, shape[0]) |
| | else: |
| | boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) |
| | boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) |
| |
|
| |
|
| | def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): |
| | |
| | if ratio_pad is None: |
| | gain = min( |
| | img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1] |
| | ) |
| | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, ( |
| | img1_shape[0] - img0_shape[0] * gain |
| | ) / 2 |
| | else: |
| | gain = ratio_pad[0][0] |
| | pad = ratio_pad[1] |
| |
|
| | boxes[:, [0, 2]] -= pad[0] |
| | boxes[:, [1, 3]] -= pad[1] |
| | boxes[:, :4] /= gain |
| | clip_boxes(boxes, img0_shape) |
| | return boxes |
| |
|
| |
|
| | def crop_mask(masks, boxes): |
| | """ |
| | "Crop" predicted masks by zeroing out everything not in the predicted bbox. |
| | Vectorized by Chong (thanks Chong). |
| | Args: |
| | - masks should be a size [h, w, n] tensor of masks |
| | - boxes should be a size [n, 4] tensor of bbox coords in relative point form |
| | """ |
| |
|
| | n, h, w = masks.shape |
| | x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) |
| | r = torch.arange(w, device=masks.device, dtype=x1.dtype)[ |
| | None, None, : |
| | ] |
| | c = torch.arange(h, device=masks.device, dtype=x1.dtype)[ |
| | None, :, None |
| | ] |
| |
|
| | return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) |
| |
|
| |
|
| | def process_mask(protos, masks_in, bboxes, shape, upsample=False): |
| | """ |
| | Crop before upsample. |
| | proto_out: [mask_dim, mask_h, mask_w] |
| | out_masks: [n, mask_dim], n is number of masks after nms |
| | bboxes: [n, 4], n is number of masks after nms |
| | shape:input_image_size, (h, w) |
| | return: h, w, n |
| | """ |
| |
|
| | c, mh, mw = protos.shape |
| | ih, iw = shape |
| | masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) |
| |
|
| | downsampled_bboxes = bboxes.clone() |
| | downsampled_bboxes[:, 0] *= mw / iw |
| | downsampled_bboxes[:, 2] *= mw / iw |
| | downsampled_bboxes[:, 3] *= mh / ih |
| | downsampled_bboxes[:, 1] *= mh / ih |
| |
|
| | masks = crop_mask(masks, downsampled_bboxes) |
| | if upsample: |
| | masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[ |
| | 0 |
| | ] |
| | return masks.gt_(0.5) |
| |
|
| |
|
| | def scale_image(im1_shape, masks, im0_shape, ratio_pad=None): |
| | """ |
| | img1_shape: model input shape, [h, w] |
| | img0_shape: origin pic shape, [h, w, 3] |
| | masks: [h, w, num] |
| | """ |
| | |
| | if ratio_pad is None: |
| | gain = min( |
| | im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1] |
| | ) |
| | pad = (im1_shape[1] - im0_shape[1] * gain) / 2, ( |
| | im1_shape[0] - im0_shape[0] * gain |
| | ) / 2 |
| | else: |
| | pad = ratio_pad[1] |
| | top, left = int(pad[1]), int(pad[0]) |
| | bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) |
| |
|
| | if len(masks.shape) < 2: |
| | raise ValueError( |
| | f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}' |
| | ) |
| | masks = masks[top:bottom, left:right] |
| | |
| | |
| | |
| | masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) |
| |
|
| | if len(masks.shape) == 2: |
| | masks = masks[:, :, None] |
| | return masks |
| |
|
| |
|
| | class Annotator: |
| | |
| | def __init__( |
| | self, |
| | im, |
| | line_width=None, |
| | font_size=None, |
| | font="Arial.ttf", |
| | pil=False, |
| | example="abc", |
| | ): |
| | assert ( |
| | im.data.contiguous |
| | ), "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images." |
| | non_ascii = not is_ascii( |
| | example |
| | ) |
| | self.pil = pil or non_ascii |
| | if self.pil: |
| | self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) |
| | self.draw = ImageDraw.Draw(self.im) |
| | self.font = check_pil_font( |
| | font="Arial.Unicode.ttf" if non_ascii else font, |
| | size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12), |
| | ) |
| | else: |
| | self.im = im |
| | self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) |
| |
|
| | def box_label( |
| | self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255) |
| | ): |
| | |
| | if self.pil or not is_ascii(label): |
| | self.draw.rectangle(box, width=self.lw, outline=color) |
| | if label: |
| | w, h = self.font.getsize(label) |
| | outside = box[1] - h >= 0 |
| | self.draw.rectangle( |
| | ( |
| | box[0], |
| | box[1] - h if outside else box[1], |
| | box[0] + w + 1, |
| | box[1] + 1 if outside else box[1] + h + 1, |
| | ), |
| | fill=color, |
| | ) |
| | |
| | self.draw.text( |
| | (box[0], box[1] - h if outside else box[1]), |
| | label, |
| | fill=txt_color, |
| | font=self.font, |
| | ) |
| | else: |
| | p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) |
| | cv2.rectangle( |
| | self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA |
| | ) |
| | if label: |
| | tf = max(self.lw - 1, 1) |
| | w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[ |
| | 0 |
| | ] |
| | outside = p1[1] - h >= 3 |
| | p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 |
| | cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) |
| | cv2.putText( |
| | self.im, |
| | label, |
| | (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), |
| | 0, |
| | self.lw / 3, |
| | txt_color, |
| | thickness=tf, |
| | lineType=cv2.LINE_AA, |
| | ) |
| |
|
| | def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False): |
| | """Plot masks at once. |
| | Args: |
| | masks (tensor): predicted masks on cuda, shape: [n, h, w] |
| | colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n] |
| | im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1] |
| | alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque |
| | """ |
| | im_gpu = torch.from_numpy(im_gpu) |
| | |
| | if self.pil: |
| | |
| | self.im = np.asarray(self.im).copy() |
| | if len(masks) == 0: |
| | self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255 |
| | colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0 |
| | colors = colors[:, None, None] |
| | masks = masks.unsqueeze(3) |
| | masks_color = masks * (colors * alpha) |
| |
|
| | inv_alph_masks = (1 - masks * alpha).cumprod(0) |
| | mcs = (masks_color * inv_alph_masks).sum( |
| | 0 |
| | ) * 2 |
| |
|
| | im_gpu = im_gpu.flip(dims=[0]) |
| | im_gpu = im_gpu.permute(1, 2, 0).contiguous() |
| | im_gpu = im_gpu * inv_alph_masks[-1] + mcs |
| | im_mask = (im_gpu * 255).byte().cpu().numpy() |
| | self.im[:] = ( |
| | im_mask |
| | if retina_masks |
| | else scale_image(im_gpu.shape, im_mask, self.im.shape) |
| | ) |
| | if self.pil: |
| | |
| | self.fromarray(self.im) |
| |
|
| | def rectangle(self, xy, fill=None, outline=None, width=1): |
| | |
| | self.draw.rectangle(xy, fill, outline, width) |
| |
|
| | def text(self, xy, text, txt_color=(255, 255, 255), anchor="top"): |
| | |
| | if anchor == "bottom": |
| | w, h = self.font.getsize(text) |
| | xy[1] += 1 - h |
| | self.draw.text(xy, text, fill=txt_color, font=self.font) |
| |
|
| | def fromarray(self, im): |
| | |
| | self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) |
| | self.draw = ImageDraw.Draw(self.im) |
| |
|
| | def result(self): |
| | |
| | return np.asarray(self.im) |
| |
|