VikramR's picture
Autocrop off by default, added progress bar for batch prediction, made colors nice, fixed download button
56eb931
import torch
from torchvision.transforms.functional import pil_to_tensor
import gradio as gr
from gradio.utils import get_upload_folder
from huggingface_hub import hf_hub_download
from external_models import EfficientNet, MobileNet, ResNet, Swin
from utils import get_preprocessing
from pathlib import Path
from PIL import Image
from tempfile import NamedTemporaryFile
import json
import os
import cv2
import pandas as pd
import numpy as np
device = "cpu"
models = {
"mbnet": MobileNet,
"effnet": EfficientNet,
"resnet": ResNet,
"swin": Swin,
}
model_filenames = {
"EfficientNetV2-S": "efficientnetv2s.pth",
"MobileNetV3-L": "mobilenetv3l.pth",
"ResNet101": "resnet101.pth",
"Swin V2-B": "swinv2b.pth",
}
model_names = {
"effnet": "EfficientNetV2-S",
"mbnet": "MobileNetV3-L",
"resnet": "ResNet101",
"swin": "Swin V2-B",
}
def cropped_img(img: np.ndarray | Image.Image | str):
"""
Takes an image and automatically crops the nematode. Returns the cropped image
and the binary mask of the original image that outlines the nematode
Parameters
----------
img : np.ndarray
Image
Returns
-------
tuple[float, float, float, float]
Cropped image bounding box
"""
if isinstance(img, str):
img = Image.open(img).convert("RGB")
if isinstance(img, Image.Image):
img = np.array(img)
rgb = img
gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
# EDGE DETECTION
edges = cv2.Canny(gray, 25, 25, apertureSize=3, L2gradient=True)
# FILLS IN NEMATODE EDGES BY "PUFFING" IT UP, ALSO REMOVES OTHER DEBRIS
kernel = np.ones((11, 11), np.uint8)
edges_dilate = cv2.dilate(edges, kernel, iterations=3)
edges_erode = cv2.erode(edges_dilate, kernel, iterations=3)
cnts, _ = cv2.findContours(edges_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
cnt = max(cnts, key=cv2.contourArea)
fill = np.zeros(edges.shape, np.uint8)
cv2.drawContours(fill, [cnt], -1, 255, cv2.FILLED)
# CROPS THE BINARY IMAGE DEPENDING ON WHERE THE WHITE PIXELS ARE
x1, y1 = (
max(np.argmax(fill.max(0)), 0),
max(np.argmax(fill.max(1)), 0),
)
x2, y2 = (
min(
fill.shape[1] - np.argmax(np.flip(fill.max(0))),
fill.shape[1],
),
min(
fill.shape[0] - np.argmax(np.flip(fill.max(1))),
fill.shape[0],
),
)
if y2 - y1 < x2 - x1:
delta = ((x2 - x1) - (y2 - y1)) // 2
if y1 < delta:
y2 += 2 * delta - y1
y1 = 0
else:
y1 -= delta
y2 += delta
else:
delta = ((y2 - y1) - (x2 - x1)) // 2
if x1 < delta:
x2 += 2 * delta - x1
x1 = 0
else:
x1 -= delta
x2 += delta
y, x = rgb.shape[:2]
x2 = min(x2, x)
y2 = min(y2, y)
x1 = max(0, x1)
y1 = max(0, y1)
# CROPS AND RESIZES IMAGE
return x1, y1, x2, y2
model, preprocessing, class_to_idx, idx_to_class = None, None, None, None
current_model_type = None
results_cache: dict[str, str] = {}
current_image = None
autocrop = False
temp_files: list[str] = []
all_images: list[str] = []
def load_model(model_name: str = "EfficientNetV2-S"):
"""
Loads model and modifies global state
"""
global model, preprocessing, class_to_idx, idx_to_class, current_model_type
if model_name is not None:
filename = model_filenames[model_name]
filepath = hf_hub_download(
repo_id="VikramR/NematodeClassification",
filename=filename,
)
(model_state, _, _, _, _, _, config, class_to_idx, _) = torch.load(
filepath, map_location=device
)
current_model_type = config["model_type"]
model = models[config["model_type"]](config).to(device)
model.load_state_dict(model_state)
model = model.eval()
idx_to_class = {idx: img_cls for img_cls, idx in class_to_idx.items()}
preprocessing = get_preprocessing(current_model_type)
def display_model():
"""
Displays the current selected model in the textbox
"""
global current_model_type
model_name = model_names[current_model_type]
return f"Current Model Type: {model_name}. Use dropdown on the right to change it."
def clear():
"""
Resets global state
"""
global results_cache, current_image
results_cache = {}
current_image = None
for file in all_images:
os.remove(file)
for file in temp_files:
os.remove(file)
@torch.no_grad()
def run_image(img: Image.Image):
global preprocessing, device, model, class_to_idx, idx_to_class
img = pil_to_tensor(img)[None].to(device)
img = preprocessing(img)
logits = model(img)
probs = torch.nn.functional.softmax(logits, dim=1)[0]
prob, label = torch.max(probs, dim=0)
n_classes = len(class_to_idx)
results = {
"Probability": list(range(n_classes)),
"Class": [idx_to_class[i] for i in range(n_classes)],
}
for i in range(n_classes):
results["Probability"][i] = float(probs[i].item())
label = idx_to_class[label.item()]
prob = prob.item()
return results, (prob, label)
def prev_crop_preview() -> str:
"""
Preview for the current cropped image
"""
global autocrop, current_image, temp_files
if current_image is None:
return None
img = Image.open(current_image).convert("RGB")
if autocrop:
box = cropped_img(img)
img = img.crop(box)
with NamedTemporaryFile(
mode="wb", dir=get_upload_folder(), suffix=".png", delete=False
) as f:
pth = f.name
img.save(f)
temp_files.append(f.name)
return pth
def predict(img: str) -> gr.BarPlot:
global results_cache
img = Image.open(img).convert("RGB")
result, (prob, label) = run_image(img)
df = pd.DataFrame(result)
current_image_name = Path(current_image).name
result = dict(zip(result["Class"], result["Probability"]))
results_cache[current_image_name] = {
"Distribution": result,
"Classification": {"Probability": prob, "Label": label},
}
return gr.BarPlot(
df, x="Class", y="Probability", tooltip=class_to_idx.keys(), y_lim=(0, 1)
)
def predict_all(progress_bar=gr.Progress()):
global all_images, results_cache
for img in progress_bar.tqdm(all_images, desc="Running images"):
current_image_name = Path(img).name
img = Image.open(img).convert("RGB")
if autocrop:
box = cropped_img(img)
img = img.crop(box)
result, (prob, label) = run_image(img)
result = dict(zip(result["Class"], result["Probability"]))
results_cache[current_image_name] = {
"Distribution": result,
"Classification": {"Probability": prob, "Label": label},
}
return "All images predicted successfully."
def get_results_cache():
global results_cache
return results_cache
def save_results():
global results_cache
with NamedTemporaryFile(
"w",
delete=False,
prefix="model_predictions_",
suffix=".json",
) as f:
json.dump(results_cache, f, indent=4)
temp_files.append(f.name)
return f.name
def select_image(files, sd: gr.SelectData):
# Returns the name of the image which you click on in the file upload
global current_image
current_image = files[sd.index].name
return files[sd.index].name
def show_crop_panel():
global current_image
return current_image
def upload_files(files):
global all_images
all_images = files
def toggle_autocrop(res):
global autocrop
autocrop = res
def show_preview(x):
# When you click the crop button, the preview is updated and cached
return x["composite"]
def show_current_filename():
orig_msg = "Crop Image Here (Optional), then click Run to Predict"
current_img_name = Path(current_image).name
return f"{orig_msg}\n\nCurrent File: {current_img_name}"
with gr.Blocks() as demo:
demo.load(load_model)
with gr.Row():
gr.Textbox(
"Only use this application on the following classes of nematodes: "
+ "Helicotylenchus, Hoplolaimus, Meloidogyne, Mesocriconema, "
"Pratylenchus, Trichodorus, and Tylenchorhynchus.\n\n"
+ "Only use images containing a single nematode.\n\n"
+ "SCROLL DOWN TO DOWNLOAD THE PREDICTIONS FOR YOUR IMAGES!",
text_align="center",
label="DISCLAIMER",
)
with gr.Row():
model_text = gr.Textbox(
"Default model: EfficientNetV2-S. To choose a different model, choose one from the dropdown on the right",
label="Current Model",
)
model_select = gr.Dropdown(
choices=["EfficientNetV2-S", "MobileNetV3-L", "ResNet101", "Swin V2-B"],
value="EfficientNetV2-S",
label="Select Model Architecture (May take a few moments, check text on the left to confirm your model has loaded)",
)
with gr.Row():
with gr.Column():
gr.Textbox(
"Upload Images, then Select Each One to Crop & Run",
show_label=False,
)
files = gr.File(file_types=["image"], file_count="multiple")
batch_predict = gr.Button("Predict All", variant="stop")
prediction_progress = gr.Textbox(
"Prediction Progress Bar", show_label=False
)
with gr.Column():
mid_col_text = gr.Textbox(
"Crop Image Here (Optional), then Click Run to Predict",
show_label=False,
)
autocrop_toggle = gr.Checkbox(value=False, label="Automatic Cropping")
cropper = gr.ImageEditor(
type="filepath",
sources=None,
layers=False,
brush=False,
mirror_webcam=False,
)
crop = gr.Button("Crop")
with gr.Column():
gr.Textbox(
"Image Preview (What will be run through network)",
show_label=False,
)
preview = gr.Image(
sources=None,
type="filepath",
height=250,
interactive=False,
mirror_webcam=False,
)
classify = gr.Button("Classify", variant="stop")
plot = gr.BarPlot()
with gr.Row():
gr.Textbox(
"Here are the predicted labels for your images in JSON format",
label="Predictions",
)
with gr.Row():
with gr.Column():
json_results = gr.JSON()
with gr.Column():
download = gr.DownloadButton("Download Predictions", variant="primary")
download.click(save_results, outputs=download)
model_select.change(load_model, inputs=model_select).then(
display_model, outputs=model_text
)
files.upload(upload_files, inputs=files)
files.select(select_image, inputs=files, outputs=cropper).then(
show_current_filename,
outputs=mid_col_text,
).then(
prev_crop_preview,
outputs=preview,
)
autocrop_toggle.change(toggle_autocrop, inputs=autocrop_toggle).then(
show_crop_panel, outputs=cropper
).then(
prev_crop_preview,
outputs=preview,
)
batch_predict.click(predict_all, outputs=prediction_progress).then(
get_results_cache, outputs=json_results
).then(save_results, outputs=download)
files.clear(clear).then(get_results_cache, outputs=json_results).then(
save_results, outputs=download
)
crop.click(show_preview, inputs=cropper, outputs=preview)
classify.click(predict, inputs=preview, outputs=plot).then(
get_results_cache, outputs=json_results
).then(save_results, outputs=download)
demo.unload(clear)
if __name__ == "__main__":
demo.launch()