Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from skimage import transform | |
| from segment_anything import sam_model_registry | |
| MEDSAM_IMG_INPUT_SIZE = 1024 | |
| def load_model(checkpoint_path): | |
| model = sam_model_registry["vit_b"](checkpoint=checkpoint_path) | |
| model = model.to("cpu") | |
| model.eval() | |
| return model, torch.device("cpu") | |
| def get_embedding(model, img_np, device): | |
| img_1024 = transform.resize( | |
| img_np, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True | |
| ).astype(np.uint8) | |
| img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), 1e-8, None) | |
| img_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) | |
| return model.image_encoder(img_tensor) | |
| def run(model, embedding, box_1024, H, W): | |
| box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=embedding.device) | |
| if len(box_torch.shape) == 2: | |
| box_torch = box_torch[:, None, :] # (B, 1, 4) | |
| sparse_embeddings, dense_embeddings = model.prompt_encoder( | |
| points=None, | |
| boxes=box_torch, | |
| masks=None, | |
| ) | |
| low_res_logits, _ = model.mask_decoder( | |
| image_embeddings=embedding, | |
| image_pe=model.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=False, | |
| ) | |
| low_res_pred = torch.sigmoid(low_res_logits) | |
| low_res_pred = torch.nn.functional.interpolate( | |
| low_res_pred, size=(H, W), mode="bilinear", align_corners=False | |
| ) | |
| low_res_pred = low_res_pred.squeeze().cpu().numpy() | |
| return (low_res_pred > 0.5).astype(np.uint8) | |