| import json |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from torchvision.models import efficientnet_b0 |
| from PIL import Image |
| import gradio as gr |
|
|
|
|
| |
| CKPT_PATH = "best_effnet_twohead.pt" |
| LABELS_PATH = "labels.json" |
| TAXA_PATH = "taxa.txt" |
| STATES_PATH = "states.txt" |
| IMG_SIZE = 224 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
|
|
|
| def load_lines(path: str) -> list[str]: |
| p = path |
| with open(p, "r", encoding="utf-8") as f: |
| return [ln.strip() for ln in f if ln.strip()] |
|
|
|
|
| |
| try: |
| with open(LABELS_PATH, "r", encoding="utf-8") as f: |
| labels = json.load(f) |
| SPECIES = labels["species"] |
| STATE = labels["state"] |
| except FileNotFoundError: |
| SPECIES = load_lines(TAXA_PATH) |
| STATE = load_lines(STATES_PATH) |
|
|
|
|
| |
| class EffNetTwoHead(nn.Module): |
| def __init__(self, num_species, num_states): |
| super().__init__() |
| base = efficientnet_b0(weights=None) |
| self.features = base.features |
| self.pool = base.avgpool |
| c = base.classifier[1].in_features |
|
|
| |
| self.drop = nn.Dropout(0.3) |
| self.head_species = nn.Linear(c, num_species) |
| self.head_state = nn.Linear(c, num_states) |
|
|
| def forward(self, x): |
| x = self.features(x) |
| x = self.pool(x) |
| x = torch.flatten(x, 1) |
| x = self.drop(x) |
| return self.head_species(x), self.head_state(x) |
|
|
|
|
| |
| ckpt = torch.load(CKPT_PATH, map_location="cpu") |
| num_species = int(ckpt.get("num_species", len(SPECIES))) |
| num_states = int(ckpt.get("num_states", len(STATE))) |
|
|
| |
| if len(SPECIES) != num_species: |
| raise RuntimeError( |
| f"Label mismatch: len(SPECIES)={len(SPECIES)} but ckpt num_species={num_species}. " |
| f"Fix labels.json or taxa.txt." |
| ) |
| if len(STATE) != num_states: |
| raise RuntimeError( |
| f"Label mismatch: len(STATE)={len(STATE)} but ckpt num_states={num_states}. " |
| f"Fix labels.json or states.txt." |
| ) |
|
|
| model = EffNetTwoHead(num_species, num_states) |
| model.load_state_dict(ckpt["model"], strict=True) |
| model.to(DEVICE).eval() |
|
|
|
|
| |
| |
| |
| normalize = transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225], |
| ) |
|
|
|
|
| def letterbox_pil(im: Image.Image, size: int = 224, fill: int = 0) -> Image.Image: |
| w, h = im.size |
| if w == 0 or h == 0: |
| return Image.new("RGB", (size, size), (fill, fill, fill)) |
|
|
| scale = min(size / w, size / h) |
| new_w = max(1, int(round(w * scale))) |
| new_h = max(1, int(round(h * scale))) |
|
|
| im_resized = im.resize((new_w, new_h), resample=Image.BILINEAR) |
|
|
| canvas = Image.new("RGB", (size, size), (fill, fill, fill)) |
| left = (size - new_w) // 2 |
| top = (size - new_h) // 2 |
| canvas.paste(im_resized, (left, top)) |
| return canvas |
|
|
|
|
| @torch.no_grad() |
| def predict(image: Image.Image): |
| if image is None: |
| return "No image", "No image" |
|
|
| image = image.convert("RGB") |
|
|
| |
| image = letterbox_pil(image, size=IMG_SIZE, fill=0) |
|
|
| x = transforms.ToTensor()(image) |
| x = normalize(x).unsqueeze(0).to(DEVICE) |
|
|
| log_sp, log_st = model(x) |
|
|
| prob_sp = torch.softmax(log_sp, dim=1)[0] |
| prob_st = torch.softmax(log_st, dim=1)[0] |
|
|
| sp_id = int(prob_sp.argmax().item()) |
| st_id = int(prob_st.argmax().item()) |
|
|
| sp_conf = float(prob_sp[sp_id].item()) |
| st_conf = float(prob_st[st_id].item()) |
|
|
| sp_text = f"{SPECIES[sp_id]} (score={sp_conf:.3f})" |
| st_text = f"{STATE[st_id]} (score={st_conf:.3f})" |
|
|
| return sp_text, st_text |
|
|
|
|
| EXAMPLE_TEXT = """ |
| **Taxa (20):** Anthidium, Cacoxnus indagator, Chelostoma campanularum, Chelostoma florisomne, Chelostoma rapunculi, Coeliopencyrtus, Eumenidae, Heriades, Hylaeus, Ichneumonidae, Isodontia mexicana, Megachile, Osmia bicornis, Osmia brevicornis, Osmia cornuta, Passaloecus, Pemphredon, Psenulus, Trichodes, Trypoxylon |
| |
| **States (5):** DauLv, DeadLv, Hatched, Lv, OldFood |
| """ |
|
|
|
|
| STATUS_TABLE = [ |
| ["DauLv", "Visible alive prepupa that stopped feeding"], |
| ["DeadLv", "Dead visible larva"], |
| ["Hatched", "Brood cell with hatched bee or wasp traces (cocoon or exuvia)"], |
| ["Lv", "Visible alive larva"], |
| ["OldFood", "Brood cell with only unconsumed food, no larva will develop"], |
| ] |
|
|
|
|
| theme = gr.themes.Soft() |
|
|
| with gr.Blocks(theme=theme, title="LTN EfficientNet Two-Head Classifier") as demo: |
| gr.Markdown( |
| "# EfficientNet Two-Head Layer Trap Nest (LTN) Classifier\n" |
| "Upload a brood-cell image crop to predict **taxon** and **developmental state**.\n" |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| img = gr.Image(type="pil", label="Input image", height=320) |
|
|
| gr.Markdown("### Label sets") |
| gr.Markdown(EXAMPLE_TEXT) |
|
|
| gr.Markdown("") |
| legend = gr.Dataframe( |
| value=STATUS_TABLE, |
| headers=["State", "Description"], |
| datatype=["str", "str"], |
| interactive=False, |
| wrap=True, |
| row_count=(len(STATUS_TABLE), "fixed"), |
| col_count=(2, "fixed"), |
| ) |
|
|
| with gr.Column(scale=1): |
| sp_out = gr.Textbox(label="Predicted taxon", lines=1) |
| st_out = gr.Textbox(label="Predicted state", lines=1) |
|
|
| btn = gr.Button("Predict", variant="primary") |
| btn.click(fn=predict, inputs=img, outputs=[sp_out, st_out]) |
|
|
| gr.Markdown("### Examples") |
| gr.Examples( |
| examples=[[f"./{i}.jpg"] for i in range(23)], |
| inputs=img, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|