Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from segment_anything import sam_model_registry, SamPredictor | |
| # Load model | |
| checkpoint = "sam_vit_h_4b8939.pth" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_type = "vit_h" | |
| sam = sam_model_registry[model_type](checkpoint=checkpoint) | |
| sam.to(device) | |
| predictor = SamPredictor(sam) | |
| def segment_image(input_img): | |
| np_img = np.array(input_img) | |
| image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) | |
| predictor.set_image(image) | |
| h, w, _ = image.shape | |
| input_point = np.array([[w // 2, h // 2]]) | |
| input_label = np.array([1]) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=False | |
| ) | |
| mask = masks[0].astype(np.uint8) * 255 | |
| return Image.fromarray(mask) | |
| # UI | |
| iface = gr.Interface(fn=segment_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Image(type="pil"), | |
| title="Segment Anything Model", | |
| description="Upload an image and get a segmentation mask.") | |
| iface.launch() | |