dnth commited on
Commit
4e1708d
·
1 Parent(s): 7be1947

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -18
app.py CHANGED
@@ -1,39 +1,101 @@
 
1
  import gradio as gr
2
  from fastai.vision.all import *
3
 
4
- learn = load_learner("models/learner.pkl")
 
5
 
6
- labels = learn.dls.vocab
 
 
 
7
 
8
- def predict(img):
9
- img = PILImage.create(img)
10
- pred, pred_idx, probs = learn.predict(img)
11
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
12
 
 
 
 
 
 
 
13
 
14
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- for root, dirs, files in os.walk(r"sample_images/"):
17
- for filename in files:
18
- print(filename)
19
 
20
- title = "NSWX Bare Electrode Classifier"
21
- description = "This model (MobilenetV3) can classify bare electrodes into 5 classes from the Ok, KIV, NG group"
22
  interpretation = "default"
23
  examples = ["sample_images/" + file for file in files]
24
  article = "<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>"
25
- enable_queue = True
26
 
27
  gr.Interface(
28
  fn=predict,
29
- inputs=gr.inputs.Image(shape=(640, 640)),
30
- outputs=gr.outputs.Label(num_top_classes=5),
 
 
 
31
  title=title,
32
  description=description,
33
  article=article,
34
  examples=examples,
35
  interpretation=interpretation,
36
  enable_queue=enable_queue,
37
- capture_session=True,
38
- theme="grass"
39
- ).launch()
 
 
 
1
+ import os
2
  import gradio as gr
3
  from fastai.vision.all import *
4
 
5
+ from icevision.all import *
6
+ from icevision.models.checkpoint import *
7
 
8
+ print("Loading images")
9
+ for root, dirs, files in os.walk(r"sample_images/"):
10
+ for filename in files:
11
+ print(filename)
12
 
13
+ print("Loading classifier")
14
+ classifier = load_learner("models/learner.pkl")
15
+ classifier_labels = classifier.dls.vocab
 
16
 
17
+ print("Loading detector")
18
+ checkpoint_path = "eye_detection/models/model_checkpoint.pth"
19
+ checkpoint_and_model = model_from_checkpoint(checkpoint_path)
20
+ model = checkpoint_and_model["model"]
21
+ model_type = checkpoint_and_model["model_type"]
22
+ class_map = checkpoint_and_model["class_map"]
23
 
24
+ img_size = checkpoint_and_model["img_size"]
25
+ valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()])
26
+
27
+
28
+ def draw_eyes(img):
29
+ pred_dict = model_type.end2end_detect(
30
+ img,
31
+ valid_tfms,
32
+ model,
33
+ class_map=class_map,
34
+ detection_threshold=0.5,
35
+ display_label=False,
36
+ color_map={"in": "#FF4040", "out": "#FFC71E"},
37
+ )
38
+
39
+ # Draw bbox with cv
40
+ for i, bbox in enumerate(pred_dict["detection"]["bboxes"]):
41
+ x, y, w, h = pred_dict["detection"]["bboxes"][i].xywh
42
+ xmin, ymin, xmax, ymax = pred_dict["detection"]["bboxes"][i].xyxy
43
+ center = (int((xmin + xmax) / 2), int((ymin + ymax) / 2))
44
+
45
+ if pred_dict["detection"]["labels"][i] == "out":
46
+ color_value = (255, 0, 0)
47
+ else:
48
+ color_value = (8, 39, 245)
49
+
50
+ image = cv2.rectangle(np.array(img), (x, y), (x + w, y + h), color_value, 2)
51
+ image = cv2.circle(image, center, 5, color_value, -1)
52
+ image = cv2.putText(
53
+ image,
54
+ f"w:{w} h:{h}",
55
+ (x, y - 10),
56
+ cv2.FONT_HERSHEY_SIMPLEX,
57
+ 1,
58
+ color_value,
59
+ 2,
60
+ cv2.LINE_AA,
61
+ )
62
+
63
+ img = Image.fromarray(image)
64
+
65
+ return img
66
+
67
+
68
+ def predict(img):
69
+ img = PILImage.create(img)
70
+ pred, pred_idx, probs = classifier.predict(img)
71
+ img = draw_eyes(img)
72
+ return {
73
+ classifier_labels[i]: float(probs[i]) for i in range(len(classifier_labels))
74
+ }, img
75
 
 
 
 
76
 
77
+ title = "NSWX Electrode Classifier"
78
+ description = "Upload an image of a bare electrode or select from the examples below"
79
  interpretation = "default"
80
  examples = ["sample_images/" + file for file in files]
81
  article = "<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>"
82
+ enable_queue = False
83
 
84
  gr.Interface(
85
  fn=predict,
86
+ inputs=gr.inputs.Image( label="Input image"),
87
+ outputs=[
88
+ gr.outputs.Label(num_top_classes=5, label="Electrode Class"),
89
+ gr.outputs.Image(type="pil", label="WE Dimensions"),
90
+ ],
91
  title=title,
92
  description=description,
93
  article=article,
94
  examples=examples,
95
  interpretation=interpretation,
96
  enable_queue=enable_queue,
97
+ allow_flagging="manual",
98
+ flagging_options=["This should be OK", "This should be KIV_COL", "This should be KIV_CMT", "This should be NG_DIM", "This should be NG_MSA"],
99
+ theme="grass",
100
+ css = ".output-image, .input-image, .image-preview {height: 600px !important} ",
101
+ ).launch(server_name="0.0.0.0")