import json import torch import torch.nn as nn from torchvision import transforms from torchvision.models import efficientnet_b0 from PIL import Image import gradio as gr # ---------- CONFIG ---------- CKPT_PATH = "best_effnet_twohead.pt" # training script saves best.pt / last.pt LABELS_PATH = "labels.json" # optional; fallback to taxa.txt / states.txt TAXA_PATH = "taxa.txt" # fallback STATES_PATH = "states.txt" # fallback IMG_SIZE = 224 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ---------------------------- def load_lines(path: str) -> list[str]: p = path with open(p, "r", encoding="utf-8") as f: return [ln.strip() for ln in f if ln.strip()] # load labels (labels.json preferred; fallback to taxa.txt/states.txt) try: with open(LABELS_PATH, "r", encoding="utf-8") as f: labels = json.load(f) SPECIES = labels["species"] STATE = labels["state"] except FileNotFoundError: SPECIES = load_lines(TAXA_PATH) STATE = load_lines(STATES_PATH) # model (must match training) class EffNetTwoHead(nn.Module): def __init__(self, num_species, num_states): super().__init__() base = efficientnet_b0(weights=None) self.features = base.features self.pool = base.avgpool c = base.classifier[1].in_features # training script uses dropout before heads self.drop = nn.Dropout(0.3) self.head_species = nn.Linear(c, num_species) self.head_state = nn.Linear(c, num_states) def forward(self, x): x = self.features(x) x = self.pool(x) x = torch.flatten(x, 1) x = self.drop(x) return self.head_species(x), self.head_state(x) # load model ckpt = torch.load(CKPT_PATH, map_location="cpu") num_species = int(ckpt.get("num_species", len(SPECIES))) num_states = int(ckpt.get("num_states", len(STATE))) # if checkpoint defines class counts, trust it; labels must match lengths if len(SPECIES) != num_species: raise RuntimeError( f"Label mismatch: len(SPECIES)={len(SPECIES)} but ckpt num_species={num_species}. " f"Fix labels.json or taxa.txt." ) if len(STATE) != num_states: raise RuntimeError( f"Label mismatch: len(STATE)={len(STATE)} but ckpt num_states={num_states}. " f"Fix labels.json or states.txt." ) model = EffNetTwoHead(num_species, num_states) model.load_state_dict(ckpt["model"], strict=True) model.to(DEVICE).eval() # preprocessing (align with training: letterbox to 224x224 without cropping) # This implementation NEVER crops the image: it resizes to fit within 224x224, # then pads the remaining area (black) to reach exactly 224x224. normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) def letterbox_pil(im: Image.Image, size: int = 224, fill: int = 0) -> Image.Image: w, h = im.size if w == 0 or h == 0: return Image.new("RGB", (size, size), (fill, fill, fill)) scale = min(size / w, size / h) # fit entirely inside size x size new_w = max(1, int(round(w * scale))) new_h = max(1, int(round(h * scale))) im_resized = im.resize((new_w, new_h), resample=Image.BILINEAR) canvas = Image.new("RGB", (size, size), (fill, fill, fill)) left = (size - new_w) // 2 top = (size - new_h) // 2 canvas.paste(im_resized, (left, top)) return canvas @torch.no_grad() def predict(image: Image.Image): if image is None: return "No image", "No image" image = image.convert("RGB") # letterbox to 224x224 (no cropping) image = letterbox_pil(image, size=IMG_SIZE, fill=0) x = transforms.ToTensor()(image) x = normalize(x).unsqueeze(0).to(DEVICE) log_sp, log_st = model(x) prob_sp = torch.softmax(log_sp, dim=1)[0] prob_st = torch.softmax(log_st, dim=1)[0] sp_id = int(prob_sp.argmax().item()) st_id = int(prob_st.argmax().item()) sp_conf = float(prob_sp[sp_id].item()) st_conf = float(prob_st[st_id].item()) sp_text = f"{SPECIES[sp_id]} (score={sp_conf:.3f})" st_text = f"{STATE[st_id]} (score={st_conf:.3f})" return sp_text, st_text EXAMPLE_TEXT = """ **Taxa (20):** Anthidium, Cacoxnus indagator, Chelostoma campanularum, Chelostoma florisomne, Chelostoma rapunculi, Coeliopencyrtus, Eumenidae, Heriades, Hylaeus, Ichneumonidae, Isodontia mexicana, Megachile, Osmia bicornis, Osmia brevicornis, Osmia cornuta, Passaloecus, Pemphredon, Psenulus, Trichodes, Trypoxylon **States (5):** DauLv, DeadLv, Hatched, Lv, OldFood """ STATUS_TABLE = [ ["DauLv", "Visible alive prepupa that stopped feeding"], ["DeadLv", "Dead visible larva"], ["Hatched", "Brood cell with hatched bee or wasp traces (cocoon or exuvia)"], ["Lv", "Visible alive larva"], ["OldFood", "Brood cell with only unconsumed food, no larva will develop"], ] theme = gr.themes.Soft() with gr.Blocks(theme=theme, title="LTN EfficientNet Two-Head Classifier") as demo: gr.Markdown( "# EfficientNet Two-Head Layer Trap Nest (LTN) Classifier\n" "Upload a brood-cell image crop to predict **taxon** and **developmental state**.\n" ) with gr.Row(): with gr.Column(scale=1): img = gr.Image(type="pil", label="Input image", height=320) gr.Markdown("### Label sets") gr.Markdown(EXAMPLE_TEXT) gr.Markdown("") legend = gr.Dataframe( value=STATUS_TABLE, headers=["State", "Description"], datatype=["str", "str"], interactive=False, wrap=True, row_count=(len(STATUS_TABLE), "fixed"), col_count=(2, "fixed"), ) with gr.Column(scale=1): sp_out = gr.Textbox(label="Predicted taxon", lines=1) st_out = gr.Textbox(label="Predicted state", lines=1) btn = gr.Button("Predict", variant="primary") btn.click(fn=predict, inputs=img, outputs=[sp_out, st_out]) gr.Markdown("### Examples") gr.Examples( examples=[[f"./{i}.jpg"] for i in range(23)], inputs=img, ) if __name__ == "__main__": demo.launch()