import os import gradio as gr from fastai.vision.all import * from icevision.all import * from icevision.models.checkpoint import * print("Loading images") for root, dirs, files in os.walk(r"sample_images/"): for filename in files: print(filename) print("Loading classifier") classifier = load_learner("models/learner.pkl") classifier_labels = classifier.dls.vocab print("Loading detector") checkpoint_path = "models/model_checkpoint.pth" checkpoint_and_model = model_from_checkpoint(checkpoint_path) model = checkpoint_and_model["model"] model_type = checkpoint_and_model["model_type"] class_map = checkpoint_and_model["class_map"] img_size = checkpoint_and_model["img_size"] valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]) def draw_eyes(img): pred_dict = model_type.end2end_detect( img, valid_tfms, model, class_map=class_map, detection_threshold=0.5, display_label=False, color_map={"in": "#FF4040", "out": "#FFC71E"}, ) # Draw bbox with cv for i, bbox in enumerate(pred_dict["detection"]["bboxes"]): x, y, w, h = pred_dict["detection"]["bboxes"][i].xywh xmin, ymin, xmax, ymax = pred_dict["detection"]["bboxes"][i].xyxy center = (int((xmin + xmax) / 2), int((ymin + ymax) / 2)) if pred_dict["detection"]["labels"][i] == "out": color_value = (255, 0, 0) else: color_value = (8, 39, 245) image = cv2.rectangle(np.array(img), (x, y), (x + w, y + h), color_value, 2) image = cv2.circle(image, center, 5, color_value, -1) image = cv2.putText( image, f"w:{w} h:{h}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, color_value, 2, cv2.LINE_AA, ) img = Image.fromarray(image) return img def predict(img): img = PILImage.create(img) pred, pred_idx, probs = classifier.predict(img) img = draw_eyes(img) return { classifier_labels[i]: float(probs[i]) for i in range(len(classifier_labels)) }, img title = "NSWX Electrode Classifier" description = "Upload an image of a bare electrode or select from the examples below" interpretation = "default" examples = ["sample_images/" + file for file in files] article = "
" enable_queue = False gr.Interface( fn=predict, inputs=gr.inputs.Image( label="Input image"), outputs=[ gr.outputs.Label(num_top_classes=5, label="Electrode Class"), gr.outputs.Image(type="pil", label="WE Dimensions"), ], title=title, description=description, article=article, examples=examples, interpretation=interpretation, enable_queue=enable_queue, allow_flagging="manual", flagging_options=["This should be OK", "This should be KIV_COL", "This should be KIV_CMT", "This should be NG_DIM", "This should be NG_MSA"], theme="grass", css = ".output-image, .input-image, .image-preview {height: 600px !important} ", ).launch(server_name="0.0.0.0")