Spaces:
Running
Running
| from PIL import Image | |
| import numpy as np | |
| from .base_segmenter import BaseSegmenter | |
| from .painter import mask_painter, point_painter | |
| mask_color = 3 | |
| mask_alpha = 0.7 | |
| contour_color = 1 | |
| contour_width = 5 | |
| point_color_ne = 8 | |
| point_color_ps = 50 | |
| point_alpha = 0.9 | |
| point_radius = 15 | |
| contour_color = 2 | |
| contour_width = 5 | |
| class SamControler: | |
| def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device): | |
| """ | |
| initialize sam controler | |
| """ | |
| self.sam_controler = BaseSegmenter(sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device) | |
| self.onnx = model_type == "vit_t" | |
| def first_frame_click( | |
| self, | |
| image: np.ndarray, | |
| points: np.ndarray, | |
| labels: np.ndarray, | |
| multimask=True, | |
| mask_color=3, | |
| ): | |
| """ | |
| it is used in first frame in video | |
| return: mask, logit, painted image(mask+point) | |
| """ | |
| # self.sam_controler.set_image(image) | |
| neg_flag = labels[-1] | |
| if self.onnx: | |
| onnx_coord = np.concatenate([points, np.array([[0.0, 0.0]])], axis=0)[None, :, :] | |
| onnx_label = np.concatenate([labels, np.array([-1])], axis=0)[None, :].astype(np.float32) | |
| onnx_coord = self.sam_controler.predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32) | |
| prompts = { | |
| "point_coords": onnx_coord, | |
| "point_labels": onnx_label, | |
| "orig_im_size": np.array(image.shape[:2], dtype=np.float32), | |
| } | |
| else: | |
| prompts = { | |
| "point_coords": points, | |
| "point_labels": labels, | |
| } | |
| if neg_flag == 1: | |
| # find positive | |
| masks, scores, logits = self.sam_controler.predict( | |
| prompts, "point", multimask | |
| ) | |
| mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] | |
| prompts["mask_input"] = np.expand_dims(logit[None, :, :], 0) | |
| masks, scores, logits = self.sam_controler.predict( | |
| prompts, "both", multimask | |
| ) | |
| mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] | |
| else: | |
| # find neg | |
| masks, scores, logits = self.sam_controler.predict( | |
| prompts, "point", multimask | |
| ) | |
| mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] | |
| assert len(points) == len(labels) | |
| painted_image = mask_painter( | |
| image, | |
| mask.astype("uint8"), | |
| mask_color, | |
| mask_alpha, | |
| contour_color, | |
| contour_width, | |
| ) | |
| painted_image = point_painter( | |
| painted_image, | |
| np.squeeze(points[np.argwhere(labels > 0)], axis=1), | |
| point_color_ne, | |
| point_alpha, | |
| point_radius, | |
| contour_color, | |
| contour_width, | |
| ) | |
| painted_image = point_painter( | |
| painted_image, | |
| np.squeeze(points[np.argwhere(labels < 1)], axis=1), | |
| point_color_ps, | |
| point_alpha, | |
| point_radius, | |
| contour_color, | |
| contour_width, | |
| ) | |
| painted_image = Image.fromarray(painted_image) | |
| return mask, logit, painted_image | |