|
|
import os |
|
|
import gradio as gr |
|
|
from fastai.vision.all import * |
|
|
|
|
|
from icevision.all import * |
|
|
from icevision.models.checkpoint import * |
|
|
|
|
|
print("Loading images") |
|
|
for root, dirs, files in os.walk(r"sample_images/"): |
|
|
for filename in files: |
|
|
print(filename) |
|
|
|
|
|
print("Loading classifier") |
|
|
classifier = load_learner("models/learner.pkl") |
|
|
classifier_labels = classifier.dls.vocab |
|
|
|
|
|
print("Loading detector") |
|
|
checkpoint_path = "models/model_checkpoint.pth" |
|
|
checkpoint_and_model = model_from_checkpoint(checkpoint_path) |
|
|
model = checkpoint_and_model["model"] |
|
|
model_type = checkpoint_and_model["model_type"] |
|
|
class_map = checkpoint_and_model["class_map"] |
|
|
|
|
|
img_size = checkpoint_and_model["img_size"] |
|
|
valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]) |
|
|
|
|
|
|
|
|
def draw_eyes(img): |
|
|
pred_dict = model_type.end2end_detect( |
|
|
img, |
|
|
valid_tfms, |
|
|
model, |
|
|
class_map=class_map, |
|
|
detection_threshold=0.5, |
|
|
display_label=False, |
|
|
color_map={"in": "#FF4040", "out": "#FFC71E"}, |
|
|
) |
|
|
|
|
|
|
|
|
for i, bbox in enumerate(pred_dict["detection"]["bboxes"]): |
|
|
x, y, w, h = pred_dict["detection"]["bboxes"][i].xywh |
|
|
xmin, ymin, xmax, ymax = pred_dict["detection"]["bboxes"][i].xyxy |
|
|
center = (int((xmin + xmax) / 2), int((ymin + ymax) / 2)) |
|
|
|
|
|
if pred_dict["detection"]["labels"][i] == "out": |
|
|
color_value = (255, 0, 0) |
|
|
else: |
|
|
color_value = (8, 39, 245) |
|
|
|
|
|
image = cv2.rectangle(np.array(img), (x, y), (x + w, y + h), color_value, 2) |
|
|
image = cv2.circle(image, center, 5, color_value, -1) |
|
|
image = cv2.putText( |
|
|
image, |
|
|
f"w:{w} h:{h}", |
|
|
(x, y - 10), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
1, |
|
|
color_value, |
|
|
2, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
|
|
|
img = Image.fromarray(image) |
|
|
|
|
|
return img |
|
|
|
|
|
|
|
|
def predict(img): |
|
|
img = PILImage.create(img) |
|
|
pred, pred_idx, probs = classifier.predict(img) |
|
|
img = draw_eyes(img) |
|
|
return { |
|
|
classifier_labels[i]: float(probs[i]) for i in range(len(classifier_labels)) |
|
|
}, img |
|
|
|
|
|
|
|
|
title = "NSWX Electrode Classifier" |
|
|
description = "Upload an image of a bare electrode or select from the examples below" |
|
|
interpretation = "default" |
|
|
examples = ["sample_images/" + file for file in files] |
|
|
article = "<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>" |
|
|
enable_queue = False |
|
|
|
|
|
gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.inputs.Image( label="Input image"), |
|
|
outputs=[ |
|
|
gr.outputs.Label(num_top_classes=5, label="Electrode Class"), |
|
|
gr.outputs.Image(type="pil", label="WE Dimensions"), |
|
|
], |
|
|
title=title, |
|
|
description=description, |
|
|
article=article, |
|
|
examples=examples, |
|
|
interpretation=interpretation, |
|
|
enable_queue=enable_queue, |
|
|
allow_flagging="manual", |
|
|
flagging_options=["This should be OK", "This should be KIV_COL", "This should be KIV_CMT", "This should be NG_DIM", "This should be NG_MSA"], |
|
|
theme="grass", |
|
|
css = ".output-image, .input-image, .image-preview {height: 600px !important} ", |
|
|
).launch(server_name="0.0.0.0") |
|
|
|