| |
|
|
| |
| from typing import Dict, List |
|
|
| import numpy as np |
| import PIL |
| import torch |
| from ..model import box_ops |
| from ..model.data_misc import FindStage, interpolate |
| from torchvision.transforms import v2 |
|
|
|
|
| class Sam3Processor: |
| """ """ |
|
|
| def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5): |
| self.model = model |
| self.resolution = resolution |
| self.device = device |
| self.transform = v2.Compose( |
| [ |
| v2.ToDtype(torch.uint8, scale=True), |
| v2.Resize(size=(resolution, resolution)), |
| v2.ToDtype(torch.float32, scale=True), |
| v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
| ] |
| ) |
| self.confidence_threshold = confidence_threshold |
|
|
| self.find_stage = FindStage( |
| img_ids=torch.tensor([0], device=device, dtype=torch.long), |
| text_ids=torch.tensor([0], device=device, dtype=torch.long), |
| input_boxes=None, |
| input_boxes_mask=None, |
| input_boxes_label=None, |
| input_points=None, |
| input_points_mask=None, |
| ) |
|
|
| @torch.inference_mode() |
| def set_image(self, image, state=None): |
| """Sets the image on which we want to do predictions.""" |
| if state is None: |
| state = {} |
|
|
| if isinstance(image, PIL.Image.Image): |
| width, height = image.size |
| elif isinstance(image, (torch.Tensor, np.ndarray)): |
| height, width = image.shape[-2:] |
| else: |
| raise ValueError("Image must be a PIL image or a tensor") |
|
|
| image = v2.functional.to_image(image).to(self.device) |
| image = self.transform(image).unsqueeze(0) |
|
|
| state["original_height"] = height |
| state["original_width"] = width |
| state["backbone_out"] = self.model.backbone.forward_image(image) |
| inst_interactivity_en = self.model.inst_interactive_predictor is not None |
| if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]: |
| sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"] |
| sam2_backbone_out["backbone_fpn"][0] = ( |
| self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0( |
| sam2_backbone_out["backbone_fpn"][0] |
| ) |
| ) |
| sam2_backbone_out["backbone_fpn"][1] = ( |
| self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1( |
| sam2_backbone_out["backbone_fpn"][1] |
| ) |
| ) |
| return state |
|
|
| @torch.inference_mode() |
| def set_image_batch(self, images: List[np.ndarray], state=None): |
| """Sets the image batch on which we want to do predictions.""" |
| if state is None: |
| state = {} |
|
|
| if not isinstance(images, list): |
| raise ValueError("Images must be a list of PIL images or tensors") |
| assert len(images) > 0, "Images list must not be empty" |
| assert isinstance(images[0], PIL.Image.Image), ( |
| "Images must be a list of PIL images" |
| ) |
|
|
| state["original_heights"] = [image.height for image in images] |
| state["original_widths"] = [image.width for image in images] |
|
|
| images = [ |
| self.transform(v2.functional.to_image(image).to(self.device)) |
| for image in images |
| ] |
| images = torch.stack(images, dim=0) |
| state["backbone_out"] = self.model.backbone.forward_image(images) |
| inst_interactivity_en = self.model.inst_interactive_predictor is not None |
| if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]: |
| sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"] |
| sam2_backbone_out["backbone_fpn"][0] = ( |
| self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0( |
| sam2_backbone_out["backbone_fpn"][0] |
| ) |
| ) |
| sam2_backbone_out["backbone_fpn"][1] = ( |
| self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1( |
| sam2_backbone_out["backbone_fpn"][1] |
| ) |
| ) |
| return state |
|
|
| @torch.inference_mode() |
| def set_text_prompt(self, prompt: str, state: Dict): |
| """Sets the text prompt and run the inference""" |
|
|
| if "backbone_out" not in state: |
| raise ValueError("You must call set_image before set_text_prompt") |
|
|
| text_outputs = self.model.backbone.forward_text([prompt], device=self.device) |
| |
| state["backbone_out"].update(text_outputs) |
| if "geometric_prompt" not in state: |
| state["geometric_prompt"] = self.model._get_dummy_prompt() |
|
|
| return self._forward_grounding(state) |
|
|
| @torch.inference_mode() |
| def add_geometric_prompt(self, box: List, label: bool, state: Dict): |
| """Adds a box prompt and run the inference. |
| The image needs to be set, but not necessarily the text prompt. |
| The box is assumed to be in [center_x, center_y, width, height] format and normalized in [0, 1] range. |
| The label is True for a positive box, False for a negative box. |
| """ |
| if "backbone_out" not in state: |
| raise ValueError("You must call set_image before set_text_prompt") |
|
|
| if "language_features" not in state["backbone_out"]: |
| |
| dummy_text_outputs = self.model.backbone.forward_text( |
| ["visual"], device=self.device |
| ) |
| state["backbone_out"].update(dummy_text_outputs) |
|
|
| if "geometric_prompt" not in state: |
| state["geometric_prompt"] = self.model._get_dummy_prompt() |
|
|
| |
| boxes = torch.tensor(box, device=self.device, dtype=torch.float32).view(1, 1, 4) |
| labels = torch.tensor([label], device=self.device, dtype=torch.bool).view(1, 1) |
| state["geometric_prompt"].append_boxes(boxes, labels) |
|
|
| return self._forward_grounding(state) |
|
|
| def reset_all_prompts(self, state: Dict): |
| """Removes all the prompts and results""" |
| if "backbone_out" in state: |
| backbone_keys_to_del = [ |
| "language_features", |
| "language_mask", |
| "language_embeds", |
| ] |
| for key in backbone_keys_to_del: |
| if key in state["backbone_out"]: |
| del state["backbone_out"][key] |
|
|
| keys_to_del = ["geometric_prompt", "boxes", "masks", "masks_logits", "scores"] |
| for key in keys_to_del: |
| if key in state: |
| del state[key] |
|
|
| @torch.inference_mode() |
| def set_confidence_threshold(self, threshold: float, state=None): |
| """Sets the confidence threshold for the masks""" |
| self.confidence_threshold = threshold |
| if state is not None and "boxes" in state: |
| |
| |
| |
| return self._forward_grounding(state) |
| return state |
|
|
| @torch.inference_mode() |
| def _forward_grounding(self, state: Dict): |
| outputs = self.model.forward_grounding( |
| backbone_out=state["backbone_out"], |
| find_input=self.find_stage, |
| geometric_prompt=state["geometric_prompt"], |
| find_target=None, |
| ) |
|
|
| out_bbox = outputs["pred_boxes"] |
| out_logits = outputs["pred_logits"] |
| out_masks = outputs["pred_masks"] |
| out_probs = out_logits.sigmoid() |
| presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1) |
| out_probs = (out_probs * presence_score).squeeze(-1) |
|
|
| keep = out_probs > self.confidence_threshold |
| out_probs = out_probs[keep] |
| out_masks = out_masks[keep] |
| out_bbox = out_bbox[keep] |
|
|
| |
| boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) |
|
|
| img_h = state["original_height"] |
| img_w = state["original_width"] |
| scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).to(self.device) |
| boxes = boxes * scale_fct[None, :] |
|
|
| out_masks = interpolate( |
| out_masks.unsqueeze(1), |
| (img_h, img_w), |
| mode="bilinear", |
| align_corners=False, |
| ).sigmoid() |
|
|
| state["masks_logits"] = out_masks |
| state["masks"] = out_masks > 0.5 |
| state["boxes"] = boxes |
| state["scores"] = out_probs |
| return state |
|
|