segment / inference.py
Shengxiao0709's picture
Update inference.py
a4fce5b verified
import torch
import numpy as np
from skimage import transform
from segment_anything import sam_model_registry
MEDSAM_IMG_INPUT_SIZE = 1024
def load_model(checkpoint_path):
model = sam_model_registry["vit_b"](checkpoint=checkpoint_path)
model = model.to("cpu")
model.eval()
return model, torch.device("cpu")
@torch.no_grad()
def get_embedding(model, img_np, device):
img_1024 = transform.resize(
img_np, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
).astype(np.uint8)
img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), 1e-8, None)
img_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
return model.image_encoder(img_tensor)
@torch.no_grad()
def run(model, embedding, box_1024, H, W):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=embedding.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
sparse_embeddings, dense_embeddings = model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, _ = model.mask_decoder(
image_embeddings=embedding,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits)
low_res_pred = torch.nn.functional.interpolate(
low_res_pred, size=(H, W), mode="bilinear", align_corners=False
)
low_res_pred = low_res_pred.squeeze().cpu().numpy()
return (low_res_pred > 0.5).astype(np.uint8)