| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .mask_decoder import MaskDecoder |
| from .prompt_encoder import PromptEncoder |
| from .transform import TwoWayTransformer |
|
|
| class MedSAM_Lite(nn.Module): |
| def __init__(self, |
| image_encoder, |
| mask_decoder, |
| prompt_encoder |
| ): |
| super().__init__() |
| self.image_encoder = image_encoder |
| self.mask_decoder = mask_decoder |
| self.prompt_encoder = prompt_encoder |
| |
| def forward(self, image, boxes): |
| image_embedding = self.image_encoder(image) |
|
|
| sparse_embeddings, dense_embeddings = self.prompt_encoder( |
| points=None, |
| boxes=boxes, |
| masks=None, |
| ) |
| low_res_masks, iou_predictions = self.mask_decoder( |
| image_embeddings=image_embedding, |
| image_pe=self.prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| multimask_output=False, |
| ) |
|
|
| return low_res_masks, iou_predictions |
| |
| @torch.no_grad() |
| def postprocess_masks(self, masks, new_size, original_size): |
| """ |
| Do cropping and resizing |
| """ |
| |
| masks = masks[:, :, :new_size[0], :new_size[1]] |
| |
| masks = F.interpolate( |
| masks, |
| size=(original_size[0], original_size[1]), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| return masks |
| |
|
|