Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision.transforms.functional import pil_to_tensor | |
| import gradio as gr | |
| from gradio.utils import get_upload_folder | |
| from huggingface_hub import hf_hub_download | |
| from external_models import EfficientNet, MobileNet, ResNet, Swin | |
| from utils import get_preprocessing | |
| from pathlib import Path | |
| from PIL import Image | |
| from tempfile import NamedTemporaryFile | |
| import json | |
| import os | |
| import cv2 | |
| import pandas as pd | |
| import numpy as np | |
| device = "cpu" | |
| models = { | |
| "mbnet": MobileNet, | |
| "effnet": EfficientNet, | |
| "resnet": ResNet, | |
| "swin": Swin, | |
| } | |
| model_filenames = { | |
| "EfficientNetV2-S": "efficientnetv2s.pth", | |
| "MobileNetV3-L": "mobilenetv3l.pth", | |
| "ResNet101": "resnet101.pth", | |
| "Swin V2-B": "swinv2b.pth", | |
| } | |
| model_names = { | |
| "effnet": "EfficientNetV2-S", | |
| "mbnet": "MobileNetV3-L", | |
| "resnet": "ResNet101", | |
| "swin": "Swin V2-B", | |
| } | |
| def cropped_img(img: np.ndarray | Image.Image | str): | |
| """ | |
| Takes an image and automatically crops the nematode. Returns the cropped image | |
| and the binary mask of the original image that outlines the nematode | |
| Parameters | |
| ---------- | |
| img : np.ndarray | |
| Image | |
| Returns | |
| ------- | |
| tuple[float, float, float, float] | |
| Cropped image bounding box | |
| """ | |
| if isinstance(img, str): | |
| img = Image.open(img).convert("RGB") | |
| if isinstance(img, Image.Image): | |
| img = np.array(img) | |
| rgb = img | |
| gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY) | |
| # EDGE DETECTION | |
| edges = cv2.Canny(gray, 25, 25, apertureSize=3, L2gradient=True) | |
| # FILLS IN NEMATODE EDGES BY "PUFFING" IT UP, ALSO REMOVES OTHER DEBRIS | |
| kernel = np.ones((11, 11), np.uint8) | |
| edges_dilate = cv2.dilate(edges, kernel, iterations=3) | |
| edges_erode = cv2.erode(edges_dilate, kernel, iterations=3) | |
| cnts, _ = cv2.findContours(edges_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| cnt = max(cnts, key=cv2.contourArea) | |
| fill = np.zeros(edges.shape, np.uint8) | |
| cv2.drawContours(fill, [cnt], -1, 255, cv2.FILLED) | |
| # CROPS THE BINARY IMAGE DEPENDING ON WHERE THE WHITE PIXELS ARE | |
| x1, y1 = ( | |
| max(np.argmax(fill.max(0)), 0), | |
| max(np.argmax(fill.max(1)), 0), | |
| ) | |
| x2, y2 = ( | |
| min( | |
| fill.shape[1] - np.argmax(np.flip(fill.max(0))), | |
| fill.shape[1], | |
| ), | |
| min( | |
| fill.shape[0] - np.argmax(np.flip(fill.max(1))), | |
| fill.shape[0], | |
| ), | |
| ) | |
| if y2 - y1 < x2 - x1: | |
| delta = ((x2 - x1) - (y2 - y1)) // 2 | |
| if y1 < delta: | |
| y2 += 2 * delta - y1 | |
| y1 = 0 | |
| else: | |
| y1 -= delta | |
| y2 += delta | |
| else: | |
| delta = ((y2 - y1) - (x2 - x1)) // 2 | |
| if x1 < delta: | |
| x2 += 2 * delta - x1 | |
| x1 = 0 | |
| else: | |
| x1 -= delta | |
| x2 += delta | |
| y, x = rgb.shape[:2] | |
| x2 = min(x2, x) | |
| y2 = min(y2, y) | |
| x1 = max(0, x1) | |
| y1 = max(0, y1) | |
| # CROPS AND RESIZES IMAGE | |
| return x1, y1, x2, y2 | |
| model, preprocessing, class_to_idx, idx_to_class = None, None, None, None | |
| current_model_type = None | |
| results_cache: dict[str, str] = {} | |
| current_image = None | |
| autocrop = False | |
| temp_files: list[str] = [] | |
| all_images: list[str] = [] | |
| def load_model(model_name: str = "EfficientNetV2-S"): | |
| """ | |
| Loads model and modifies global state | |
| """ | |
| global model, preprocessing, class_to_idx, idx_to_class, current_model_type | |
| if model_name is not None: | |
| filename = model_filenames[model_name] | |
| filepath = hf_hub_download( | |
| repo_id="VikramR/NematodeClassification", | |
| filename=filename, | |
| ) | |
| (model_state, _, _, _, _, _, config, class_to_idx, _) = torch.load( | |
| filepath, map_location=device | |
| ) | |
| current_model_type = config["model_type"] | |
| model = models[config["model_type"]](config).to(device) | |
| model.load_state_dict(model_state) | |
| model = model.eval() | |
| idx_to_class = {idx: img_cls for img_cls, idx in class_to_idx.items()} | |
| preprocessing = get_preprocessing(current_model_type) | |
| def display_model(): | |
| """ | |
| Displays the current selected model in the textbox | |
| """ | |
| global current_model_type | |
| model_name = model_names[current_model_type] | |
| return f"Current Model Type: {model_name}. Use dropdown on the right to change it." | |
| def clear(): | |
| """ | |
| Resets global state | |
| """ | |
| global results_cache, current_image | |
| results_cache = {} | |
| current_image = None | |
| for file in all_images: | |
| os.remove(file) | |
| for file in temp_files: | |
| os.remove(file) | |
| def run_image(img: Image.Image): | |
| global preprocessing, device, model, class_to_idx, idx_to_class | |
| img = pil_to_tensor(img)[None].to(device) | |
| img = preprocessing(img) | |
| logits = model(img) | |
| probs = torch.nn.functional.softmax(logits, dim=1)[0] | |
| prob, label = torch.max(probs, dim=0) | |
| n_classes = len(class_to_idx) | |
| results = { | |
| "Probability": list(range(n_classes)), | |
| "Class": [idx_to_class[i] for i in range(n_classes)], | |
| } | |
| for i in range(n_classes): | |
| results["Probability"][i] = float(probs[i].item()) | |
| label = idx_to_class[label.item()] | |
| prob = prob.item() | |
| return results, (prob, label) | |
| def prev_crop_preview() -> str: | |
| """ | |
| Preview for the current cropped image | |
| """ | |
| global autocrop, current_image, temp_files | |
| if current_image is None: | |
| return None | |
| img = Image.open(current_image).convert("RGB") | |
| if autocrop: | |
| box = cropped_img(img) | |
| img = img.crop(box) | |
| with NamedTemporaryFile( | |
| mode="wb", dir=get_upload_folder(), suffix=".png", delete=False | |
| ) as f: | |
| pth = f.name | |
| img.save(f) | |
| temp_files.append(f.name) | |
| return pth | |
| def predict(img: str) -> gr.BarPlot: | |
| global results_cache | |
| img = Image.open(img).convert("RGB") | |
| result, (prob, label) = run_image(img) | |
| df = pd.DataFrame(result) | |
| current_image_name = Path(current_image).name | |
| result = dict(zip(result["Class"], result["Probability"])) | |
| results_cache[current_image_name] = { | |
| "Distribution": result, | |
| "Classification": {"Probability": prob, "Label": label}, | |
| } | |
| return gr.BarPlot( | |
| df, x="Class", y="Probability", tooltip=class_to_idx.keys(), y_lim=(0, 1) | |
| ) | |
| def predict_all(progress_bar=gr.Progress()): | |
| global all_images, results_cache | |
| for img in progress_bar.tqdm(all_images, desc="Running images"): | |
| current_image_name = Path(img).name | |
| img = Image.open(img).convert("RGB") | |
| if autocrop: | |
| box = cropped_img(img) | |
| img = img.crop(box) | |
| result, (prob, label) = run_image(img) | |
| result = dict(zip(result["Class"], result["Probability"])) | |
| results_cache[current_image_name] = { | |
| "Distribution": result, | |
| "Classification": {"Probability": prob, "Label": label}, | |
| } | |
| return "All images predicted successfully." | |
| def get_results_cache(): | |
| global results_cache | |
| return results_cache | |
| def save_results(): | |
| global results_cache | |
| with NamedTemporaryFile( | |
| "w", | |
| delete=False, | |
| prefix="model_predictions_", | |
| suffix=".json", | |
| ) as f: | |
| json.dump(results_cache, f, indent=4) | |
| temp_files.append(f.name) | |
| return f.name | |
| def select_image(files, sd: gr.SelectData): | |
| # Returns the name of the image which you click on in the file upload | |
| global current_image | |
| current_image = files[sd.index].name | |
| return files[sd.index].name | |
| def show_crop_panel(): | |
| global current_image | |
| return current_image | |
| def upload_files(files): | |
| global all_images | |
| all_images = files | |
| def toggle_autocrop(res): | |
| global autocrop | |
| autocrop = res | |
| def show_preview(x): | |
| # When you click the crop button, the preview is updated and cached | |
| return x["composite"] | |
| def show_current_filename(): | |
| orig_msg = "Crop Image Here (Optional), then click Run to Predict" | |
| current_img_name = Path(current_image).name | |
| return f"{orig_msg}\n\nCurrent File: {current_img_name}" | |
| with gr.Blocks() as demo: | |
| demo.load(load_model) | |
| with gr.Row(): | |
| gr.Textbox( | |
| "Only use this application on the following classes of nematodes: " | |
| + "Helicotylenchus, Hoplolaimus, Meloidogyne, Mesocriconema, " | |
| "Pratylenchus, Trichodorus, and Tylenchorhynchus.\n\n" | |
| + "Only use images containing a single nematode.\n\n" | |
| + "SCROLL DOWN TO DOWNLOAD THE PREDICTIONS FOR YOUR IMAGES!", | |
| text_align="center", | |
| label="DISCLAIMER", | |
| ) | |
| with gr.Row(): | |
| model_text = gr.Textbox( | |
| "Default model: EfficientNetV2-S. To choose a different model, choose one from the dropdown on the right", | |
| label="Current Model", | |
| ) | |
| model_select = gr.Dropdown( | |
| choices=["EfficientNetV2-S", "MobileNetV3-L", "ResNet101", "Swin V2-B"], | |
| value="EfficientNetV2-S", | |
| label="Select Model Architecture (May take a few moments, check text on the left to confirm your model has loaded)", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Textbox( | |
| "Upload Images, then Select Each One to Crop & Run", | |
| show_label=False, | |
| ) | |
| files = gr.File(file_types=["image"], file_count="multiple") | |
| batch_predict = gr.Button("Predict All", variant="stop") | |
| prediction_progress = gr.Textbox( | |
| "Prediction Progress Bar", show_label=False | |
| ) | |
| with gr.Column(): | |
| mid_col_text = gr.Textbox( | |
| "Crop Image Here (Optional), then Click Run to Predict", | |
| show_label=False, | |
| ) | |
| autocrop_toggle = gr.Checkbox(value=False, label="Automatic Cropping") | |
| cropper = gr.ImageEditor( | |
| type="filepath", | |
| sources=None, | |
| layers=False, | |
| brush=False, | |
| mirror_webcam=False, | |
| ) | |
| crop = gr.Button("Crop") | |
| with gr.Column(): | |
| gr.Textbox( | |
| "Image Preview (What will be run through network)", | |
| show_label=False, | |
| ) | |
| preview = gr.Image( | |
| sources=None, | |
| type="filepath", | |
| height=250, | |
| interactive=False, | |
| mirror_webcam=False, | |
| ) | |
| classify = gr.Button("Classify", variant="stop") | |
| plot = gr.BarPlot() | |
| with gr.Row(): | |
| gr.Textbox( | |
| "Here are the predicted labels for your images in JSON format", | |
| label="Predictions", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| json_results = gr.JSON() | |
| with gr.Column(): | |
| download = gr.DownloadButton("Download Predictions", variant="primary") | |
| download.click(save_results, outputs=download) | |
| model_select.change(load_model, inputs=model_select).then( | |
| display_model, outputs=model_text | |
| ) | |
| files.upload(upload_files, inputs=files) | |
| files.select(select_image, inputs=files, outputs=cropper).then( | |
| show_current_filename, | |
| outputs=mid_col_text, | |
| ).then( | |
| prev_crop_preview, | |
| outputs=preview, | |
| ) | |
| autocrop_toggle.change(toggle_autocrop, inputs=autocrop_toggle).then( | |
| show_crop_panel, outputs=cropper | |
| ).then( | |
| prev_crop_preview, | |
| outputs=preview, | |
| ) | |
| batch_predict.click(predict_all, outputs=prediction_progress).then( | |
| get_results_cache, outputs=json_results | |
| ).then(save_results, outputs=download) | |
| files.clear(clear).then(get_results_cache, outputs=json_results).then( | |
| save_results, outputs=download | |
| ) | |
| crop.click(show_preview, inputs=cropper, outputs=preview) | |
| classify.click(predict, inputs=preview, outputs=plot).then( | |
| get_results_cache, outputs=json_results | |
| ).then(save_results, outputs=download) | |
| demo.unload(clear) | |
| if __name__ == "__main__": | |
| demo.launch() | |