Shri Jayaram commited on
Commit
35ea9eb
·
1 Parent(s): 1c9286b

initial commit

Browse files
Files changed (8) hide show
  1. FlowChart.png +0 -0
  2. app.py +457 -0
  3. arial.ttf +0 -0
  4. florence.py +59 -0
  5. requirements.txt +11 -0
  6. sam.py +46 -0
  7. sam2_hiera_s.yaml +117 -0
  8. sam2_hiera_small.pt +3 -0
FlowChart.png ADDED
app.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ from typing import Optional
4
+ import io
5
+ from io import BytesIO
6
+ import os
7
+ import cv2
8
+ import numpy as np
9
+ import supervision as sv
10
+ import matplotlib.pyplot as plt
11
+ from rembg import remove
12
+ import mediapipe as mp
13
+ import torch
14
+ from transformers import AutoProcessor, AutoModelForCausalLM
15
+ from transformers.dynamic_module_utils import get_imports
16
+ from unittest.mock import patch
17
+ from scipy.spatial import distance as dist
18
+
19
+ st.set_page_config(layout="wide", page_title="Ring Size Measurement")
20
+ ring_size_dict = {
21
+ 14.0: 3,
22
+ 14.4: 3.5,
23
+ 14.8: 4,
24
+ 15.2: 4.5,
25
+ 15.6: 5,
26
+ 16.0: 5.5,
27
+ 16.45: 6,
28
+ 16.9: 6.5,
29
+ 17.3: 7,
30
+ 17.7: 7.5,
31
+ 18.2: 8,
32
+ 18.6: 8.5,
33
+ 19.0: 9,
34
+ 19.4: 9.5,
35
+ 19.8: 10,
36
+ 20.2: 10.5,
37
+ 20.6: 11,
38
+ 21.0: 11.5,
39
+ 21.4: 12,
40
+ 21.8: 12.5,
41
+ 22.2: 13,
42
+ 22.6: 13.5
43
+ }
44
+
45
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+
47
+ # def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
48
+ # if not str(filename).endswith("modeling_florence2.py"):
49
+ # return get_imports(filename)
50
+ # imports = get_imports(filename)
51
+ # imports.remove("flash_attn")
52
+ # return imports
53
+
54
+ # def load_model():
55
+ # model_id = "microsoft/Florence-2-base-ft"
56
+ # processor = AutoProcessor.from_pretrained(model_id, torch_dtype=torch.qint8, trust_remote_code=True)
57
+
58
+ # try:
59
+ # os.mkdir("temp")
60
+ # except:
61
+ # pass
62
+
63
+ # with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
64
+ # model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa", trust_remote_code=True)
65
+
66
+ # Qmodel = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
67
+ # return Qmodel.to(device), processor
68
+
69
+ # if 'model_loaded' not in st.session_state:
70
+ # st.session_state.model_loaded = False
71
+
72
+ # if not st.session_state.model_loaded:
73
+ # with st.spinner('Loading model...'):
74
+ # st.session_state.model, st.session_state.processor = load_model()
75
+ # st.session_state.model_loaded = True
76
+ # st.write("Model loaded complete")
77
+ from florence import load_florence_model, run_florence_inference, \
78
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
79
+ from sam import load_sam_image_model, run_sam_inference
80
+
81
+ if torch.cuda.is_available():
82
+ DEVICE = torch.device("cuda")
83
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
84
+ if torch.cuda.get_device_properties(0).major >= 8:
85
+ torch.backends.cuda.matmul.allow_tf32 = True
86
+ torch.backends.cudnn.allow_tf32 = True
87
+ else:
88
+ DEVICE = torch.device("cpu")
89
+
90
+ FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
91
+ SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
92
+
93
+ def calculate_pixel_per_metric(image, known_diameter_of_coin=25):
94
+ original_image = image.copy()
95
+ def extrac_the_obj(input_image: Image.Image, mask: Image.Image):
96
+
97
+ input_array = np.array(input_image)
98
+ mask_array = np.array(mask.convert("L"))
99
+
100
+ binary_mask = (mask_array > 0).astype("uint8")
101
+
102
+ output_array = np.zeros((*input_array.shape[:2], 4), dtype=np.uint8)
103
+
104
+ output_array[binary_mask == 1, :3] = input_array[binary_mask == 1]
105
+ output_array[binary_mask == 1, 3] = 255
106
+
107
+ return Image.fromarray(output_array, 'RGBA')
108
+ @torch.inference_mode()
109
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
110
+ def get_obj_mask(image_input, text_input) -> Optional[Image.Image]:
111
+ if image_input is None:
112
+ st.warning("Please upload an image.")
113
+ return None
114
+
115
+ if not text_input:
116
+ st.warning("Please enter a text prompt.")
117
+ return None
118
+
119
+ _, result = run_florence_inference(
120
+ model=FLORENCE_MODEL,
121
+ processor=FLORENCE_PROCESSOR,
122
+ device=DEVICE,
123
+ image=image_input,
124
+ task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
125
+ text=text_input
126
+ )
127
+ detections = sv.Detections.from_lmm(
128
+ lmm=sv.LMM.FLORENCE_2,
129
+ result=result,
130
+ resolution_wh=image_input.size
131
+ )
132
+ detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
133
+ if len(detections) == 0:
134
+ st.warning("No objects detected.")
135
+ return None
136
+
137
+ return Image.fromarray(detections.mask[0].astype("uint8") * 255)
138
+
139
+ # def plot_bbox(original_image, data):
140
+ # # Create a copy of the original image to draw on
141
+ # image_with_bboxes = original_image.copy()
142
+
143
+ # # Use Pillow to draw bounding boxes and labels
144
+ # draw = ImageDraw.Draw(image_with_bboxes)
145
+ # def calculate_bbox_dimensions(bbox):
146
+ # x1, y1, x2, y2 = bbox
147
+ # width = x2 - x1
148
+ # height = y2 - y1
149
+ # return width, height
150
+
151
+ # font = ImageFont.truetype("arial.ttf", 28)
152
+ # for bbox, label in zip(data['bboxes'], data['labels']):
153
+ # x1, y1, x2, y2 = bbox
154
+ # draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
155
+ # draw.text((x1, y1), label, fill="red", font=font)
156
+
157
+ # width, height = calculate_bbox_dimensions(bbox)
158
+ # print(f"Label: {label}, Width: {width}, Height: {height}")
159
+ # dimension_text = f"W: {width}, H: {height}"
160
+ # draw.text((x1, y1 + 20), dimension_text, fill="red", font=font)
161
+
162
+ # real_world_dimension_mm = 160
163
+ # largest_dimension = max(width, height)
164
+ # pixels_per_mm = largest_dimension / real_world_dimension_mm
165
+ # ratio_text = f"Pixels/mm: {pixels_per_mm:.2f}"
166
+ # draw.text((x1, y1 + 40), ratio_text, fill="red", font=font)
167
+ def plot_bbox(the_obj, mask, known_length = 160):
168
+ input_array = np.array(the_obj)
169
+ mask_array = np.array(mask.convert("L"))
170
+
171
+ # Create binary mask
172
+ binary_mask = (mask_array > 0).astype("uint8")
173
+
174
+ output_array = np.zeros((*input_array.shape[:2], 4), dtype=np.uint8)
175
+
176
+ output_array[binary_mask == 1, :3] = input_array[binary_mask == 1]
177
+ output_array[binary_mask == 1, 3] = 255
178
+
179
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
180
+
181
+ if contours:
182
+ heights = [cv2.boundingRect(contour)[3] for contour in contours]
183
+ m_ht = max(heights)
184
+
185
+ pixels_per_metric = m_ht / known_length
186
+ text = f"Pixel per Metric -> {pixels_per_metric:.2f} px/mm Actual Ht.: {m_ht} px Known Length : {known_length} mm"
187
+ print(text)
188
+ output_image = Image.fromarray(output_array, 'RGBA')
189
+ if text:
190
+ draw = ImageDraw.Draw(output_image)
191
+ font = ImageFont.truetype("arial.ttf", 34) # You can specify a TTF font if you have one
192
+ text_position = (10, 10) # You can change this position
193
+ draw.text(text_position, text, fill=(255, 255, 255, 255), font=font)
194
+ return output_image, pixels_per_metric, None
195
+ # return Image.fromarray(output_array, 'RGBA'), pixels_per_metric, None
196
+
197
+ def finding_ruler(image, task_prompt, text_input=None):
198
+ known_length = 160 # roughly 16cm so 160 mm
199
+ results = get_obj_mask(image, text_input=text_input)
200
+ # the_obj = extrac_the_obj(image, results)
201
+ image_with_bboxes, value_1, value_2 = plot_bbox(image, results, known_length)
202
+ return value_1, value_2, image_with_bboxes
203
+
204
+ image_for_model = image.copy()
205
+
206
+ image_for_model = cv2.cvtColor(image_for_model, cv2.COLOR_BGR2RGB)
207
+ image_for_model = Image.fromarray(image_for_model)
208
+ # if image_for_model.mode != 'RGB':
209
+ # image_for_model = image_for_model.convert('RGB')
210
+
211
+ # Process the image
212
+ text_input = "ruler"
213
+ task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
214
+ pixel_per_metric, mm_per_pixel, marked_image_buf = finding_ruler(image_for_model, task_prompt, text_input)
215
+
216
+
217
+ return pixel_per_metric, mm_per_pixel, marked_image_buf
218
+
219
+ def process_image(image):
220
+ return remove(image)
221
+
222
+ def calculate_pip_width(image, original_img, pixel_per_metric):
223
+ def calSize(xA, yA, xB, yB, color_circle, color_line, img):
224
+ d = dist.euclidean((xA, yA), (xB, yB))
225
+ cv2.circle(img, (int(xA), int(yA)), 5, color_circle, -1)
226
+ cv2.circle(img, (int(xB), int(yB)), 5, color_circle, -1)
227
+ cv2.line(img, (int(xA), int(yA)), (int(xB), int(yB)), color_line, 2)
228
+ d_mm = d / pixel_per_metric
229
+ d_mm = d_mm - 1.5
230
+ cv2.putText(img, "{:.1f}".format(d_mm), (int(xA - 15), int(yA - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.65, (255, 255, 255), 2)
231
+ # print(d_mm)
232
+ return d_mm
233
+
234
+ def process_point(point, cnt, m1, b):
235
+ x1, x2 = point[0], point[0]
236
+ y1 = m1 * x1 + b
237
+ y2 = m1 * x2 + b
238
+
239
+ result = 1.0
240
+ while result > 0:
241
+ result = cv2.pointPolygonTest(cnt, (x1, y1), False)
242
+ x1 += 1
243
+ y1 = m1 * x1 + b
244
+ x1 -= 1
245
+
246
+ result = 1.0
247
+ while result > 0:
248
+ result = cv2.pointPolygonTest(cnt, (x2, y2), False)
249
+ x2 -= 1
250
+ y2 = m1 * x2 + b
251
+ x2 += 1
252
+
253
+ return x1, y1, x2, y2
254
+
255
+ og_img = original_img.copy()
256
+ imgH, imgW, _ = image.shape
257
+ imgcpy = image.copy()
258
+ image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
259
+ _, binary_image = cv2.threshold(image_gray, 1, 255, cv2.THRESH_BINARY)
260
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
261
+ contour_image = np.zeros_like(image_gray)
262
+ cv2.drawContours(contour_image, contours, -1, (255), thickness=cv2.FILLED)
263
+ cv2.drawContours(imgcpy, contours, -1, (0, 255, 0), 2)
264
+ # print("length : ",len(contours))
265
+
266
+ marked_img = image.copy()
267
+
268
+ if len(contours) > 0:
269
+ cnt = max(contours, key=cv2.contourArea)
270
+ frame2 = cv2.cvtColor(og_img, cv2.COLOR_BGR2RGB)
271
+ handsLM = mp.solutions.hands.Hands(max_num_hands=1, min_detection_confidence=0.8, min_tracking_confidence=0.8)
272
+ pr = handsLM.process(frame2)
273
+ # print(pr.multi_hand_landmarks)
274
+ if pr.multi_hand_landmarks:
275
+ for hand_landmarks in pr.multi_hand_landmarks:
276
+ lmlist = []
277
+ for id, landMark in enumerate(hand_landmarks.landmark):
278
+ xPos, yPos = int(landMark.x * imgW), int(landMark.y * imgH)
279
+ lmlist.append([id, xPos, yPos])
280
+
281
+ if len(lmlist) != 0:
282
+ pip_joint = [lmlist[14][1], lmlist[14][2]]
283
+ mcp_joint = [lmlist[13][1], lmlist[13][2]]
284
+
285
+ midpoint_x = (pip_joint[0] + mcp_joint[0]) / 2
286
+ midpoint_y = (pip_joint[1] + mcp_joint[1]) / 2
287
+ midpoint = [midpoint_x, midpoint_y]
288
+
289
+ m2 = (pip_joint[1] - mcp_joint[1]) / (pip_joint[0] - mcp_joint[0])
290
+ m1 = -1 / m2
291
+ b = pip_joint[1] - m1 * pip_joint[0]
292
+
293
+ #pip_joint
294
+ x1_pip, y1_pip, x2_pip, y2_pip = process_point(pip_joint, cnt, m1, b)
295
+
296
+ m2 = (midpoint_y - mcp_joint[1]) / (midpoint_x - mcp_joint[0])
297
+ m1 = -1 / m2
298
+ b = midpoint_y - m1 * midpoint_x
299
+
300
+ #midpoint
301
+ x1_mid, y1_mid, x2_mid, y2_mid = process_point(midpoint, cnt, m1, b)
302
+
303
+ d_mm_pip = calSize(x1_pip, y1_pip, x2_pip, y2_pip, (255, 0, 0), (255, 0, 255), original_img)
304
+ d_mm_mid = calSize(x1_mid, y1_mid, x2_mid, y2_mid, (0, 255, 0), (0, 0, 255), original_img)
305
+
306
+ largest_d_mm = max(int(d_mm_mid),int(d_mm_pip))
307
+ return original_img, largest_d_mm, imgcpy, marked_img
308
+
309
+ def mark_hand_landmarks(image_path):
310
+
311
+ mp_hands = mp.solutions.hands
312
+ hands = mp_hands.Hands()
313
+ mp_draw = mp.solutions.drawing_utils
314
+
315
+ img = image_path
316
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
317
+
318
+ results = hands.process(img_rgb)
319
+
320
+ if results.multi_hand_landmarks:
321
+ for hand_landmarks in results.multi_hand_landmarks:
322
+ mp_draw.draw_landmarks(img, hand_landmarks, mp_hands.HAND_CONNECTIONS)
323
+
324
+ mcp = hand_landmarks.landmark[13]
325
+ pip = hand_landmarks.landmark[14]
326
+
327
+ img_height, img_width, _ = img.shape
328
+
329
+ mcp_x, mcp_y = int(mcp.x * img_width), int(mcp.y * img_height)
330
+ pip_x, pip_y = int(pip.x * img_width), int(pip.y * img_height)
331
+
332
+ cv2.circle(img, (mcp_x, mcp_y), 10, (255, 0, 0), -1)
333
+ cv2.circle(img, (pip_x, pip_y), 10, (255, 0, 0), -1)
334
+
335
+ return img
336
+
337
+ def show_resized_image(images, titles, scale=0.5):
338
+ num_images = len(images)
339
+
340
+ fig, axes = plt.subplots(2, 3, figsize=(17, 13))
341
+ axes = axes.flatten()
342
+
343
+ for ax in axes[num_images:]:
344
+ ax.axis('off')
345
+ i = 0
346
+ for ax, img, title in zip(axes, images, titles):
347
+ i = i + 1
348
+ print(i)
349
+ resized_image = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
350
+ ax.imshow(cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB))
351
+ ax.set_title(title)
352
+ ax.axis('off')
353
+
354
+ plt.tight_layout()
355
+ img_stream = BytesIO()
356
+ plt.savefig(img_stream, format='png')
357
+ img_stream.seek(0)
358
+ plt.close(fig)
359
+ return img_stream
360
+
361
+ def get_ring_size(mm_value):
362
+ if mm_value in ring_size_dict:
363
+ return ring_size_dict[mm_value]
364
+ else:
365
+ closest_mm = min(ring_size_dict.keys(), key=lambda x: abs(x - mm_value))
366
+ return ring_size_dict[closest_mm]
367
+
368
+ # st.set_page_config(layout="wide", page_title="Ring Size Measurement")
369
+ st.write("## Determine Your Ring Size")
370
+ st.write(
371
+ "📏 Upload an image of your hand to measure the finger width and determine your ring size. The measurement will be displayed along with a visual breakdown of the image processing flow."
372
+ )
373
+ st.sidebar.write("## Upload :gear:")
374
+ #~~
375
+ st.write("### Workflow Overview")
376
+ st.image("FlowChart.png", caption="Workflow Overview", use_column_width=True)
377
+
378
+ st.write("### Detailed Workflow")
379
+ st.write("1. **Florence-2 Model:** Florence-2 is an advanced vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks.We utilize this model to detect the scale within the image and mark a bounding box which we can use to find the approximate full measurement of scale.")
380
+ st.write("2. **Pixel Per Metric Ratio:** The Pixel Per Metric Ratio is used to convert pixel measurements into real-world units. By comparing the pixel length obtained from image analysis (i.e., Hough Circle) with the known real-world measurement of the reference object (coin), we get the ratio. This ratio then allows us to accurately scale and size estimation of objects within the image.")
381
+ st.write("3. **Background Removal:** Removing the background first ensures that only the relevant subject is highlighted. We start by converting the image to grayscale and applying thresholding to distinguish the subject from the background. Erosion and dilation then clean up the image, improving the detection of specific features like individual fingers.")
382
+ st.write("4. **Contour Detection:** We use Contour Detection to find the largest contour, which allows us to outline or draw a boundary around the subject (i.e., hand). This highlights the object's shape and edges, improving the precision of the subject.")
383
+ st.write("5. **Finding Hand Landmarks:** This involves using the MediaPipe library to identify key points on the hand, such as the PIP (Proximal Interphalangeal) and MCP (Metacarpophalangeal) joints of the ring finger. This enables precise tracking and analysis of finger positions and movements.")
384
+ st.write("6. **Determining Finger Width:** Here we use the slope formula `[y = mx + b]` with PIP and MCP points to measure the finger's width. We project outward perpendicularly from the PIP point towards the MCP point, then apply a point polygon test to accurately determine the pixel width of the finger.")
385
+ st.write("7. **Predicting Ring Size:** Predicting Ring Size involves calculating the finger’s diameter using the Pixel Per Metric Ratio and the largest width measurement at the PIP or MCP joint. This diameter is then used to predict the appropriate ring size.")
386
+ #~~
387
+
388
+ MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
389
+
390
+ def process_image_and_get_results(upload):
391
+ image = Image.open(upload)
392
+ # image = cv2.imread(upload)
393
+ image_np = np.array(image)
394
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
395
+ original_img = image_np.copy()
396
+ og_img1 = image_np.copy()
397
+ og_img2 = image_np.copy()
398
+ img_1 = image_np.copy()
399
+ hand_lms = mark_hand_landmarks(img_1)
400
+
401
+ pixel_per_metric, mm_per_pixel, image_with_coin_info = calculate_pixel_per_metric(image_np)
402
+ processed_image = process_image(og_img1)
403
+ image_with_pip_width, width_mm, contour_image, pip_mark_img = calculate_pip_width(processed_image, original_img, pixel_per_metric)
404
+ image_with_coin_info = np.array(image_with_coin_info)
405
+ if image_with_coin_info is None:
406
+ print("inside1")
407
+ raise ValueError("Image is None, cannot resize.")
408
+
409
+ elif not isinstance(image_with_coin_info, (np.ndarray, cv2.UMat)):
410
+ print("inside2")
411
+ raise TypeError(f"Invalid image type: {type(image_with_coin_info)}. Expected numpy array or cv2.UMat.")
412
+ ring_size = get_ring_size(width_mm)
413
+ return {
414
+ "processed_image": image_with_pip_width,
415
+ "original_image": og_img2,
416
+ "hand_lm_marked_image": hand_lms,
417
+ "image_with_coin_info": image_with_coin_info,
418
+ "contour_image": contour_image,
419
+ "width_mm": width_mm,
420
+ "ring_size": ring_size
421
+ }
422
+
423
+ def show_how_it_works(processed_image):
424
+ st.write("## How It Works")
425
+ st.write("Here's a step-by-step breakdown of how your image is processed to determine your ring size:")
426
+ st.image(processed_image, caption="Image Processing Flow", use_column_width=True)
427
+
428
+ col1, col2 = st.columns(2)
429
+ my_upload = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
430
+
431
+ if my_upload is not None:
432
+ if my_upload.size > MAX_FILE_SIZE:
433
+ st.error("The uploaded file is too large. Please upload an image smaller than 5MB.")
434
+ else:
435
+ st.write("## Image Processing Flow")
436
+ results = process_image_and_get_results(my_upload)
437
+
438
+ col1.write("Uploaded Image :camera:")
439
+ col1.image(cv2.cvtColor(results["original_image"], cv2.COLOR_BGR2RGB), caption="Uploaded Image")
440
+
441
+ col2.write("Processed Image :wrench:")
442
+ col2.image(cv2.cvtColor(results["processed_image"], cv2.COLOR_BGR2RGB), caption="Processed Image with PIP Width")
443
+
444
+ st.write(f"📏 The width of your finger is {results['width_mm']:.2f} mm, and the estimated ring size is {results['ring_size']:.1f}.")
445
+
446
+ if st.button("How it Works"):
447
+ st.write("## How It Works")
448
+ st.write("Here's a step-by-step breakdown of how your image is processed to determine your ring size:")
449
+ print("here")
450
+ img_stream = show_resized_image(
451
+ [results["original_image"], results["image_with_coin_info"], results["contour_image"], results["hand_lm_marked_image"], results["processed_image"]],
452
+ ['Original Image', 'Image with Scale Info', 'Contour Boundary Image', 'Hand Landmarks', 'Ring Finger Width'],
453
+ scale=0.5
454
+ )
455
+ st.image(img_stream, caption="Processing Flow", use_column_width=True)
456
+ else:
457
+ st.info("Please upload an image to get started.")
arial.ttf ADDED
Binary file (312 kB). View file
 
florence.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, Any, Tuple, Dict
3
+ from unittest.mock import patch
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM, AutoProcessor
8
+ from transformers.dynamic_module_utils import get_imports
9
+
10
+ FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
11
+ # FLORENCE_CHECKPOINT = "microsoft/Florence-2-large"
12
+ FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
13
+ FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
14
+ FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
15
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
16
+ FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'
17
+
18
+
19
+ def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
20
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
21
+ if not str(filename).endswith("/modeling_florence2.py"):
22
+ return get_imports(filename)
23
+ imports = get_imports(filename)
24
+ imports.remove("flash_attn")
25
+ return imports
26
+
27
+
28
+ def load_florence_model(
29
+ device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
30
+ ) -> Tuple[Any, Any]:
31
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ checkpoint, trust_remote_code=True).to(device).eval()
34
+ processor = AutoProcessor.from_pretrained(
35
+ checkpoint, trust_remote_code=True)
36
+ return model, processor
37
+
38
+
39
+ def run_florence_inference(
40
+ model: Any,
41
+ processor: Any,
42
+ device: torch.device,
43
+ image: Image,
44
+ task: str,
45
+ text: str = ""
46
+ ) -> Tuple[str, Dict]:
47
+ prompt = task + text
48
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
49
+ generated_ids = model.generate(
50
+ input_ids=inputs["input_ids"],
51
+ pixel_values=inputs["pixel_values"],
52
+ max_new_tokens=1024,
53
+ num_beams=3
54
+ )
55
+ generated_text = processor.batch_decode(
56
+ generated_ids, skip_special_tokens=False)[0]
57
+ response = processor.post_process_generation(
58
+ generated_text, task=task, image_size=image.size)
59
+ return generated_text, response
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ Pillow
3
+ numpy
4
+ opencv-python
5
+ supervision
6
+ matplotlib
7
+ rembg
8
+ mediapipe
9
+ torch
10
+ transformers
11
+ scipy
sam.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import supervision as sv
5
+ import torch
6
+ from PIL import Image
7
+ from sam2.build_sam import build_sam2, build_sam2_video_predictor
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+
10
+ SAM_CHECKPOINT = "sam2_hiera_small.pt"
11
+ SAM_CONFIG = "sam2_hiera_s.yaml"
12
+ # SAM_CHECKPOINT = "checkpoints/sam2_hiera_large.pt"
13
+ # SAM_CONFIG = "sam2_hiera_l.yaml"
14
+
15
+
16
+ def load_sam_image_model(
17
+ device: torch.device,
18
+ config: str = SAM_CONFIG,
19
+ checkpoint: str = SAM_CHECKPOINT
20
+ ) -> SAM2ImagePredictor:
21
+ model = build_sam2(config, checkpoint, device=device)
22
+ return SAM2ImagePredictor(sam_model=model)
23
+
24
+
25
+ def load_sam_video_model(
26
+ device: torch.device,
27
+ config: str = SAM_CONFIG,
28
+ checkpoint: str = SAM_CHECKPOINT
29
+ ) -> Any:
30
+ return build_sam2_video_predictor(config, checkpoint, device=device)
31
+
32
+
33
+ def run_sam_inference(
34
+ model: Any,
35
+ image: Image,
36
+ detections: sv.Detections
37
+ ) -> sv.Detections:
38
+ image = np.array(image.convert("RGB"))
39
+ model.set_image(image)
40
+ mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False)
41
+
42
+ if len(mask.shape) == 4:
43
+ mask = np.squeeze(mask)
44
+
45
+ detections.mask = mask.astype(bool)
46
+ return detections
sam2_hiera_s.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
117
+
sam2_hiera_small.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95949964d4e548409021d47b22712d5f1abf2564cc0c3c765ba599a24ac7dce3
3
+ size 184309650