krishanwalia30 commited on
Commit
87d1643
·
verified ·
1 Parent(s): d661221

Upload 7 files

Browse files
Files changed (7) hide show
  1. YoloV8_Sam.ipynb +0 -0
  2. app.py +143 -0
  3. detectObjects.py +18 -0
  4. packages.txt +1 -0
  5. requirements.txt +129 -0
  6. sam_vit_b_01ec64.pth +3 -0
  7. yolov8n.pt +3 -0
YoloV8_Sam.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
3
+ # from IPython.display import display, Image
4
+ import cv2
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from rembg import remove
8
+ from PIL import Image
9
+
10
+
11
+ # Content of detectObjects.py file
12
+ # import detectObjects
13
+ import ultralytics
14
+ from ultralytics import YOLO
15
+
16
+ model = YOLO('yolov8n.pt')
17
+ sam_checkpoint = "sam_vit_b_01ec64.pth"
18
+ model_type = "vit_b"
19
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
20
+ predictor = SamPredictor(sam)
21
+
22
+ def detected_objects(filename:str):
23
+ results = model.predict(source=filename, conf=0.25)
24
+
25
+ categories = results[0].names
26
+
27
+ dc = []
28
+ for i in range(len(results[0])):
29
+ cat = results[0].boxes[i].cls
30
+ dc.append(categories[int(cat)])
31
+
32
+ print(dc)
33
+ return results, dc
34
+
35
+ def show_mask(mask, ax, random_color=False):
36
+ if random_color:
37
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
38
+ else:
39
+ color = np.array([30/255, 144/255, 255/255, 0.6])
40
+ h, w = mask.shape[-2:]
41
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
42
+ ax.imshow(mask_image)
43
+
44
+ def show_points(coords, labels, ax, marker_size=375):
45
+ pos_points = coords[labels==1]
46
+ neg_points = coords[labels==0]
47
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
48
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
49
+
50
+ def show_box(box, ax):
51
+ x0, y0 = box[0], box[1]
52
+ w, h = box[2] - box[0], box[3] - box[1]
53
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
54
+
55
+ st.title('Extract Objects From Image')
56
+
57
+ uploaded_file = st.file_uploader('Upload an image')
58
+
59
+ if uploaded_file is not None:
60
+ # To read file as bytes:
61
+ bytes_data = uploaded_file.getvalue()
62
+ with open('uploaded_file.png','wb') as file:
63
+ file.write(uploaded_file.getvalue())
64
+
65
+ # Detect objects in the uploaded image
66
+ # results, dc = detectObjects.detected_objects('uploaded_file.png')
67
+ results, dc = detected_objects('uploaded_file.png')
68
+
69
+ st.write(dc)
70
+
71
+ option = st.selectbox("Which object would you like to extract?", tuple(dc))
72
+ # print(option)
73
+ index_of_the_choosen_detected_object = tuple(dc).index(option)
74
+
75
+ if st.button('Extract'):
76
+ for result in results:
77
+ boxes = result.boxes
78
+
79
+ bbox=boxes.xyxy.tolist()[index_of_the_choosen_detected_object]
80
+ # sam_checkpoint = "sam_vit_b_01ec64.pth"
81
+ # model_type = "vit_b"
82
+ # sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
83
+ # predictor = SamPredictor(sam)
84
+
85
+ image = cv2.cvtColor(cv2.imread('uploaded_file.png'), cv2.COLOR_BGR2RGB)
86
+ predictor.set_image(image)
87
+
88
+ input_box = np.array(bbox)
89
+
90
+ masks, _, _ = predictor.predict(
91
+ point_coords=None,
92
+ point_labels=None,
93
+ box=input_box[None, :],
94
+ multimask_output=False,
95
+ )
96
+
97
+ # plt.figure(figsize=(10, 10))
98
+ # st.image(image)
99
+ # plt.imshow(image)
100
+ # show_mask(masks[0], plt.gca())
101
+ # show_box(input_box, plt.gca())
102
+ # plt.axis('off')
103
+ # plt.show()
104
+
105
+ segmentation_mask = masks[0]
106
+ binary_mask = np.where(segmentation_mask > 0.5, 1, 0)
107
+
108
+ white_background = np.ones_like(image) * 255
109
+
110
+ new_image = white_background * (1 - binary_mask[..., np.newaxis]) + image * binary_mask[..., np.newaxis]
111
+
112
+
113
+ plt.imsave('extracted_image.jpg', new_image.astype(np.uint8))
114
+ # st.image('extracted_image.jpg')
115
+
116
+ # Store path of the image in the variable input_path
117
+ input_path = 'extracted_image.jpg'
118
+
119
+ # Store path of the output image in the variable output_path
120
+ output_path = 'finalExtracted.png'
121
+
122
+ # Processing the image
123
+ input = Image.open(input_path)
124
+
125
+ # Removing the background from the given Image
126
+ output = remove(input)
127
+
128
+ #Saving the image in the given path
129
+ output.save(output_path)
130
+ # st.image(output_path)
131
+
132
+ with open("finalExtracted.png", "rb") as file:
133
+ btn = st.download_button(
134
+ label="Download final image",
135
+ data=file,
136
+ file_name="finalExtracted.png",
137
+ mime="image/png",
138
+ )
139
+
140
+ # bbox=boxes.xyxy.tolist()[0]
141
+
142
+
143
+
detectObjects.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ultralytics
2
+ from ultralytics import YOLO
3
+
4
+ model = YOLO('yolov8n.pt')
5
+
6
+ def detected_objects(filename:str):
7
+ results = model.predict(source=filename, conf=0.25)
8
+
9
+ categories = results[0].names
10
+
11
+ dc = []
12
+ for i in range(len(results[0])):
13
+ cat = results[0].boxes[i].cls
14
+ dc.append(categories[int(cat)])
15
+
16
+ print(dc)
17
+ return results, dc
18
+
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libgl1
requirements.txt ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # altair==5.4.0
2
+ # asttokens==2.4.1
3
+ # attrs==24.2.0
4
+ # blinker==1.8.2
5
+ # cachetools==5.5.0
6
+ # certifi==2024.7.4
7
+ # charset-normalizer==3.3.2
8
+ # click==8.1.7
9
+ # coloredlogs==15.0.1
10
+ # comm==0.2.1
11
+ # contourpy==1.2.1
12
+ # cycler==0.12.1
13
+ # debugpy==1.8.0
14
+ # decorator==5.1.1
15
+ # # exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1720869315914/work
16
+ # executing==2.0.1
17
+ # filelock==3.15.4
18
+ # flatbuffers==24.3.25
19
+ # fonttools==4.53.1
20
+ # fsspec==2024.6.1
21
+ # gitdb==4.0.11
22
+ # GitPython==3.1.43
23
+ # humanfriendly==10.0
24
+ # idna==3.8
25
+ # imageio==2.35.1
26
+ # # importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1724187233579/work
27
+ # intel-openmp==2021.4.0
28
+ # ipykernel==6.29.0
29
+ # ipython==8.20.0
30
+ # jedi==0.19.1
31
+ # Jinja2==3.1.4
32
+ # jsonschema==4.23.0
33
+ # jsonschema-specifications==2023.12.1
34
+ # jupyter_client==8.6.0
35
+ # jupyter_core==5.7.1
36
+ # kiwisolver==1.4.5
37
+ # lazy_loader==0.4
38
+ # llvmlite==0.43.0
39
+ # markdown-it-py==3.0.0
40
+ # MarkupSafe==2.1.5
41
+ # matplotlib==3.9.2
42
+ # matplotlib-inline==0.1.6
43
+ # mdurl==0.1.2
44
+ # mkl==2021.4.0
45
+ # mpmath==1.3.0
46
+ # narwhals==1.5.5
47
+ # nest-asyncio==1.5.9
48
+ # networkx==3.3
49
+ # numba==0.60.0
50
+ # numpy==1.26.4
51
+ # onnxruntime==1.19.0
52
+ # opencv-contrib-python-headless==4.10.0.84
53
+ # opencv-python==4.10.0.84
54
+ # opencv-python-headless==4.10.0.84
55
+ # packaging==23.2
56
+ # pandas==2.2.2
57
+ # parso==0.8.3
58
+ # # pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
59
+ # pillow==10.4.0
60
+ # # platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
61
+ # pooch==1.8.2
62
+ # prompt-toolkit==3.0.43
63
+ # protobuf==5.27.3
64
+ # psutil==5.9.7
65
+ # pure-eval==0.2.2
66
+ # py-cpuinfo==9.0.0
67
+ # pyarrow==17.0.0
68
+ # pydeck==0.9.1
69
+ # Pygments==2.17.2
70
+ # PyMatting==1.1.12
71
+ # pyparsing==3.1.3
72
+ # pyreadline3==3.4.1
73
+ # python-dateutil==2.8.2
74
+ # pytz==2024.1
75
+ # pywin32==306
76
+ # PyYAML==6.0.2
77
+ # pyzmq==25.1.2
78
+ # referencing==0.35.1
79
+ # rembg==2.0.58
80
+ # requests==2.32.3
81
+ # rich==13.7.1
82
+ # rpds-py==0.20.0
83
+ # scikit-image==0.24.0
84
+ # scipy==1.14.1
85
+ # seaborn==0.13.2
86
+ # segment_anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
87
+ # six==1.16.0
88
+ # smmap==5.0.1
89
+ # stack-data==0.6.3
90
+ # streamlit==1.37.1
91
+ # sympy==1.13.2
92
+ # tbb==2021.13.1
93
+ # tenacity==8.5.0
94
+ # tifffile==2024.8.10
95
+ # toml==0.10.2
96
+ # torch==2.3.0
97
+ # torchvision==0.18.0
98
+ # tornado==6.4
99
+ # tqdm==4.66.5
100
+ # traitlets==5.14.1
101
+ # # typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work
102
+ # tzdata==2024.1
103
+ # ultralytics==8.2.81
104
+ # ultralytics-thop==2.0.5
105
+ # urllib3==2.2.2
106
+ # watchdog==4.0.2
107
+ # wcwidth==0.2.13
108
+ # wget==3.2
109
+ # # zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1723591248676/work
110
+
111
+
112
+
113
+ # New Requirements.txt file for Deployment
114
+
115
+ matplotlib
116
+
117
+ # opencv-python-headless
118
+ # opencv-contrib-python-headless==4.10.0.84
119
+ opencv-contrib-python-headless
120
+ # opencv-python==4.10.0.84
121
+ # opencv-python-headless==4.10.0.84
122
+
123
+ segment_anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
124
+
125
+ rembg
126
+
127
+ streamlit
128
+
129
+ ultralytics==8.2.81
sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383
yolov8n.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f59b3d833e2ff32e194b5bb8e08d211dc7c5bdf144b90d2c8412c47ccfc83b36
3
+ size 6549796