Spaces:
Sleeping
Sleeping
| from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| class SegmentAnything: | |
| def __init__(self): | |
| sam_checkpoint = 'checkpoint/sam_vit_h_4b8939.pth' | |
| model_type = 'vit_h' | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
| if torch.cuda.is_available(): | |
| sam.to(device='cuda') | |
| self.sam = sam | |
| def predict(self, image, point_coords, point_labels, box=None): | |
| predictor = SamPredictor(self.sam) | |
| predictor.set_image(np.array(image, dtype=np.uint8)) | |
| return predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) | |
| def generate(self, image): | |
| mask_generator = SamAutomaticMaskGenerator(self.sam) | |
| return mask_generator.generate(np.array(image, dtype=np.uint8)) | |
| def makeMaskImage(mask, color): | |
| image = Image.new('RGBA', mask.shape) | |
| width, height = image.size | |
| for x in range(width): | |
| for y in range(height): | |
| if mask[x, y]: | |
| image.putpixel((x, y), color) | |
| return image | |
| def makeNewImage(image, maskImage): | |
| newImage = Image.new('RGBA', image.size) | |
| timage = maskImage.copy() | |
| width, height = timage.size | |
| for x in range(width): | |
| for y in range(height): | |
| _, _, _, a = timage.getpixel((x, y)) | |
| timage.putpixel((x, y), (0, 0, 0, 255) if a > 0 else (0, 0, 0, 0)) | |
| newImage.paste(image, (0, 0), timage) | |
| return newImage | |