MCP-MedSAM / models /lite_medsam.py
Leo-Lyu's picture
Upload 13 files
1715fda verified
raw
history blame
1.68 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .transform import TwoWayTransformer
class MedSAM_Lite(nn.Module):
def __init__(self,
image_encoder,
mask_decoder,
prompt_encoder
):
super().__init__()
self.image_encoder = image_encoder
self.mask_decoder = mask_decoder
self.prompt_encoder = prompt_encoder
def forward(self, image, boxes):
image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=None,
boxes=boxes,
masks=None,
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=image_embedding, # (B, 256, 64, 64)
image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
) # (B, 1, 256, 256)
return low_res_masks, iou_predictions
@torch.no_grad()
def postprocess_masks(self, masks, new_size, original_size):
"""
Do cropping and resizing
"""
# Crop
masks = masks[:, :, :new_size[0], :new_size[1]]
# Resize
masks = F.interpolate(
masks,
size=(original_size[0], original_size[1]),
mode="bilinear",
align_corners=False,
)
return masks