import torch from torchvision.transforms.functional import pil_to_tensor import gradio as gr from gradio.utils import get_upload_folder from huggingface_hub import hf_hub_download from external_models import EfficientNet, MobileNet, ResNet, Swin from utils import get_preprocessing from pathlib import Path from PIL import Image from tempfile import NamedTemporaryFile import json import os import cv2 import pandas as pd import numpy as np device = "cpu" models = { "mbnet": MobileNet, "effnet": EfficientNet, "resnet": ResNet, "swin": Swin, } model_filenames = { "EfficientNetV2-S": "efficientnetv2s.pth", "MobileNetV3-L": "mobilenetv3l.pth", "ResNet101": "resnet101.pth", "Swin V2-B": "swinv2b.pth", } model_names = { "effnet": "EfficientNetV2-S", "mbnet": "MobileNetV3-L", "resnet": "ResNet101", "swin": "Swin V2-B", } def cropped_img(img: np.ndarray | Image.Image | str): """ Takes an image and automatically crops the nematode. Returns the cropped image and the binary mask of the original image that outlines the nematode Parameters ---------- img : np.ndarray Image Returns ------- tuple[float, float, float, float] Cropped image bounding box """ if isinstance(img, str): img = Image.open(img).convert("RGB") if isinstance(img, Image.Image): img = np.array(img) rgb = img gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY) # EDGE DETECTION edges = cv2.Canny(gray, 25, 25, apertureSize=3, L2gradient=True) # FILLS IN NEMATODE EDGES BY "PUFFING" IT UP, ALSO REMOVES OTHER DEBRIS kernel = np.ones((11, 11), np.uint8) edges_dilate = cv2.dilate(edges, kernel, iterations=3) edges_erode = cv2.erode(edges_dilate, kernel, iterations=3) cnts, _ = cv2.findContours(edges_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) cnt = max(cnts, key=cv2.contourArea) fill = np.zeros(edges.shape, np.uint8) cv2.drawContours(fill, [cnt], -1, 255, cv2.FILLED) # CROPS THE BINARY IMAGE DEPENDING ON WHERE THE WHITE PIXELS ARE x1, y1 = ( max(np.argmax(fill.max(0)), 0), max(np.argmax(fill.max(1)), 0), ) x2, y2 = ( min( fill.shape[1] - np.argmax(np.flip(fill.max(0))), fill.shape[1], ), min( fill.shape[0] - np.argmax(np.flip(fill.max(1))), fill.shape[0], ), ) if y2 - y1 < x2 - x1: delta = ((x2 - x1) - (y2 - y1)) // 2 if y1 < delta: y2 += 2 * delta - y1 y1 = 0 else: y1 -= delta y2 += delta else: delta = ((y2 - y1) - (x2 - x1)) // 2 if x1 < delta: x2 += 2 * delta - x1 x1 = 0 else: x1 -= delta x2 += delta y, x = rgb.shape[:2] x2 = min(x2, x) y2 = min(y2, y) x1 = max(0, x1) y1 = max(0, y1) # CROPS AND RESIZES IMAGE return x1, y1, x2, y2 model, preprocessing, class_to_idx, idx_to_class = None, None, None, None current_model_type = None results_cache: dict[str, str] = {} current_image = None autocrop = False temp_files: list[str] = [] all_images: list[str] = [] def load_model(model_name: str = "EfficientNetV2-S"): """ Loads model and modifies global state """ global model, preprocessing, class_to_idx, idx_to_class, current_model_type if model_name is not None: filename = model_filenames[model_name] filepath = hf_hub_download( repo_id="VikramR/NematodeClassification", filename=filename, ) (model_state, _, _, _, _, _, config, class_to_idx, _) = torch.load( filepath, map_location=device ) current_model_type = config["model_type"] model = models[config["model_type"]](config).to(device) model.load_state_dict(model_state) model = model.eval() idx_to_class = {idx: img_cls for img_cls, idx in class_to_idx.items()} preprocessing = get_preprocessing(current_model_type) def display_model(): """ Displays the current selected model in the textbox """ global current_model_type model_name = model_names[current_model_type] return f"Current Model Type: {model_name}. Use dropdown on the right to change it." def clear(): """ Resets global state """ global results_cache, current_image results_cache = {} current_image = None for file in all_images: os.remove(file) for file in temp_files: os.remove(file) @torch.no_grad() def run_image(img: Image.Image): global preprocessing, device, model, class_to_idx, idx_to_class img = pil_to_tensor(img)[None].to(device) img = preprocessing(img) logits = model(img) probs = torch.nn.functional.softmax(logits, dim=1)[0] prob, label = torch.max(probs, dim=0) n_classes = len(class_to_idx) results = { "Probability": list(range(n_classes)), "Class": [idx_to_class[i] for i in range(n_classes)], } for i in range(n_classes): results["Probability"][i] = float(probs[i].item()) label = idx_to_class[label.item()] prob = prob.item() return results, (prob, label) def prev_crop_preview() -> str: """ Preview for the current cropped image """ global autocrop, current_image, temp_files if current_image is None: return None img = Image.open(current_image).convert("RGB") if autocrop: box = cropped_img(img) img = img.crop(box) with NamedTemporaryFile( mode="wb", dir=get_upload_folder(), suffix=".png", delete=False ) as f: pth = f.name img.save(f) temp_files.append(f.name) return pth def predict(img: str) -> gr.BarPlot: global results_cache img = Image.open(img).convert("RGB") result, (prob, label) = run_image(img) df = pd.DataFrame(result) current_image_name = Path(current_image).name result = dict(zip(result["Class"], result["Probability"])) results_cache[current_image_name] = { "Distribution": result, "Classification": {"Probability": prob, "Label": label}, } return gr.BarPlot( df, x="Class", y="Probability", tooltip=class_to_idx.keys(), y_lim=(0, 1) ) def predict_all(progress_bar=gr.Progress()): global all_images, results_cache for img in progress_bar.tqdm(all_images, desc="Running images"): current_image_name = Path(img).name img = Image.open(img).convert("RGB") if autocrop: box = cropped_img(img) img = img.crop(box) result, (prob, label) = run_image(img) result = dict(zip(result["Class"], result["Probability"])) results_cache[current_image_name] = { "Distribution": result, "Classification": {"Probability": prob, "Label": label}, } return "All images predicted successfully." def get_results_cache(): global results_cache return results_cache def save_results(): global results_cache with NamedTemporaryFile( "w", delete=False, prefix="model_predictions_", suffix=".json", ) as f: json.dump(results_cache, f, indent=4) temp_files.append(f.name) return f.name def select_image(files, sd: gr.SelectData): # Returns the name of the image which you click on in the file upload global current_image current_image = files[sd.index].name return files[sd.index].name def show_crop_panel(): global current_image return current_image def upload_files(files): global all_images all_images = files def toggle_autocrop(res): global autocrop autocrop = res def show_preview(x): # When you click the crop button, the preview is updated and cached return x["composite"] def show_current_filename(): orig_msg = "Crop Image Here (Optional), then click Run to Predict" current_img_name = Path(current_image).name return f"{orig_msg}\n\nCurrent File: {current_img_name}" with gr.Blocks() as demo: demo.load(load_model) with gr.Row(): gr.Textbox( "Only use this application on the following classes of nematodes: " + "Helicotylenchus, Hoplolaimus, Meloidogyne, Mesocriconema, " "Pratylenchus, Trichodorus, and Tylenchorhynchus.\n\n" + "Only use images containing a single nematode.\n\n" + "SCROLL DOWN TO DOWNLOAD THE PREDICTIONS FOR YOUR IMAGES!", text_align="center", label="DISCLAIMER", ) with gr.Row(): model_text = gr.Textbox( "Default model: EfficientNetV2-S. To choose a different model, choose one from the dropdown on the right", label="Current Model", ) model_select = gr.Dropdown( choices=["EfficientNetV2-S", "MobileNetV3-L", "ResNet101", "Swin V2-B"], value="EfficientNetV2-S", label="Select Model Architecture (May take a few moments, check text on the left to confirm your model has loaded)", ) with gr.Row(): with gr.Column(): gr.Textbox( "Upload Images, then Select Each One to Crop & Run", show_label=False, ) files = gr.File(file_types=["image"], file_count="multiple") batch_predict = gr.Button("Predict All", variant="stop") prediction_progress = gr.Textbox( "Prediction Progress Bar", show_label=False ) with gr.Column(): mid_col_text = gr.Textbox( "Crop Image Here (Optional), then Click Run to Predict", show_label=False, ) autocrop_toggle = gr.Checkbox(value=False, label="Automatic Cropping") cropper = gr.ImageEditor( type="filepath", sources=None, layers=False, brush=False, mirror_webcam=False, ) crop = gr.Button("Crop") with gr.Column(): gr.Textbox( "Image Preview (What will be run through network)", show_label=False, ) preview = gr.Image( sources=None, type="filepath", height=250, interactive=False, mirror_webcam=False, ) classify = gr.Button("Classify", variant="stop") plot = gr.BarPlot() with gr.Row(): gr.Textbox( "Here are the predicted labels for your images in JSON format", label="Predictions", ) with gr.Row(): with gr.Column(): json_results = gr.JSON() with gr.Column(): download = gr.DownloadButton("Download Predictions", variant="primary") download.click(save_results, outputs=download) model_select.change(load_model, inputs=model_select).then( display_model, outputs=model_text ) files.upload(upload_files, inputs=files) files.select(select_image, inputs=files, outputs=cropper).then( show_current_filename, outputs=mid_col_text, ).then( prev_crop_preview, outputs=preview, ) autocrop_toggle.change(toggle_autocrop, inputs=autocrop_toggle).then( show_crop_panel, outputs=cropper ).then( prev_crop_preview, outputs=preview, ) batch_predict.click(predict_all, outputs=prediction_progress).then( get_results_cache, outputs=json_results ).then(save_results, outputs=download) files.clear(clear).then(get_results_cache, outputs=json_results).then( save_results, outputs=download ) crop.click(show_preview, inputs=cropper, outputs=preview) classify.click(predict, inputs=preview, outputs=plot).then( get_results_cache, outputs=json_results ).then(save_results, outputs=download) demo.unload(clear) if __name__ == "__main__": demo.launch()