dnth commited on
Commit
44260d0
·
1 Parent(s): e252420
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ print("Reinstalling mmcv")
4
+ subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "mmcv-full==1.3.17"])
5
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "mmcv-full==1.3.17", "-f", "https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html"])
6
+ print("mmcv install complete")
7
+
8
+ ## Only works if we reinstall mmcv here.
9
+
10
+ from gradio.outputs import Label
11
+ from icevision.all import *
12
+ from icevision.models.checkpoint import *
13
+ import PIL
14
+ import gradio as gr
15
+ import os
16
+
17
+ # Load model
18
+ checkpoint_path = "models/model_checkpoint.pth"
19
+ checkpoint_and_model = model_from_checkpoint(checkpoint_path)
20
+ model = checkpoint_and_model["model"]
21
+ model_type = checkpoint_and_model["model_type"]
22
+ class_map = checkpoint_and_model["class_map"]
23
+
24
+ # Transforms
25
+ img_size = checkpoint_and_model["img_size"]
26
+ valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()])
27
+
28
+ for root, dirs, files in os.walk(r"sample_images/"):
29
+ for filename in files:
30
+ print("Loading sample image:", filename)
31
+
32
+
33
+ # Populate examples in Gradio interface
34
+ example_images = [["sample_images/" + file] for file in files]
35
+ # Columns: Input Image | Label | Box | Detection Threshold
36
+ examples = [
37
+ [example_images[0], False, True, 0.5],
38
+ [example_images[1], True, True, 0.5],
39
+ [example_images[2], False, True, 0.7],
40
+ [example_images[3], True, True, 0.7],
41
+ [example_images[4], False, True, 0.5],
42
+ [example_images[5], False, True, 0.5],
43
+ [example_images[6], False, True, 0.6],
44
+ [example_images[7], False, True, 0.6],
45
+ ]
46
+
47
+
48
+ def show_preds(input_image, display_label, display_bbox, detection_threshold):
49
+ if detection_threshold == 0:
50
+ detection_threshold = 0.5
51
+ img = PIL.Image.fromarray(input_image, "RGB")
52
+ pred_dict = model_type.end2end_detect(
53
+ img,
54
+ valid_tfms,
55
+ model,
56
+ class_map=class_map,
57
+ detection_threshold=detection_threshold,
58
+ display_label=display_label,
59
+ display_bbox=display_bbox,
60
+ return_img=True,
61
+ font_size=16,
62
+ label_color="#FF59D6",
63
+ )
64
+ return pred_dict["img"], len(pred_dict["detection"]["bboxes"])
65
+
66
+
67
+ # display_chkbox = gr.inputs.CheckboxGroup(["Label", "BBox"], label="Display", default=True)
68
+ display_chkbox_label = gr.inputs.Checkbox(label="Label", default=False)
69
+ display_chkbox_box = gr.inputs.Checkbox(label="Box", default=True)
70
+ detection_threshold_slider = gr.inputs.Slider(
71
+ minimum=0, maximum=1, step=0.1, default=0.5, label="Detection Threshold"
72
+ )
73
+ outputs = [
74
+ gr.outputs.Image(type="pil", label="RetinaNet Inference"),
75
+ gr.outputs.Textbox(type="number", label="Microalgae Count"),
76
+ ]
77
+
78
+ article = "<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>"
79
+
80
+ # Option 1: Get an image from local drive
81
+ gr_interface = gr.Interface(
82
+ fn=show_preds,
83
+ inputs=[
84
+ "image",
85
+ display_chkbox_label,
86
+ display_chkbox_box,
87
+ detection_threshold_slider,
88
+ ],
89
+ outputs=outputs,
90
+ title="Microalgae Detector with RetinaNet",
91
+ description="This RetinaNet model counts microalgaes on a given image. Upload an image or click an example image below to use.",
92
+ article=article,
93
+ examples=examples,
94
+ )
95
+ # # Option 2: Grab an image from a webcam
96
+ # gr_interface = gr.Interface(fn=show_preds, inputs=["webcam", display_chkbox_label, display_chkbox_box, detection_threshold_slider], outputs=outputs, title='IceApp - COCO', live=False)
97
+ # # Option 3: Continuous image stream from the webcam
98
+ # gr_interface = gr.Interface(fn=show_preds, inputs=["webcam", display_chkbox_label, display_chkbox_box, detection_threshold_slider], outputs=outputs, title='IceApp - COCO', live=True)
99
+ gr_interface.launch(inline=False, share=False, debug=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ mmdet==2.19.0
2
+ gradio
3
+ icevision[all]
4
+ mmcv-full==1.3.17 -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html
5
+
sample_images/IMG_20191212_151351.jpg ADDED
sample_images/IMG_20191212_151438.jpg ADDED
sample_images/IMG_20191212_151559.jpg ADDED
sample_images/IMG_20191212_151714.jpg ADDED
sample_images/IMG_20191212_151844.jpg ADDED
sample_images/IMG_20191212_153339.jpg ADDED
sample_images/IMG_20191212_153420.jpg ADDED
sample_images/IMG_20191212_153457.jpg ADDED
sample_images/IMG_20191212_153614.jpg ADDED
sample_images/IMG_20191212_154100.jpg ADDED
sample_images/IMG_20191212_154209.jpg ADDED
sample_images/IMG_20191212_154330.jpg ADDED
sample_images/IMG_20191212_154452.jpg ADDED
sample_images/IMG_20191212_154600.jpg ADDED