Spaces:
Sleeping
Sleeping
| """ | |
| Created By: ishwor subedi | |
| Date: 2024-07-10 | |
| """ | |
| import os.path | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| import wget | |
| from PIL import Image, ImageOps | |
| from tqdm import tqdm | |
| from ultralytics import YOLO | |
| from segment_anything import SamPredictor, sam_model_registry | |
| class Segmentation: | |
| def __init__(self): | |
| model_path = "artifacts/segmentation/yolov8x-seg.pt" | |
| self.segmentation_model = YOLO(model=model_path) | |
| def segment_image(self, image_path: str): | |
| results = self.segmentation_model(image_path, show=True) | |
| return results | |
| class SegmentAnything: | |
| def __init__(self, device="cpu"): | |
| self.model_name = "sam_vit_l_0b3195.pth" | |
| self.model_download() | |
| self.sam = sam_model_registry["vit_l"](checkpoint="artifacts/segmentation/sam_vit_l_0b3195.pth").to(device) | |
| self.samPredictor = SamPredictor(self.sam) | |
| def model_download(self): | |
| if os.path.exists(f"artifacts/segmentation/{self.model_name}"): | |
| print(f"{self.model_name} model already exists.") | |
| else: | |
| print(f"Downloading {self.model_name} model...") | |
| url = f"https://dl.fbaipublicfiles.com/segment_anything/{self.model_name}" | |
| response = requests.get(url, stream=True) | |
| total_size_in_bytes = int(response.headers.get('content-length', 0)) | |
| block_size = 1024 | |
| progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | |
| with open(f"artifacts/segmentation/{self.model_name}", 'wb') as file: | |
| for data in response.iter_content(block_size): | |
| progress_bar.update(len(data)) | |
| file.write(data) | |
| progress_bar.close() | |
| if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
| print("ERROR, something went wrong") | |
| def generate_mask(self, image, selected_points, deselected_points): | |
| selected_pixels = [] | |
| deselected_pixels = [] | |
| selected_pixels.append(selected_points) | |
| deselected_pixels.append(deselected_points) | |
| self.samPredictor.set_image(image) | |
| points = np.array(selected_pixels) | |
| label = np.ones(points.shape[0]) | |
| mask, _, _ = self.samPredictor.predict( | |
| point_coords=points, | |
| point_labels=label, | |
| ) | |
| mask = Image.fromarray(mask[0, :, :]) | |
| mask_img = ImageOps.invert(mask) | |
| return mask_img | |
| if __name__ == '__main__': | |
| segment_anything = SegmentAnything() | |
| image_path = "/home/ishwor/Pictures/01.TEST/alia/5869473_dark_lean.png" | |
| image = cv2.imread(image_path) | |
| mask = segment_anything.generate_mask(image, (20, 20), | |
| (20, 20)) | |
| maskimage = np.array(mask) | |
| image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| print(maskimage.shape) | |
| cv2.imshow("image", image) | |
| cv2.waitKey(0) | |