HarryEslick commited on
Commit
f6accfe
·
verified ·
1 Parent(s): 553fd63

create app.py

Browse files
Files changed (1) hide show
  1. app.py +390 -0
app.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio web application for detecting and measuring objects in images.
3
+
4
+ Key features:
5
+ - Image scaling tool to set measurement reference
6
+ - Object detection using YOLOv8 model for scallop/spat detection
7
+ - Interactive annotation of detected objects
8
+ - Size measurements in mm based on scale reference
9
+ - Statistics and histogram visualization of object sizes
10
+ - Export results to CSV
11
+
12
+ ## TODO:
13
+ - [ ] Load annotations from T-Rex
14
+
15
+ """
16
+
17
+ # %% #|> Imports |
18
+ from pathlib import Path
19
+ import cv2
20
+ import gradio as gr
21
+ from gradio_image_annotation import image_annotator
22
+ import numpy as np
23
+ import pandas as pd
24
+ import supervision as sv
25
+
26
+ import plotly.express as px
27
+
28
+ from spatstatapp.inference import inference_large
29
+ from spatstatapp.plotting import coco_to_detections
30
+ from spatstatapp.tile_training_data import load_bboxes
31
+
32
+ import gradio as gr
33
+ import numpy as np
34
+ from PIL import Image, ImageDraw
35
+ import cv2
36
+
37
+ # %% load image and detections |
38
+ model_path = Path("models/best.pt")
39
+
40
+ data_dir=Path("img")
41
+ data_dir.exists()
42
+
43
+ train_images = list(data_dir.glob('shells.png'))
44
+
45
+ default_images = {Path(img).stem: str(img) for img in train_images}
46
+
47
+ class PointSelector:
48
+ def __init__(self, image=None):
49
+ self.points = []
50
+ self.og_img = image
51
+ self.image_path = image
52
+ self.line_len_px = None
53
+ self.line_len_mm = None
54
+
55
+ def reset(self):
56
+ self.points = []
57
+ return self.og_img, "Points cleared"
58
+
59
+ def reset_og_img(self, image):
60
+ # self.og_img = image.copy()
61
+ # raise Exception(image)
62
+ self.og_img = None
63
+ self.points = []
64
+ self.image_path = image
65
+
66
+ def add_point(self, image, evt: gr.SelectData):
67
+ img_draw = cv2.imread(image)#[:,:,::-1]
68
+ if (len(self.points) == 0):# & (self.og_img is None):
69
+ self.image_path = image
70
+ # self.og_img = cv2.imread(image)
71
+ # img_draw = self.og_img.copy()
72
+
73
+ if len(self.points) >= 2:
74
+ self.points = []
75
+ img_draw = cv2.imread(self.image_path)
76
+ # img_draw = self.og_img.copy()
77
+
78
+ self.points.append((evt.index[0], evt.index[1]))
79
+
80
+ # Draw on image
81
+ # img_draw = image.copy()
82
+ if len(self.points) > 0:
83
+ for pt in self.points:
84
+ cv2.circle(img_draw, (int(pt[0]), int(pt[1])), 5, (255,0,0), -1)
85
+
86
+ if len(self.points) == 2:
87
+ cv2.line(img_draw,
88
+ (int(self.points[0][0]), int(self.points[0][1])),
89
+ (int(self.points[1][0]), int(self.points[1][1])),
90
+ (0,255,0), 3)
91
+
92
+ # Calculate distance
93
+ dist = np.sqrt((self.points[1][0] - self.points[0][0])**2 +
94
+ (self.points[1][1] - self.points[0][1])**2)
95
+ msg = f"Distance: {dist:.1f} pixels"
96
+ self.line_len_px = dist
97
+ else:
98
+ msg = f"Click point {len(self.points)+1}"
99
+
100
+ return img_draw[:,:,::-1], msg
101
+
102
+ def set_line_length(self, line_len_mm, button):
103
+ self.line_len_mm = line_len_mm
104
+ return self.check_scale_set(button)
105
+
106
+ def check_scale_set(self, button):
107
+ if (self.line_len_mm is not None) & (self.line_len_px is not None):
108
+ # if True:
109
+ return gr.update(visible=True)
110
+ else:
111
+ return gr.update(visible=False)
112
+
113
+ def save_scaled_boxes(self, annotator):
114
+ try:
115
+ json_data = annotator["boxes"]
116
+ if len(json_data)==0:
117
+ return None
118
+ else:
119
+ df = pd.DataFrame(json_data).drop(columns=["color"], errors='ignore')
120
+ df["xrange"] = ((df["xmax"] - df["xmin"])*(self.line_len_mm/self.line_len_px)).round(2)
121
+ df["yrange"] = ((df["ymax"] - df["ymin"])*(self.line_len_mm/self.line_len_px)).round(2)
122
+ df["mean_daimeter_mm"] = ((df["yrange"]+df["xrange"])/2).round(2)
123
+
124
+ return df
125
+ except Exception as e:
126
+ return None
127
+
128
+
129
+
130
+
131
+ def detections_to_json(detections:sv.Detections, image:np.ndarray):
132
+ """Add predictions to canvas"""
133
+ boxes = []
134
+
135
+ for xyxy, mask, confidence, class_id, tracker_id, data in detections:
136
+ xmin, ymin, xmax, ymax = xyxy
137
+ obj = {
138
+ "xmin": float(xmin),
139
+ "ymin": float(ymin),
140
+ "xmax": float(xmax),
141
+ "ymax": float(ymax),
142
+ "label": "",# data["class_name"],
143
+ "color": (255, 0, 0)
144
+ }
145
+ boxes.append(obj)
146
+
147
+ annotation = {
148
+ "image": image,
149
+ "boxes": boxes
150
+ }
151
+
152
+ return annotation
153
+
154
+
155
+
156
+
157
+ def create_histogram(df):
158
+ # print(type(df))
159
+ # print(len(df))
160
+ print()
161
+ if df is None or len(df) == 0 or df.iloc[0,0]=="":
162
+ return None
163
+ fig = px.histogram(df, x="mean_daimeter_mm",
164
+ title="Distribution of Shell Sizes",
165
+ labels={"mean_daimeter_mm": "Mean Diameter (mm)"},
166
+ nbins=30)
167
+ return fig
168
+
169
+
170
+ # def get_boxes_table(annotator):
171
+ # json_data = annotator["boxes"]
172
+ # if len(json_data)==0:
173
+ # return pd.DataFrame()
174
+ # else:
175
+ # df = pd.DataFrame(json_data).drop(columns=["color"], errors='ignore')
176
+ # return df
177
+ from ultralytics.utils.ops import xywhn2xyxy
178
+
179
+
180
+
181
+ def find_boxes_json(image_path):
182
+ # print(annotator)
183
+ img = cv2.imread(image_path)
184
+ detections = inference_large(img, model_path, sam_path=None, edge_pct=0.01, conf_threshold=0.4, overlap_px=100, tile_px=400)
185
+ annotations = detections_to_json(detections, image_path)
186
+ annotations["image"] = image_path
187
+ # annotator.update(annotations)
188
+ annotator = image_annotator(
189
+ annotations,
190
+ boxes_alpha=0.02,
191
+ handle_size=4,
192
+ show_label=False,
193
+ )
194
+ return annotator, annotations["boxes"]
195
+
196
+ def load_coco_boxes(image_path, coco_file, class_labels="scallop"):
197
+ image = cv2.imread(image_path)[:,:,::-1]
198
+ detections = coco_to_detections(coco_file, image)
199
+ annotations = detections_to_json(detections, image)
200
+ annotations["image"] = image_path
201
+ # annotator.update(annotations)
202
+ annotator = image_annotator(
203
+ annotations,
204
+ boxes_alpha=0.02,
205
+ handle_size=4,
206
+ show_label=False,
207
+ )
208
+ return annotator, annotations["boxes"]
209
+
210
+
211
+ selector = PointSelector()
212
+
213
+ # %% Tab 1 |
214
+ with gr.Blocks() as demo:
215
+ with gr.Tabs() as tabs:
216
+ with gr.TabItem("Scale ", id=0):
217
+ with gr.Row():
218
+ with gr.Column(scale=10):
219
+ image_input = gr.Image(label="Click two points to measure distance", type="filepath", interactive=True)
220
+
221
+ default_image = gr.Dropdown(
222
+ choices=["None"] + list(default_images.keys()),
223
+ label="Use default image?",
224
+ # value=list(default_images.keys())[0] if default_images else None
225
+ )
226
+ with gr.Column(scale=1, min_width=200):
227
+ filename = gr.Textbox(label="Filename")
228
+ output_text = gr.Textbox(label="Status", value="Click two points to measure distance")
229
+ line_length_mm = gr.Number(label="line length in mm")
230
+ target_select = gr.Radio(label ="select target:", visible=True, choices=["scallop", "spat"])
231
+ button_find = gr.Button("find bounding boxes", visible=False)
232
+ load_annot_btn = gr.Button("Load Existing boxes (Optional)", visible=False)
233
+ load_annot = gr.File(label="Load Existing boxes (Optional)",file_types=[".txt"], file_count="single", visible=False, height=500)
234
+ # test_text = gr.Textbox(label="test")
235
+
236
+
237
+
238
+ # %% T1: event handlers |
239
+ # image_input.upload()
240
+ default_image.change(
241
+ lambda x: default_images[x] if x in default_images.keys() else None,
242
+ inputs=[default_image],
243
+ outputs=[image_input]
244
+ )
245
+
246
+ image_input.upload(
247
+ selector.reset_og_img,
248
+ inputs=[image_input],
249
+ )
250
+ image_input.upload(
251
+ lambda x: Path(x).name,
252
+ inputs=[image_input],
253
+ outputs=[filename]
254
+ )
255
+
256
+ # Event handlers
257
+ image_input.select(
258
+ selector.add_point,
259
+ inputs=[image_input],
260
+ outputs=[image_input, output_text]
261
+ )
262
+
263
+ line_length_mm.change(
264
+ selector.set_line_length,
265
+ inputs=[line_length_mm, button_find],
266
+ outputs=[button_find]
267
+ )
268
+
269
+ line_length_mm.change(
270
+ selector.check_scale_set,
271
+ inputs = load_annot_btn,
272
+ outputs=load_annot_btn,
273
+ )
274
+
275
+ load_annot_btn.click(
276
+ lambda: gr.update(visible=True),
277
+ outputs=[load_annot]
278
+ )
279
+
280
+ # load_annot.upload(
281
+ # load_bboxes,
282
+ # inputs=[load_annot],
283
+ # outputs=[test_text]
284
+ # )
285
+
286
+ # %% Tab2 |
287
+ with gr.TabItem("Object annotation", id=1, visible=True):
288
+ annotator = image_annotator(
289
+ boxes_alpha=0.02,
290
+ handle_size=4,
291
+ show_label=False,
292
+ label_list=["scallop", "spat"],
293
+ label_colors=[(255, 0, 0), (255, 200, 0)]
294
+ )
295
+
296
+ # button_get = gr.Button("Get bounding boxes")
297
+ download_file = gr.File(
298
+ label="Download CSV",
299
+ visible=True,
300
+ # interactive=True
301
+ )
302
+ with gr.Row():
303
+ with gr.Column(scale=1):
304
+ obj_count = gr.Textbox(label="Object count")
305
+ # button_save = gr.Button("save bounding boxes")
306
+ with gr.Column(scale=1):
307
+ obj_size = gr.Textbox(value = "Has the scale size been set?" ,label="Mean size")
308
+
309
+ histogram = gr.Plot()
310
+ table = gr.DataFrame(
311
+ max_height=500,
312
+ )
313
+ # table = gr.Textbox(label="Status", value=1)
314
+
315
+ json_data = gr.JSON(value={}, visible=False)
316
+
317
+
318
+ # %% T2: event handlers |
319
+ json_boxes = button_find.click(
320
+ fn=find_boxes_json,
321
+ inputs=[image_input],
322
+ outputs=[annotator, json_data]
323
+ )
324
+
325
+ button_find.click(
326
+ fn=lambda: gr.Tabs(selected=1),
327
+ outputs=tabs
328
+ )
329
+
330
+ json_boxes = load_annot.upload(
331
+ fn=load_coco_boxes,
332
+ inputs=[image_input, load_annot],
333
+ outputs=[annotator, json_data]
334
+ )
335
+
336
+ # button_find.click(
337
+ # fn=change_tab,
338
+ # # inputs=[annotator, image_input],
339
+ # outputs=tabs
340
+ # )
341
+
342
+ # annotator.change(
343
+ # json_boxes = button_get.click(
344
+ json_boxes = annotator.change(
345
+ fn=selector.save_scaled_boxes,
346
+ inputs= [annotator],
347
+ outputs= table
348
+ )
349
+
350
+ table.change(
351
+ fn=create_histogram,
352
+ inputs=[table],
353
+ outputs=[histogram]
354
+ )
355
+
356
+ def df_mean_count(df):
357
+ try:
358
+ mean = df["mean_daimeter_mm"].mean().round(2)
359
+ count = len(df)
360
+ return mean, count
361
+ except Exception as e:
362
+ return "Has the scale size been set?", None
363
+
364
+ table.change(
365
+ fn=df_mean_count,
366
+ inputs=[table],
367
+ outputs=[obj_size, obj_count]
368
+ )
369
+
370
+ def save_and_download_table(df, img_name):
371
+ try:
372
+ # Create temporary file with .csv extension
373
+ # with NamedTemporaryFile(delete=False, suffix='.csv') as tmp_file:
374
+ # csv_path = tmp_file.name
375
+ csv_path = Path(img_name).stem +"_boxes.csv"
376
+ df.to_csv(csv_path, index=False)
377
+ return csv_path
378
+ except Exception as e:
379
+ return None
380
+
381
+ table.change(
382
+ fn=save_and_download_table,
383
+ inputs=[table, filename],
384
+ outputs=[download_file]
385
+ )
386
+
387
+
388
+ if __name__ == "__main__":
389
+ demo.launch()
390
+