import torch import numpy as np from skimage import transform from segment_anything import sam_model_registry MEDSAM_IMG_INPUT_SIZE = 1024 def load_model(checkpoint_path): model = sam_model_registry["vit_b"](checkpoint=checkpoint_path) model = model.to("cpu") model.eval() return model, torch.device("cpu") @torch.no_grad() def get_embedding(model, img_np, device): img_1024 = transform.resize( img_np, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True ).astype(np.uint8) img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), 1e-8, None) img_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) return model.image_encoder(img_tensor) @torch.no_grad() def run(model, embedding, box_1024, H, W): box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=embedding.device) if len(box_torch.shape) == 2: box_torch = box_torch[:, None, :] # (B, 1, 4) sparse_embeddings, dense_embeddings = model.prompt_encoder( points=None, boxes=box_torch, masks=None, ) low_res_logits, _ = model.mask_decoder( image_embeddings=embedding, image_pe=model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) low_res_pred = torch.sigmoid(low_res_logits) low_res_pred = torch.nn.functional.interpolate( low_res_pred, size=(H, W), mode="bilinear", align_corners=False ) low_res_pred = low_res_pred.squeeze().cpu().numpy() return (low_res_pred > 0.5).astype(np.uint8)