import gradio as gr import torch import numpy as np from PIL import Image import cv2 from segment_anything import sam_model_registry, SamPredictor # Load model checkpoint = "sam_vit_h_4b8939.pth" device = "cuda" if torch.cuda.is_available() else "cpu" model_type = "vit_h" sam = sam_model_registry[model_type](checkpoint=checkpoint) sam.to(device) predictor = SamPredictor(sam) def segment_image(input_img): np_img = np.array(input_img) image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) predictor.set_image(image) h, w, _ = image.shape input_point = np.array([[w // 2, h // 2]]) input_label = np.array([1]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=False ) mask = masks[0].astype(np.uint8) * 255 return Image.fromarray(mask) # UI iface = gr.Interface(fn=segment_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Segment Anything Model", description="Upload an image and get a segmentation mask.") iface.launch()