ivoryzhang commited on
Commit
ab40ac4
·
1 Parent(s): 12e87f8

first commit

Browse files
Files changed (5) hide show
  1. app.py +26 -0
  2. heinsight.py +303 -0
  3. models/best_content.pt +3 -0
  4. models/best_vessel.pt +3 -0
  5. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from heinsight import HeinSight, HeinSightConfig
2
+
3
+ heinsight = HeinSight(vial_model_path="models/best_vessel.pt",
4
+ contents_model_path="models/best_content.pt",
5
+ config=HeinSightConfig())
6
+
7
+ import gradio as gr
8
+
9
+ # Gradio UI
10
+ demo = gr.Interface(
11
+ fn=heinsight.image_demo,
12
+ inputs=[
13
+ gr.Image(type="pil"),
14
+ gr.Slider(0.1, 1.0, step=0.01, value=0, label="Cap Size Ratio")
15
+ ],
16
+ outputs=[
17
+ gr.Image(type="pil", label="Detected Image"),
18
+ gr.JSON(label="Detection Info") # or gr.Textbox() if you prefer plain text
19
+ ],
20
+ title="HeinSight",
21
+ description="Upload an image with vials to detect their contents"
22
+ )
23
+
24
+
25
+ if __name__ == "__main__":
26
+ demo.launch()
heinsight.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from itertools import chain
3
+ from random import randint
4
+
5
+ import cv2
6
+ import matplotlib
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from ultralytics import YOLO
11
+ from PIL import Image
12
+ matplotlib.use('Agg')
13
+
14
+ def highlight_vial_body(frame, vial_location, cap_ratio=0.2):
15
+ """
16
+ Highlights only the vial body in the frame by masking out background and cap.
17
+
18
+ Args:
19
+ frame (np.ndarray): Original BGR image.
20
+ vial_location (tuple): (x, y, w, h) bounding box of vial.
21
+ cap_ratio (float): Fraction (0-1) of vial height considered as cap.
22
+
23
+ Returns:
24
+ masked_frame (np.ndarray): Frame with background and cap masked out.
25
+ """
26
+ overlay = frame.copy()
27
+ x, y, x2, y2 = vial_location
28
+ h = y2 - y
29
+ # Define cap and body regions
30
+ cap_height = int(h * cap_ratio)
31
+ body_y_start = y + cap_height
32
+
33
+ # Draw gray background mask
34
+ cv2.rectangle(overlay, (0, 0), (frame.shape[1], frame.shape[0]), (128, 128, 128), thickness=-1)
35
+
36
+ # Draw red translucent cap over the vial's cap region
37
+ cv2.rectangle(overlay, (x, y), (x2, body_y_start), (0, 0, 255), thickness=-1)
38
+ masked = cv2.addWeighted(overlay, 0.5, frame, 0.5, 0)
39
+
40
+ return masked
41
+
42
+
43
+ class HeinSightConfig:
44
+
45
+ """Configuration for the HeinSight system."""
46
+ NUM_ROWS = -1
47
+ SAVE_PLOT_VIDEO = True
48
+ LIQUID_CONTENT = ["Homo", "Hetero"]
49
+ CAP_RATIO = 0.3
50
+ STATUS_RULE = 0.7
51
+ DEFAULT_VIAL_LOCATION = None
52
+ DEFAULT_VIAL_HEIGHT = None
53
+
54
+ class HeinSight:
55
+ """
56
+ The core of the HeinSight system, responsible for computer vision and analysis.
57
+ """
58
+
59
+ def __init__(self, vial_model_path: str, contents_model_path: str, config: HeinSightConfig = HeinSightConfig()):
60
+ self.fig, self.axs = plt.subplots(2, 2, figsize=(8, 6), height_ratios=[2, 1], constrained_layout=True)
61
+ self._set_axes()
62
+ self.config = config
63
+ self.vial_model = YOLO(vial_model_path)
64
+ self.contents_model = YOLO(contents_model_path)
65
+ self.color_palette = self._register_colors([self.vial_model, self.contents_model])
66
+ self.clear_cache()
67
+
68
+ def _set_axes(self):
69
+ """creating plot axes"""
70
+ ax0, ax1, ax2, ax3 = self.axs.flat
71
+ ax0.set_position([0.21, 0.45, 0.22, 0.43]) # [left, bottom, width, height]
72
+
73
+ ax1.set_position([0.47, 0.45, 0.45, 0.43]) # [left, bottom, width, height]
74
+ ax2.set_position([0.12, 0.12, 0.35, 0.27])
75
+ ax3.set_position([0.56, 0.12, 0.35, 0.27])
76
+ self.fig.canvas.draw_idle()
77
+
78
+ def clear_cache(self):
79
+ """Resets the state of the HeinSight system."""
80
+ self.vial_location = self.config.DEFAULT_VIAL_LOCATION.copy() if self.config.DEFAULT_VIAL_LOCATION else None
81
+ self.cap_rows = 0
82
+ self.vial_heigh = self.config.DEFAULT_VIAL_HEIGHT
83
+ self.vial_size = []
84
+ self.content_info = None
85
+ self.x_time = []
86
+ self.turbidity_2d = []
87
+ self.average_colors = []
88
+ self.average_turbidity = []
89
+ self.output = []
90
+ self.stream_output = []
91
+ self.status = {}
92
+ self.status_queue = []
93
+ self.output_dataframe = pd.DataFrame()
94
+ self.output_frame = None
95
+ self.turbidity = []
96
+
97
+ @staticmethod
98
+ def _register_colors(model_list):
99
+ """
100
+ register default colors for models
101
+ :param model_list: YOLO models list
102
+ """
103
+ name_color_dict = {
104
+ "Empty": (19, 69, 139), # Brown
105
+ "Residue": (0, 165, 255), # Orange
106
+ "Hetero": (255, 0, 255), # purple
107
+ "Homo": (0, 0, 255), # Red
108
+ "Solid": (255, 0, 0), # Blue
109
+ }
110
+ names = set(chain.from_iterable(model.names.values() for model in model_list if model))
111
+ for name in names:
112
+ if name not in name_color_dict:
113
+ name_color_dict[name] = (randint(0, 255), randint(0, 255), randint(0, 255))
114
+ return name_color_dict
115
+
116
+ def find_vial(self, frame):
117
+ """
118
+ Detect the vial in video frame with YOLOv8
119
+ :param frame: raw input frame
120
+ :return result: np.ndarray or None: Detected vial bounding box or None if no vial is found.
121
+ """
122
+ # vial location is not defined, use vial model to detect
123
+ if not self.vial_location:
124
+ results = self.vial_model(frame, conf=0.2, max_det=1)
125
+ boxes = results[0].boxes.data.cpu().numpy()
126
+ if boxes.size > 0:
127
+ self.vial_location = [int(x) for x in boxes[0, :4]]
128
+ if self.vial_location:
129
+ self.cap_rows = int((self.vial_location[3] - self.vial_location[1]) * self.config.CAP_RATIO)
130
+ return self.vial_location is not None
131
+
132
+ def crop_rectangle(self, image, vial_location):
133
+ """
134
+ crop and resize the image
135
+ :param image: raw image capture
136
+ :param vial_location:
137
+ :return: cropped and resized vial frame
138
+ """
139
+ x1, y1, x2, y2 = vial_location
140
+ y1 = int(self.config.CAP_RATIO * (y2 - y1)) + y1
141
+ cropped_image = image[y1:y2, x1:x2]
142
+ return cropped_image
143
+
144
+ def content_detection(self, vial_frame):
145
+ """
146
+ Detect content in a vial frame.
147
+ :param vial_frame: (np.ndarray) Cropped vial frame.
148
+ :return tuple: Bounding boxes, liquid boxes, and detected class titles.
149
+ """
150
+ results = self.contents_model(vial_frame, max_det=4, agnostic_nms=False, conf=0.25, iou=0.25, verbose=False)
151
+ bboxes = results[0].boxes.data.cpu().numpy()
152
+ pred_classes = bboxes[:, 5]
153
+ title = " ".join([self.contents_model.names[int(x)] for x in pred_classes])
154
+ liquid_boxes = [bboxes[i][:4] for i, cls in enumerate(pred_classes) if
155
+ self.contents_model.names[int(cls)] in self.config.LIQUID_CONTENT]
156
+ return bboxes, sorted(liquid_boxes, key=lambda x: x[1], reverse=True), title
157
+
158
+
159
+ def process_vial_frame(self, vial_frame, update_od: bool = False):
160
+ """
161
+ process single vial frame, detect content, draw bounding box and calculate turbidity and color
162
+ :param vial_frame: vial frame image
163
+ :param update_od: update object detection, True: run YOLO for this frame, False: use previous YOLO results
164
+ """
165
+ if update_od or self.content_info is None:
166
+ self.content_info = self.content_detection(vial_frame)
167
+ bboxes, liquid_boxes, title = self.content_info
168
+ phase_data, raw_turbidity = self.calculate_value_color(vial_frame, liquid_boxes)
169
+ frame_image = self.draw_bounding_boxes(vial_frame, bboxes, self.contents_model.names, text_right=False)
170
+
171
+ if self.config.SAVE_PLOT_VIDEO:
172
+ self.display_frame(raw_turbidity, frame_image, title)
173
+ self.fig.canvas.draw()
174
+ frame_image = np.array(self.fig.canvas.renderer.buffer_rgba())
175
+ frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGBA2BGR)
176
+ return frame_image, bboxes, raw_turbidity, phase_data
177
+
178
+ def calculate_value_color(self, vial_frame, liquid_boxes):
179
+ """
180
+ Calculate the value and color for a given vial image and bounding boxes
181
+ :param vial_frame: the vial image
182
+ :param liquid_boxes: the liquid boxes (["Homo", "Hetero"])
183
+ :return: the output dict and raw turbidity per row
184
+ """
185
+ height, _, _ = vial_frame.shape
186
+ hsv_image = cv2.cvtColor(vial_frame, cv2.COLOR_BGR2HSV)
187
+ output = {
188
+ 'time': self.x_time[-1],
189
+ 'color': np.mean(hsv_image[:, :, 0]),
190
+ 'turbidity': np.mean(hsv_image[:, :, 2])
191
+ }
192
+ raw_value = np.mean(hsv_image[:, :, 2], axis=1)
193
+ for i, bbox in enumerate(liquid_boxes):
194
+ _, top, _, bottom = map(int, bbox)
195
+ roi = hsv_image[top:bottom, :]
196
+ output[f'volume_{i + 1}'] = (bottom - top) / height
197
+ output[f'color_{i + 1}'] = np.mean(roi[:, :, 0])
198
+ output[f'turbidity_{i + 1}'] = np.mean(roi[:, :, 2])
199
+ self.average_colors.append(output['color'])
200
+ self.average_turbidity.append(output['turbidity'])
201
+ return output, raw_value
202
+
203
+ @staticmethod
204
+ def _get_dynamic_font_params(img_height, base_height=200, base_font_scale=0.5, base_thickness=1):
205
+ scale_factor = img_height / base_height
206
+ font_scale = base_font_scale * scale_factor
207
+ thickness = max(1, int(base_thickness * scale_factor))
208
+ return font_scale, thickness
209
+
210
+ def draw_bounding_boxes(self, image, bboxes, class_names, thickness=None, text_right=False, on_raw=False):
211
+ """Draws bounding boxes on the image."""
212
+ output_image = image.copy()
213
+ height = image.shape[1]
214
+ font_scale, text_thickness = self._get_dynamic_font_params(height)
215
+ margin = 2
216
+ thickness = thickness or max(1, int(height / 200))
217
+ for rect in bboxes:
218
+ x1, y1, x2, y2, _, class_id = map(int, rect)
219
+ class_name = class_names[class_id]
220
+ color = self.color_palette.get(class_name, (255, 255, 255))
221
+ if on_raw and self.vial_location:
222
+ x1, y1 = x1 + self.vial_location[0], y1 + self.vial_location[1] + self.cap_rows
223
+ x2, y2 = x2 + self.vial_location[0], y2 + self.vial_location[1] + self.cap_rows
224
+ cv2.rectangle(output_image, (x1, y1), (x2, y2), color, thickness)
225
+ (text_width, text_height), baseline = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, font_scale,
226
+ text_thickness)
227
+ text_location = (
228
+ x2 - text_width - margin if text_right ^ (class_name == "Solid") else x1 + margin,
229
+ y1 + text_height + margin
230
+ )
231
+ cv2.putText(output_image, class_name, text_location, cv2.FONT_HERSHEY_SIMPLEX, font_scale, color,
232
+ text_thickness)
233
+ return output_image
234
+
235
+ def display_frame(self, y_values, image, title=None):
236
+ """
237
+ Display the image (top-left) and its turbidity values per row (top-right)
238
+ turbidity over time (bottom-left) and color over time (bottom-right)
239
+ :param y_values: the turbidity value per row
240
+ :param image: vial image frame to display
241
+ :param title: title of the image frame
242
+ """
243
+ # init plot
244
+ for ax in self.axs.flat:
245
+ ax.clear()
246
+ ax0, ax1, ax2, ax3 = self.axs.flat
247
+
248
+ # top left - vial frame and bounding boxes
249
+ image_copy = image.copy()
250
+ image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
251
+ ax0.imshow(np.flipud(image_copy), origin='lower')
252
+ if title:
253
+ ax0.set_title(title)
254
+
255
+ # use fill between to optimize the speed 154.9857677 -> 68.15193
256
+ x_values = np.arange(len(y_values))
257
+ ax1.fill_betweenx(x_values, 0, y_values[::-1], color='green', alpha=0.5)
258
+ ax1.set_ylim(0, len(y_values))
259
+ ax1.set_xlim(0, 255)
260
+ ax1.xaxis.set_label_position('top')
261
+ ax1.set_xlabel('Turbidity per row')
262
+
263
+ realtime_tick_label = None
264
+
265
+ # bottom left - turbidity
266
+ ax2.set_ylabel('Turbidity')
267
+ ax2.set_xlabel('Time / min')
268
+ ax2.plot(self.x_time, self.average_turbidity)
269
+ ax2.set_xticks([self.x_time[0], self.x_time[-1]], realtime_tick_label)
270
+
271
+
272
+
273
+ # bottom right - color
274
+ ax3.set_ylabel('Color (hue)')
275
+ ax3.set_xlabel('Time / min')
276
+ ax3.plot(self.x_time, self.average_colors)
277
+ ax3.set_xticks([self.x_time[0], self.x_time[-1]], realtime_tick_label)
278
+
279
+
280
+ def image_demo(self, pil_image, cap_ratio=0):
281
+
282
+ self.clear_cache()
283
+ frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) # PIL → OpenCV
284
+ phase_data = {}
285
+ self.config.CAP_RATIO = cap_ratio
286
+ if self.find_vial(frame):
287
+ vial_frame = self.crop_rectangle(frame, self.vial_location)
288
+ x1, y1, x2, y2 = self.vial_location
289
+ self.x_time.append(0)
290
+ frame_image, bboxes, _, phase_data = self.process_vial_frame(vial_frame)
291
+ boxes_on_vial = self.draw_bounding_boxes(vial_frame, bboxes, self.contents_model.names, on_raw=False)
292
+ masked_frame = highlight_vial_body(frame, self.vial_location, cap_ratio=cap_ratio)
293
+ masked_frame[y1 + self.cap_rows :y2, x1:x2] = boxes_on_vial
294
+
295
+ # bboxes_on_raw = self.draw_bounding_boxes(masked_frame, bboxes, self.contents_model.names, on_raw=True)
296
+ result = masked_frame
297
+ else:
298
+ result = frame
299
+ result_rgb = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) # OpenCV → RGB
300
+ return Image.fromarray(result_rgb), phase_data
301
+
302
+
303
+
models/best_content.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00b0be2dd8eec4aedd5d56da0fa196b95454d420a0df7eae453178ab9fbcc485
3
+ size 52009878
models/best_vessel.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93a49e5f80434b35f9244ab67fb17f5440be84ec0427834dff28ad98fa83bc58
3
+ size 22496110
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Ultralytics
2
+ Pillow
3
+ gradio