dheena commited on
Commit
7a75e77
·
1 Parent(s): 58597bf
.vscode/launch.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Streamlit App",
6
+ "type": "debugpy",
7
+ "request": "launch",
8
+ "module": "streamlit",
9
+ "args": [
10
+ "run",
11
+ "${workspaceFolder}/src/streamlit_app.py"
12
+ ],
13
+ "console": "integratedTerminal",
14
+ "justMyCode": false
15
+ }
16
+ ]
17
+ }
18
+
ViT-B-32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
3
+ size 353976522
src/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
src/__pycache__/segmentation.cpython-310.pyc ADDED
Binary file (5.96 kB). View file
 
src/image-segmentation.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[30]:
5
+
6
+
7
+ import random
8
+ from dataclasses import dataclass
9
+ from typing import Any, List, Dict, Optional, Union, Tuple
10
+ import os
11
+
12
+ import cv2
13
+ import torch
14
+ import requests
15
+ import numpy as np
16
+ from PIL import Image
17
+ import clip
18
+ import plotly.express as px
19
+ from datetime import datetime
20
+ import matplotlib.pyplot as plt
21
+ import plotly.graph_objects as go
22
+ from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
23
+
24
+ # In[2]:
25
+
26
+
27
+ @dataclass
28
+ class BoundingBox:
29
+ xmin: int
30
+ ymin: int
31
+ xmax: int
32
+ ymax: int
33
+
34
+ @property
35
+ def xyxy(self) -> List[float]:
36
+ return [self.xmin, self.ymin, self.xmax, self.ymax]
37
+
38
+ @dataclass
39
+ class DetectionResult:
40
+ score: float
41
+ label: str
42
+ box: BoundingBox
43
+ mask: Optional[np.array] = None
44
+
45
+ @classmethod
46
+ def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
47
+ return cls(score=detection_dict['score'],
48
+ label=detection_dict['label'],
49
+ box=BoundingBox(xmin=detection_dict['box']['xmin'],
50
+ ymin=detection_dict['box']['ymin'],
51
+ xmax=detection_dict['box']['xmax'],
52
+ ymax=detection_dict['box']['ymax']))
53
+
54
+
55
+ # In[3]:
56
+
57
+
58
+ def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult]) -> np.ndarray:
59
+ # Convert PIL Image to OpenCV format
60
+ image_cv2 = np.array(image) if isinstance(image, Image.Image) else image
61
+ image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_RGB2BGR)
62
+
63
+ # Iterate over detections and add bounding boxes and masks
64
+ for detection in detection_results:
65
+ label = detection.label
66
+ score = detection.score
67
+ box = detection.box
68
+ mask = detection.mask
69
+
70
+ # Sample a random color for each detection
71
+ color = np.random.randint(0, 256, size=3)
72
+
73
+ # Draw bounding box
74
+ cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color.tolist(), 2)
75
+ cv2.putText(imagUnione_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 2)
76
+
77
+ # If mask is available, apply it
78
+ if mask is not None:
79
+ # Convert mask to uint8
80
+ mask_uint8 = (mask * 255).astype(np.uint8)
81
+ contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
82
+ cv2.drawContours(image_cv2, contours, -1, color.tolist(), 2)
83
+
84
+ return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
85
+
86
+ def plot_detections(
87
+ image: Union[Image.Image, np.ndarray],
88
+ detections: List[DetectionResult],
89
+ save_name: Optional[str] = None
90
+ ) -> None:
91
+ annotated_image = annotate(image, detections)
92
+ plt.imshow(annotated_image)
93
+ plt.axis('off')
94
+ if save_name:
95
+ plt.savefig(save_name, bbox_inches='tight')
96
+ plt.show()
97
+
98
+
99
+
100
+ # In[4]:
101
+
102
+
103
+ def random_named_css_colors(num_colors: int) -> List[str]:
104
+ """
105
+ Returns a list of randomly selected named CSS colors.
106
+
107
+ Args:
108
+ - num_colors (int): Number of random colors to generate.
109
+
110
+ Returns:
111
+ - list: List of randomly selected named CSS colors.
112
+ """
113
+ # List of named CSS colors
114
+ named_css_colors = [
115
+ 'aliceblue', 'antiquewhite', 'aqua', 'aquamarine', 'azure', 'beige', 'bisque', 'black', 'blanchedalmond',
116
+ 'blue', 'blueviolet', 'brown', 'burlywood', 'cadetblue', 'chartreuse', 'chocolate', 'coral', 'cornflowerblue',
117
+ 'cornsilk', 'crimson', 'cyan', 'darkblue', 'darkcyan', 'darkgoldenrod', 'darkgray', 'darkgreen', 'darkgrey',
118
+ 'darkkhaki', 'darkmagenta', 'darkolivegreen', 'darkorange', 'darkorchid', 'darkred', 'darksalmon', 'darkseagreen',
119
+ 'darkslateblue', 'darkslategray', 'darkslategrey', 'darkturquoise', 'darkviolet', 'deeppink', 'deepskyblue',
120
+ 'dimgray', 'dimgrey', 'dodgerblue', 'firebrick', 'floralwhite', 'forestgreen', 'fuchsia', 'gainsboro', 'ghostwhite',
121
+ 'gold', 'goldenrod', 'gray', 'green', 'greenyellow', 'grey', 'honeydew', 'hotpink', 'indianred', 'indigo', 'ivory',
122
+ 'khaki', 'lavender', 'lavenderblush', 'lawngreen', 'lemonchiffon', 'lightblue', 'lightcoral', 'lightcyan', 'lightgoldenrodyellow',
123
+ 'lightgray', 'lightgreen', 'lightgrey', 'lightpink', 'lightsalmon', 'lightseagreen', 'lightskyblue', 'lightslategray',
124
+ 'lightslategrey', 'lightsteelblue', 'lightyellow', 'lime', 'limegreen', 'linen', 'magenta', 'maroon', 'mediumaquamarine',
125
+ 'mediumblue', 'mediumorchid', 'mediumpurple', 'mediumseagreen', 'mediumslateblue', 'mediumspringgreen', 'mediumturquoise',
126
+ 'mediumvioletred', 'midnightblue', 'mintcream', 'mistyrose', 'moccasin', 'navajowhite', 'navy', 'oldlace', 'olive',
127
+ 'olivedrab', 'orange', 'orangered', 'orchid', 'palegoldenrod', 'palegreen', 'paleturquoise', 'palevioletred', 'papayawhip',
128
+ 'peachpuff', 'peru', 'pink', 'plum', 'powderblue', 'purple', 'rebeccapurple', 'red', 'rosybrown', 'royalblue', 'saddlebrown',
129
+ 'salmon', 'sandybrown', 'seagreen', 'seashell', 'sienna', 'silver', 'skyblue', 'slateblue', 'slategray', 'slategrey',
130
+ 'snow', 'springgreen', 'steelblue', 'tan', 'teal', 'thistle', 'tomato', 'turquoise', 'violet', 'wheat', 'white',
131
+ 'whitesmoke', 'yellow', 'yellowgreen'
132
+ ]
133
+
134
+ # Sample random named CSS colors
135
+ return random.sample(named_css_colors, min(num_colors, len(named_css_colors)))
136
+
137
+ def plot_detections_plotly(
138
+ image: np.ndarray,
139
+ detections: List[DetectionResult],
140
+ class_colors: Optional[Dict[str, str]] = None
141
+ ) -> None:
142
+ # If class_colors is not provided, generate random colors for each class
143
+ if class_colors is None:
144
+ num_detections = len(detections)
145
+ colors = random_named_css_colors(num_detections)
146
+ class_colors = {}
147
+ for i in range(num_detections):
148
+ class_colors[i] = colors[i]
149
+
150
+
151
+ fig = px.imshow(image)
152
+
153
+ # Add bounding boxes
154
+ shapes = []
155
+ annotations = []
156
+ for idx, detection in enumerate(detections):
157
+ label = detection.label
158
+ box = detection.box
159
+ score = detection.score
160
+ mask = detection.mask
161
+
162
+ polygon = mask_to_polygon(mask)
163
+
164
+ fig.add_trace(go.Scatter(
165
+ x=[point[0] for point in polygon] + [polygon[0][0]],
166
+ y=[point[1] for point in polygon] + [polygon[0][1]],
167
+ mode='lines',
168
+ line=dict(color=class_colors[idx], width=2),
169
+ fill='toself',
170
+ name=f"{label}: {score:.2f}"
171
+ ))
172
+
173
+ xmin, ymin, xmax, ymax = box.xyxy
174
+ shape = [
175
+ dict(
176
+ type="rect",
177
+ xref="x", yref="y",
178
+ x0=xmin, y0=ymin,
179
+ x1=xmax, y1=ymax,
180
+ line=dict(color=class_colors[idx])
181
+ )
182
+ ]
183
+ annotation = [
184
+ dict(
185
+ x=(xmin+xmax) // 2, y=(ymin+ymax) // 2,
186
+ xref="x", yref="y",
187
+ text=f"{label}: {score:.2f}",
188
+ )
189
+ ]
190
+
191
+ shapes.append(shape)
192
+ annotations.append(annotation)
193
+
194
+ # Update layout
195
+ button_shapes = [dict(label="None",method="relayout",args=["shapes", []])]
196
+ button_shapes = button_shapes + [
197
+ dict(label=f"Detection {idx+1}",method="relayout",args=["shapes", shape]) for idx, shape in enumerate(shapes)
198
+ ]
199
+ button_shapes = button_shapes + [dict(label="All", method="relayout", args=["shapes", sum(shapes, [])])]
200
+
201
+ fig.update_layout(
202
+ xaxis=dict(visible=False),
203
+ yaxis=dict(visible=False),
204
+ # margin=dict(l=0, r=0, t=0, b=0),
205
+ showlegend=True,
206
+ updatemenus=[
207
+ dict(
208
+ type="buttons",
209
+ direction="up",
210
+ buttons=button_shapes
211
+ )
212
+ ],
213
+ legend=dict(
214
+ orientation="h",
215
+ yanchor="bottom",
216
+ y=1.02,
217
+ xanchor="right",
218
+ x=1
219
+ )
220
+ )
221
+
222
+ # Show plot
223
+ fig.show()
224
+
225
+
226
+ # In[5]:
227
+
228
+
229
+ def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
230
+ # Find contours in the binary mask
231
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
232
+
233
+ # Find the contour with the largest area
234
+ largest_contour = max(contours, key=cv2.contourArea)
235
+
236
+ # Extract the vertices of the contour
237
+ polygon = largest_contour.reshape(-1, 2).tolist()
238
+
239
+ return polygon
240
+
241
+ def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
242
+ """
243
+ Convert a polygon to a segmentation mask.
244
+
245
+ Args:
246
+ - polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
247
+ - image_shape (tuple): Shape of the image (height, width) for the mask.
248
+
249
+ Returns:
250
+ - np.ndarray: Segmentation mask with the polygon filled.
251
+ """
252
+ # Create an empty mask
253
+ mask = np.zeros(image_shape, dtype=np.uint8)
254
+
255
+ # Convert polygon to an array of points
256
+ pts = np.array(polygon, dtype=np.int32)
257
+
258
+ # Fill the polygon with white color (255)
259
+ cv2.fillPoly(mask, [pts], color=(255,))
260
+
261
+ return mask
262
+
263
+ def load_image(image_str: str) -> Image.Image:
264
+ if image_str.startswith("http"):
265
+ image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
266
+ else:
267
+ image = Image.open(image_str).convert("RGB")
268
+
269
+ return image
270
+
271
+ def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
272
+ boxes = []
273
+ for result in results:
274
+ xyxy = result.box.xyxy
275
+ boxes.append(xyxy)
276
+
277
+ return [boxes]
278
+
279
+ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
280
+ masks = masks.cpu().float()
281
+ masks = masks.permute(0, 2, 3, 1)
282
+ masks = masks.mean(axis=-1)
283
+ masks = (masks > 0).int()
284
+ masks = masks.numpy().astype(np.uint8)
285
+ masks = list(masks)
286
+
287
+ if polygon_refinement:
288
+ for idx, mask in enumerate(masks):
289
+ shape = mask.shape
290
+ polygon = mask_to_polygon(mask)
291
+ mask = polygon_to_mask(polygon, shape)
292
+ masks[idx] = mask
293
+
294
+ return masks
295
+
296
+
297
+ # In[6]:
298
+
299
+
300
+ def detect(
301
+ image: Image.Image,
302
+ labels: List[str],
303
+ threshold: float = 0.3,
304
+ detector_id: Optional[str] = None
305
+ ) -> List[Dict[str, Any]]:
306
+ """
307
+ Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion.
308
+ """
309
+ device = "cuda" if torch.cuda.is_available() else "cpu"
310
+ detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny"
311
+ object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
312
+
313
+ labels = [label if label.endswith(".") else label+"." for label in labels]
314
+
315
+ results = object_detector(image, candidate_labels=labels, threshold=threshold)
316
+ results = [DetectionResult.from_dict(result) for result in results]
317
+
318
+ return results
319
+
320
+ def segment(
321
+ image: Image.Image,
322
+ detection_results: List[Dict[str, Any]],
323
+ polygon_refinement: bool = False,
324
+ segmenter_id: Optional[str] = None
325
+ ) -> List[DetectionResult]:
326
+ """
327
+ Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.
328
+ """
329
+ device = "cuda" if torch.cuda.is_available() else "cpu"
330
+ segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base"
331
+
332
+ segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
333
+ processor = AutoProcessor.from_pretrained(segmenter_id)
334
+
335
+ boxes = get_boxes(detection_results)
336
+ inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
337
+
338
+ outputs = segmentator(**inputs)
339
+ masks = processor.post_process_masks(
340
+ masks=outputs.pred_masks,
341
+ original_sizes=inputs.original_sizes,
342
+ reshaped_input_sizes=inputs.reshaped_input_sizes
343
+ )[0]
344
+
345
+ masks = refine_masks(masks, polygon_refinement)
346
+
347
+ for detection_result, mask in zip(detection_results, masks):
348
+ detection_result.mask = mask
349
+
350
+ return detection_results
351
+
352
+ def grounded_segmentation(
353
+ image: Union[Image.Image, str],
354
+ labels: List[str],
355
+ threshold: float = 0.3,
356
+ polygon_refinement: bool = False,
357
+ detector_id: Optional[str] = None,
358
+ segmenter_id: Optional[str] = None
359
+ ) -> Tuple[np.ndarray, List[DetectionResult]]:
360
+ if isinstance(image, str):
361
+ image = load_image(image)
362
+
363
+ detections = detect(image, labels, threshold, detector_id)
364
+ detections = segment(image, detections, polygon_refinement, segmenter_id)
365
+
366
+ return image, detections
367
+
368
+
369
+ # In[7]:
370
+
371
+
372
+ # save clipped images
373
+ def cut_image(image, mask, box):
374
+ ny_image = np.array(image)
375
+ cut = cv2.bitwise_and(ny_image, ny_image, mask=mask.astype(np.uint8)*255)
376
+ x0, y0, x1, y1 = map(int, box.xyxy)
377
+ cropped = cut[y0:y1, x0:x1]
378
+ cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
379
+ return cropped_bgr
380
+
381
+
382
+ # In[8]:
383
+
384
+
385
+ image_url = "/home/dheena/Downloads/fashion/images (1).jpeg"
386
+ labels = ["a dress"]
387
+ threshold = 0.3
388
+
389
+ detector_id = "IDEA-Research/grounding-dino-tiny"
390
+ segmenter_id = "facebook/sam-vit-base"
391
+
392
+
393
+ # In[9]:
394
+
395
+
396
+ # image, detections = grounded_segmentation(
397
+ # image=image_url,
398
+ # labels=labels,
399
+ # threshold=threshold,
400
+ # polygon_refinement=True,
401
+ # detector_id=detector_id,
402
+ # segmenter_id=segmenter_id
403
+ # )
404
+ # current = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
405
+ # cropped_image = cut_image(image, detections[0].mask, detections[0].box)
406
+ # cv2.imwrite("/home/dheena/Downloads/fashion/output/" + current, cropped_image)
407
+
408
+
409
+ # In[44]:
410
+
411
+
412
+ # plot_detections(np.array(image), detections, "test.png")
413
+
414
+
415
+ # In[60]:
416
+
417
+
418
+ # model imports
419
+ import faiss
420
+ import torch
421
+ import clip
422
+ from openai import OpenAI
423
+ from torch.utils.data import DataLoader
424
+
425
+ # helper imports
426
+ from tqdm import tqdm
427
+ import os
428
+ import numpy as np
429
+ from typing import List, Tuple
430
+
431
+ # visualization imports
432
+ from PIL import Image
433
+ from fastapi import FastAPI
434
+ from typing import List
435
+ import matplotlib.pyplot as plt
436
+
437
+ client = OpenAI()
438
+
439
+ # Set device
440
+ device = "cpu"
441
+ model, preprocess = clip.load("ViT-B/32", device=device)
442
+
443
+ # # Directory path
444
+ # direc = '/home/dheena/Downloads/fashion/output/'
445
+
446
+ # def get_image(filepath: str) -> Image.Image:
447
+ # """Safely load and convert an image file to RGB PIL format."""
448
+ # try:
449
+ # return Image.open(filepath).convert("RGB")
450
+ # except Exception as e:
451
+ # print(f"Failed to load {filepath}: {e}")
452
+ # return None
453
+
454
+ # def get_all_images_from_dir(directory: str) -> List[Tuple[str, Image.Image]]:
455
+ # """Load all supported images from a directory, with paths."""
456
+ # supported_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp')
457
+ # image_data = []
458
+
459
+ # for root, _, files in os.walk(directory):
460
+ # for file in files:
461
+ # if file.lower().endswith(supported_exts):
462
+ # full_path = os.path.join(root, file)
463
+ # try:
464
+ # img = Image.open(full_path).convert("RGB")
465
+ # image_data.append((full_path, img))
466
+ # except Exception as e:
467
+ # print(f"Error loading {full_path}: {e}")
468
+ # return image_data
469
+
470
+ def get_image_features(image: Image.Image) -> np.ndarray:
471
+ """Extract CLIP features from an image."""
472
+ image_input = preprocess(image).unsqueeze(0).to(device)
473
+ with torch.no_grad():
474
+ image_features = model.encode_image(image_input).float()
475
+ return image_features.cpu().numpy()
476
+
477
+ # FAISS setup
478
+ index = faiss.IndexFlatIP(512)
479
+ meta_data_store = []
480
+
481
+ def save_image_in_index(image_features: np.ndarray, metadata: dict):
482
+ """Normalize features and add to index."""
483
+ faiss.normalize_L2(image_features)
484
+ index.add(image_features)
485
+ meta_data_store.append(metadata)
486
+
487
+ def process_image_embedding(image_url: str, labels=['clothes']) -> np.ndarray:
488
+ """Get feature embedding for a query image."""
489
+ search_image, search_detections = grounded_segmentation(image=image_url, labels=labels)
490
+ cropped_image = cut_image(search_image, search_detections[0].mask, search_detections[0].box)
491
+
492
+ # Convert to valid RGB
493
+ if cropped_image.dtype != np.uint8:
494
+ cropped_image = (cropped_image * 255).astype(np.uint8)
495
+ if cropped_image.ndim == 2:
496
+ cropped_image = np.stack([cropped_image] * 3, axis=-1)
497
+
498
+ pil_image = Image.fromarray(cropped_image)
499
+ return pil_image
500
+
501
+ def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
502
+ """Find top-k similar images from the index."""
503
+ processed_image = process_image_embedding(image_url)
504
+ image_search_embedding = get_image_features(processed_image)
505
+ faiss.normalize_L2(image_search_embedding)
506
+ distances, indices = index.search(image_search_embedding.reshape(1, -1), k)
507
+
508
+ results = []
509
+ for i, dist in zip(indices[0], distances[0]):
510
+ if i < len(meta_data_store):
511
+ results.append({
512
+ 'metadata': meta_data_store[i],
513
+ 'score': float(dist)
514
+ })
515
+ return results
516
+
517
+ # def display_similar_images(results: List[dict]):
518
+ # """Display retrieved images using matplotlib."""
519
+ # for item in results:
520
+ # img = get_image(item['metadata']['image_path'])
521
+ # if img:
522
+ # print(f"Score: {item['score']:.4f}")
523
+ # plt.imshow(img)
524
+ # plt.axis('off')
525
+ # plt.show()
526
+
527
+ # In[73]:
528
+
529
+
530
+ app = FastAPI()
531
+
532
+ @app.get("/similar_images")
533
+ def get_similar_images(image_url: str, k: int = 10):
534
+ results = get_top_k_results(image_url, k)
535
+ # display_similar_images(results) # Optional visualization call
536
+ return {
537
+ "results": [
538
+ {
539
+ "metadata": item["metadata"],
540
+ "score": item["score"]
541
+ }
542
+ for item in results
543
+ ]
544
+ }
545
+
546
+ # Example usage:
547
+ # results = get_top_k_results("/home/dheena/Downloads/fashion/temp/KPR-120-Wine_2_1024x1024.webp")
548
+ # display_similar_images(results)
549
+
550
+
551
+ # In[54]:
552
+
553
+
554
+
555
+
src/segmentation.py CHANGED
@@ -1,7 +1,6 @@
1
 
2
  from dataclasses import dataclass
3
  from typing import Any, List, Dict, Optional, Union, Tuple
4
- import os
5
 
6
  import cv2
7
  import torch
 
1
 
2
  from dataclasses import dataclass
3
  from typing import Any, List, Dict, Optional, Union, Tuple
 
4
 
5
  import cv2
6
  import torch