Spaces:
Sleeping
Sleeping
| import os | |
| os.system("pip install ultralytics") | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from PIL.Image import Image | |
| from ultralytics import SAM | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| class SAMModel: | |
| def __init__(self): | |
| model_path = 'mobile_sam.pt' | |
| self.model = SAM(model_path) | |
| def mobilesam_point_predict(self, image, x, y): | |
| result = self.model.predict(image, points=[x, y], labels=[1]) | |
| plotted = result[0].plot() | |
| plotted = cv2.cvtColor(np.array(plotted), cv2.COLOR_BGR2RGB) | |
| return plotted | |
| def mobile_bbox_predict(self, image: Image, bbox: str) -> np.ndarray: | |
| # Parse the bounding box string | |
| bbox_list = list(map(int, bbox.split(','))) | |
| # Predict a segment based on a box prompt | |
| result = self.model.predict(image, bboxes=bbox_list) | |
| plotted = result[0].plot() | |
| plotted = cv2.cvtColor(np.array(plotted), cv2.COLOR_BGR2RGB) | |
| return plotted | |
| def launch(self): | |
| """Launches the Gradio interface.""" | |
| # Create the UI | |
| with gr.Blocks() as app: | |
| # Header | |
| gr.Markdown("# SAM Model Demo") | |
| # Tabs | |
| with gr.Tabs(): | |
| # Point-predict-button Tab | |
| with gr.TabItem("point-predict"): | |
| with gr.Column(): | |
| inputs = [ | |
| gr.inputs.Image(type='pil', label='Input Image'), | |
| gr.inputs.Number(default=900, label='X Coordinate'), | |
| gr.inputs.Number(default=370, label='Y Coordinate'), | |
| ] | |
| output = gr.outputs.Image(type='pil', label='Output Image') | |
| point_predict_button = gr.Button("inference") | |
| # Run object detection on the input image when the button is clicked | |
| point_predict_button.click(self.mobilesam_point_predict, | |
| inputs=inputs, | |
| outputs=output) | |
| # Bbox-predict-button Tab | |
| with gr.TabItem("bbox-predict"): | |
| image_input = gr.inputs.Image(type='pil') | |
| text_input = gr.inputs.Textbox(lines=1, label="Bounding Box (x1, y1, x2, y2)", default="439, 437, 524, 709") | |
| image_output = gr.outputs.Image('pil') | |
| inputs = [image_input, text_input] | |
| output = image_output | |
| point_predict_button = gr.Button("inference") | |
| # Run object detection on the input image when the button is clicked | |
| point_predict_button.click(self.mobile_bbox_predict, | |
| inputs=inputs, | |
| outputs=output) | |
| app.launch() | |
| if __name__ == '__main__': | |
| web_ui = SAMModel() | |
| web_ui.launch() | |