| import gradio as gr |
| import cv2 |
| import torch |
| import numpy as np |
| from PIL import Image |
| from torchvision import transforms |
| from segment_anything import SamAutomaticMaskGenerator, sam_model_registry |
| |
|
|
| def load_model(model_type): |
| |
| model = sam_model_registry[model_type](checkpoint=f"sam_{model_type}_checkpoint.pth") |
| model.to(device='cuda') |
| return SamAutomaticMaskGenerator(model) |
|
|
| def segment_and_classify(image, model_type): |
| model = load_model(model_type) |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| |
| |
| masks = model.generate(image_cv) |
| |
| |
| segments = [] |
| |
| |
| for mask_data in masks: |
| mask = mask_data['segmentation'] |
| segment = image_cv * np.tile(mask[:, :, None], [1, 1, 3]) |
| segments.append(segment) |
| |
| |
| |
| return Image.fromarray(segments[0]) |
|
|
| iface = gr.Interface( |
| fn=segment_and_classify, |
| inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")], |
| outputs=gr.outputs.Image(type="pil"), |
| title="SAM Model Segmentation and Classification", |
| description="Upload an image, select a model type, and receive the segmented and classified parts." |
| ) |
|
|
| iface.launch() |