isLinXu commited on
Commit
4a0fdfe
·
1 Parent(s): 3e0ac75

update app

Browse files
Files changed (2) hide show
  1. app.py +80 -0
  2. requirements.txt +20 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip install ultralytics")
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL.Image import Image
8
+ from ultralytics import SAM
9
+ import warnings
10
+
11
+ warnings.filterwarnings("ignore")
12
+
13
+ class SAMModel:
14
+ def __init__(self):
15
+ model_path = 'mobile_sam.pt'
16
+ self.model = SAM(model_path)
17
+
18
+ def mobilesam_point_predict(self, image, x, y):
19
+ result = self.model.predict(image, points=[x, y], labels=[1])
20
+ plotted = result[0].plot()
21
+ plotted = cv2.cvtColor(np.array(plotted), cv2.COLOR_BGR2RGB)
22
+ return plotted
23
+
24
+ def mobile_bbox_predict(self, image: Image, bbox: str) -> np.ndarray:
25
+ # Parse the bounding box string
26
+ bbox_list = list(map(int, bbox.split(',')))
27
+
28
+ # Predict a segment based on a box prompt
29
+ result = self.model.predict(image, bboxes=bbox_list)
30
+ plotted = result[0].plot()
31
+ plotted = cv2.cvtColor(np.array(plotted), cv2.COLOR_BGR2RGB)
32
+ return plotted
33
+
34
+
35
+ def launch(self):
36
+ """Launches the Gradio interface."""
37
+ # Create the UI
38
+ with gr.Blocks() as app:
39
+ # Header
40
+ gr.Markdown("# SAM Model Demo")
41
+
42
+ # Tabs
43
+ with gr.Tabs():
44
+ # Point-predict-button Tab
45
+ with gr.TabItem("point-predict"):
46
+ with gr.Column():
47
+ inputs = [
48
+ gr.inputs.Image(type='pil', label='Input Image'),
49
+ gr.inputs.Number(default=900, label='X Coordinate'),
50
+ gr.inputs.Number(default=370, label='Y Coordinate'),
51
+ ]
52
+
53
+ output = gr.outputs.Image(type='pil', label='Output Image')
54
+ point_predict_button = gr.Button("inference")
55
+
56
+ # Run object detection on the input image when the button is clicked
57
+ point_predict_button.click(self.mobilesam_point_predict,
58
+ inputs=inputs,
59
+ outputs=output)
60
+
61
+ # Bbox-predict-button Tab
62
+ with gr.TabItem("bbox-predict"):
63
+ image_input = gr.inputs.Image(type='pil')
64
+ text_input = gr.inputs.Textbox(lines=1, label="Bounding Box (x1, y1, x2, y2)", default="439, 437, 524, 709")
65
+ image_output = gr.outputs.Image('pil')
66
+ inputs = [image_input, text_input]
67
+ output = image_output
68
+ point_predict_button = gr.Button("inference")
69
+
70
+ # Run object detection on the input image when the button is clicked
71
+ point_predict_button.click(self.mobile_bbox_predict,
72
+ inputs=inputs,
73
+ outputs=output)
74
+
75
+ app.launch(share=True)
76
+
77
+
78
+ if __name__ == '__main__':
79
+ web_ui = SAMModel()
80
+ web_ui.launch()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wget~=3.2
2
+ opencv-python~=4.6.0.66
3
+ numpy~=1.23.0
4
+ torch~=1.13.1
5
+ torchvision~=0.14.1
6
+ pillow~=9.4.0
7
+ gradio~=3.42.0
8
+ ultralytics~=8.0.169
9
+ pyyaml~=6.0
10
+ wandb~=0.13.11
11
+ tqdm~=4.65.0
12
+ matplotlib~=3.7.1
13
+ pandas~=2.0.0
14
+ seaborn~=0.12.2
15
+ requests~=2.31.0
16
+ psutil~=5.9.4
17
+ thop~=0.1.1-2209072238
18
+ timm~=0.9.2
19
+ super-gradients~=3.2.0
20
+ openmim