Files changed (1) hide show
  1. app.py +0 -233
app.py DELETED
@@ -1,233 +0,0 @@
1
- import os
2
- import urllib
3
- from functools import lru_cache
4
- from random import randint
5
- from typing import Any, Callable, Dict, List, Tuple
6
-
7
- import clip
8
- import cv2
9
- import gradio as gr
10
- import numpy as np
11
- import PIL
12
- import torch
13
- from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
14
-
15
- CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
16
- CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
17
- CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
18
- MODEL_TYPE = "default"
19
- MAX_WIDTH = MAX_HEIGHT = 1024
20
- TOP_K_OBJ = 100
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
-
23
-
24
- @lru_cache
25
- def load_mask_generator() -> SamAutomaticMaskGenerator:
26
- if not os.path.exists(CHECKPOINT_PATH):
27
- os.makedirs(CHECKPOINT_PATH)
28
- checkpoint = os.path.join(CHECKPOINT_PATH, CHECKPOINT_NAME)
29
- if not os.path.exists(checkpoint):
30
- urllib.request.urlretrieve(CHECKPOINT_URL, checkpoint)
31
- sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint).to(device)
32
- mask_generator = SamAutomaticMaskGenerator(sam)
33
- return mask_generator
34
-
35
-
36
- @lru_cache
37
- def load_clip(
38
- name: str = "ViT-B/32",
39
- ) -> Tuple[torch.nn.Module, Callable[[PIL.Image.Image], torch.Tensor]]:
40
- model, preprocess = clip.load(name, device=device)
41
- return model.to(device), preprocess
42
-
43
-
44
- def adjust_image_size(image: np.ndarray) -> np.ndarray:
45
- height, width = image.shape[:2]
46
- if height > width:
47
- if height > MAX_HEIGHT:
48
- height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width)
49
- else:
50
- if width > MAX_WIDTH:
51
- height, width = int(MAX_WIDTH / width * height), MAX_WIDTH
52
- image = cv2.resize(image, (width, height))
53
- return image
54
-
55
-
56
- @torch.no_grad()
57
- def get_score(crop: PIL.Image.Image, texts: List[str]) -> torch.Tensor:
58
- model, preprocess = load_clip()
59
- preprocessed = preprocess(crop).unsqueeze(0).to(device)
60
- tokens = clip.tokenize(texts).to(device)
61
- logits_per_image, _ = model(preprocessed, tokens)
62
- similarity = logits_per_image.softmax(-1).cpu()
63
- return similarity[0, 0]
64
-
65
-
66
- def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
67
- x, y, w, h = mask["bbox"]
68
- masked = image * np.expand_dims(mask["segmentation"], -1)
69
- crop = masked[y : y + h, x : x + w]
70
- if h > w:
71
- top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
72
- else:
73
- top, bottom, left, right = (w - h) // 2, (w - h) // 2, 0, 0
74
- # padding
75
- crop = cv2.copyMakeBorder(
76
- crop,
77
- top,
78
- bottom,
79
- left,
80
- right,
81
- cv2.BORDER_CONSTANT,
82
- value=(0, 0, 0),
83
- )
84
- crop = PIL.Image.fromarray(crop)
85
- return crop
86
-
87
-
88
- def get_texts(query: str) -> List[str]:
89
- return [f"a picture of {query}", "a picture of background"]
90
-
91
-
92
- def filter_masks(
93
- image: np.ndarray,
94
- masks: List[Dict[str, Any]],
95
- predicted_iou_threshold: float,
96
- stability_score_threshold: float,
97
- query: str,
98
- clip_threshold: float,
99
- ) -> List[Dict[str, Any]]:
100
- filtered_masks: List[Dict[str, Any]] = []
101
-
102
- for mask in sorted(masks, key=lambda mask: mask["area"])[-TOP_K_OBJ:]:
103
- if (
104
- mask["predicted_iou"] < predicted_iou_threshold
105
- or mask["stability_score"] < stability_score_threshold
106
- or image.shape[:2] != mask["segmentation"].shape[:2]
107
- or query
108
- and get_score(crop_image(image, mask), get_texts(query)) < clip_threshold
109
- ):
110
- continue
111
-
112
- filtered_masks.append(mask)
113
-
114
- return filtered_masks
115
-
116
-
117
- def draw_masks(
118
- image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7
119
- ) -> np.ndarray:
120
- for mask in masks:
121
- color = [randint(127, 255) for _ in range(3)]
122
-
123
- # draw mask overlay
124
- colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0)
125
- colored_mask = np.moveaxis(colored_mask, 0, -1)
126
- masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
127
- image_overlay = masked.filled()
128
- image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
129
-
130
- # draw contour
131
- contours, _ = cv2.findContours(
132
- np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
133
- )
134
- cv2.drawContours(image, contours, -1, (0, 0, 255), 2)
135
- return image
136
-
137
-
138
- def segment(
139
- predicted_iou_threshold: float,
140
- stability_score_threshold: float,
141
- clip_threshold: float,
142
- image_path: str,
143
- query: str,
144
- ) -> PIL.ImageFile.ImageFile:
145
- mask_generator = load_mask_generator()
146
- image = cv2.imread(image_path, cv2.IMREAD_COLOR)
147
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
148
-
149
- # reduce the size to save gpu memory
150
- image = adjust_image_size(image)
151
- print(image.shape)
152
- masks = mask_generator.generate(image)
153
- # print(masks)
154
- masks = filter_masks(
155
- image,
156
- masks,
157
- predicted_iou_threshold,
158
- stability_score_threshold,
159
- query,
160
- clip_threshold,
161
- )
162
- image = draw_masks(image, masks)
163
- image = PIL.Image.fromarray(image)
164
- return image
165
-
166
-
167
- demo = gr.Interface(
168
- fn=segment,
169
- inputs=[
170
- gr.Slider(0, 1, value=0.9, label="predicted_iou_threshold"),
171
- gr.Slider(0, 1, value=0.8, label="stability_score_threshold"),
172
- gr.Slider(0, 1, value=0.85, label="clip_threshold"),
173
- gr.Image(type="filepath"),
174
- "text",
175
- ],
176
- outputs="image",
177
- allow_flagging="never",
178
- title="Segment Anything with CLIP",
179
- examples=[
180
- [
181
- 0.9,
182
- 0.8,
183
- 0.99,
184
- os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
185
- "dog",
186
- ],
187
- [
188
- 0.9,
189
- 0.8,
190
- 0.75,
191
- os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
192
- "building",
193
- ],
194
- [
195
- 0.9,
196
- 0.8,
197
- 0.998,
198
- os.path.join(os.path.dirname(__file__), "examples/food.jpg"),
199
- "strawberry",
200
- ],
201
- [
202
- 0.9,
203
- 0.8,
204
- 0.75,
205
- os.path.join(os.path.dirname(__file__), "examples/horse.jpg"),
206
- "horse",
207
- ],
208
- [
209
- 0.9,
210
- 0.8,
211
- 0.99,
212
- os.path.join(os.path.dirname(__file__), "examples/bears.jpg"),
213
- "bear",
214
- ],
215
- [
216
- 0.9,
217
- 0.8,
218
- 0.99,
219
- os.path.join(os.path.dirname(__file__), "examples/cats.jpg"),
220
- "cat",
221
- ],
222
- [
223
- 0.9,
224
- 0.8,
225
- 0.99,
226
- os.path.join(os.path.dirname(__file__), "examples/fish.jpg"),
227
- "fish",
228
- ],
229
- ],
230
- )
231
-
232
- if __name__ == "__main__":
233
- demo.launch(share=True)