from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator import torch import numpy as np from PIL import Image class SegmentAnything: def __init__(self): sam_checkpoint = 'checkpoint/sam_vit_h_4b8939.pth' model_type = 'vit_h' sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) if torch.cuda.is_available(): sam.to(device='cuda') self.sam = sam def predict(self, image, point_coords, point_labels, box=None): predictor = SamPredictor(self.sam) predictor.set_image(np.array(image, dtype=np.uint8)) return predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) def generate(self, image): mask_generator = SamAutomaticMaskGenerator(self.sam) return mask_generator.generate(np.array(image, dtype=np.uint8)) @staticmethod def makeMaskImage(mask, color): image = Image.new('RGBA', mask.shape) width, height = image.size for x in range(width): for y in range(height): if mask[x, y]: image.putpixel((x, y), color) return image @staticmethod def makeNewImage(image, maskImage): newImage = Image.new('RGBA', image.size) timage = maskImage.copy() width, height = timage.size for x in range(width): for y in range(height): _, _, _, a = timage.getpixel((x, y)) timage.putpixel((x, y), (0, 0, 0, 255) if a > 0 else (0, 0, 0, 0)) newImage.paste(image, (0, 0), timage) return newImage