Spaces:
Runtime error
Runtime error
| import types | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from distinctipy import distinctipy | |
| from segment_anything import (SamAutomaticMaskGenerator, SamPredictor, | |
| sam_model_registry) | |
| from torch.nn import functional as F | |
| def get_color(): | |
| return distinctipy.get_colors(200) | |
| def medsam_preprocess(self, x: torch.Tensor) -> torch.Tensor: | |
| """Normalize pixel values and pad to a square input.""" | |
| # Normalize colors | |
| x = (x - x.min()) / torch.clip( | |
| x.max() - x.min(), min=1e-8, max=None) # normalize to [0, 1], (H, W, 3) | |
| # Pad | |
| h, w = x.shape[-2:] | |
| padh = self.image_encoder.img_size - h | |
| padw = self.image_encoder.img_size - w | |
| x = F.pad(x, (0, padw, 0, padh)) | |
| return x | |
| def get_model(checkpoint='checkpoint/sam_vit_b_01ec64.pth'): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = sam_model_registry['vit_b'](checkpoint=checkpoint) | |
| # Replace preprocess function | |
| funcType = types.MethodType | |
| model.preprocess = funcType(medsam_preprocess, model) | |
| model.mask_threshold = 0.5 | |
| model = model.to(device) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| predictor = SamPredictor(model) | |
| mask_generator = SamAutomaticMaskGenerator(model) | |
| return predictor, mask_generator | |
| def show_everything(sorted_anns): | |
| if len(sorted_anns) == 0: | |
| return np.array([]) | |
| #sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True) | |
| h, w = sorted_anns[0]['segmentation'].shape[-2:] | |
| #sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)] | |
| mask = np.zeros((h,w,4)) | |
| for ann in sorted_anns: | |
| m = ann['segmentation'] | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| mask += m.reshape(h,w,1) * color.reshape(1, 1, -1) | |
| mask = mask * 255 | |
| return mask.astype(np.uint8) | |
| def show_click(masks, colors): | |
| h, w = masks[0].shape[-2:] | |
| masks_total = np.zeros((h,w,4)).astype(np.uint8) | |
| for mask, color in zip(masks, colors): | |
| if np.array_equal(mask,np.array([])):continue | |
| masks = np.zeros((h,w,4)).astype(np.uint8) | |
| masks = masks + mask.reshape(h,w,1).astype(np.uint8) | |
| masks = masks.astype(bool).astype(np.uint8) | |
| masks = masks * 255 * color.reshape(1, 1, -1) | |
| masks_total += masks.astype(np.uint8) | |
| return masks_total | |
| def model_predict_masks_click(model,input_points,input_labels): | |
| if input_points == []:return np.array([]) | |
| input_labels = np.array(input_labels) | |
| input_points = np.array(input_points) | |
| masks, _, _ = model.predict( | |
| point_coords=input_points, | |
| point_labels=input_labels, | |
| multimask_output=False, | |
| ) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return masks | |
| def model_predict_masks_box(model,center_point,center_label,input_box): | |
| masks = np.array([]) | |
| for i in range(len(center_label)): | |
| if center_point[i] == []:continue | |
| center_point_1 = np.array([center_point[i]]) | |
| center_label_1 = np.array(center_label[i]) | |
| input_box_1 = np.array(input_box[i]) | |
| mask, _, _ = model.predict( | |
| point_coords=center_point_1, | |
| point_labels=center_label_1, | |
| box=input_box_1, | |
| multimask_output=False, | |
| ) | |
| try: | |
| masks = masks + mask | |
| except: | |
| masks = mask | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return masks | |
| def model_predict_masks_everything(mask_generator, image): | |
| masks = mask_generator.generate(image) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return masks |