File size: 6,259 Bytes
cc470a4 e6d2a2e cc470a4 e6d2a2e cc470a4 e6d2a2e cc470a4 e6d2a2e cc470a4 e6d2a2e cc470a4 e6d2a2e cc470a4 e6d2a2e cc470a4 e6d2a2e cc470a4 e6d2a2e cc470a4 8d30696 cc470a4 e6d2a2e cc470a4 8d30696 0fc44d7 3ece5b1 cc470a4 0fc44d7 cc470a4 8d30696 d85ecf2 e6d2a2e 3ece5b1 e6d2a2e d85ecf2 e6d2a2e 8d30696 df03416 0791474 e6d2a2e 0791474 8d30696 2ac1183 8d30696 1abff61 8d30696 ded23b6 8d30696 2b8d7aa 8d30696 2ac1183 8d30696 cc470a4 e6d2a2e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | import json
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import efficientnet_b0
from PIL import Image
import gradio as gr
# ---------- CONFIG ----------
CKPT_PATH = "best_effnet_twohead.pt" # training script saves best.pt / last.pt
LABELS_PATH = "labels.json" # optional; fallback to taxa.txt / states.txt
TAXA_PATH = "taxa.txt" # fallback
STATES_PATH = "states.txt" # fallback
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ----------------------------
def load_lines(path: str) -> list[str]:
p = path
with open(p, "r", encoding="utf-8") as f:
return [ln.strip() for ln in f if ln.strip()]
# load labels (labels.json preferred; fallback to taxa.txt/states.txt)
try:
with open(LABELS_PATH, "r", encoding="utf-8") as f:
labels = json.load(f)
SPECIES = labels["species"]
STATE = labels["state"]
except FileNotFoundError:
SPECIES = load_lines(TAXA_PATH)
STATE = load_lines(STATES_PATH)
# model (must match training)
class EffNetTwoHead(nn.Module):
def __init__(self, num_species, num_states):
super().__init__()
base = efficientnet_b0(weights=None)
self.features = base.features
self.pool = base.avgpool
c = base.classifier[1].in_features
# training script uses dropout before heads
self.drop = nn.Dropout(0.3)
self.head_species = nn.Linear(c, num_species)
self.head_state = nn.Linear(c, num_states)
def forward(self, x):
x = self.features(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.drop(x)
return self.head_species(x), self.head_state(x)
# load model
ckpt = torch.load(CKPT_PATH, map_location="cpu")
num_species = int(ckpt.get("num_species", len(SPECIES)))
num_states = int(ckpt.get("num_states", len(STATE)))
# if checkpoint defines class counts, trust it; labels must match lengths
if len(SPECIES) != num_species:
raise RuntimeError(
f"Label mismatch: len(SPECIES)={len(SPECIES)} but ckpt num_species={num_species}. "
f"Fix labels.json or taxa.txt."
)
if len(STATE) != num_states:
raise RuntimeError(
f"Label mismatch: len(STATE)={len(STATE)} but ckpt num_states={num_states}. "
f"Fix labels.json or states.txt."
)
model = EffNetTwoHead(num_species, num_states)
model.load_state_dict(ckpt["model"], strict=True)
model.to(DEVICE).eval()
# preprocessing (align with training: letterbox to 224x224 without cropping)
# This implementation NEVER crops the image: it resizes to fit within 224x224,
# then pads the remaining area (black) to reach exactly 224x224.
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
def letterbox_pil(im: Image.Image, size: int = 224, fill: int = 0) -> Image.Image:
w, h = im.size
if w == 0 or h == 0:
return Image.new("RGB", (size, size), (fill, fill, fill))
scale = min(size / w, size / h) # fit entirely inside size x size
new_w = max(1, int(round(w * scale)))
new_h = max(1, int(round(h * scale)))
im_resized = im.resize((new_w, new_h), resample=Image.BILINEAR)
canvas = Image.new("RGB", (size, size), (fill, fill, fill))
left = (size - new_w) // 2
top = (size - new_h) // 2
canvas.paste(im_resized, (left, top))
return canvas
@torch.no_grad()
def predict(image: Image.Image):
if image is None:
return "No image", "No image"
image = image.convert("RGB")
# letterbox to 224x224 (no cropping)
image = letterbox_pil(image, size=IMG_SIZE, fill=0)
x = transforms.ToTensor()(image)
x = normalize(x).unsqueeze(0).to(DEVICE)
log_sp, log_st = model(x)
prob_sp = torch.softmax(log_sp, dim=1)[0]
prob_st = torch.softmax(log_st, dim=1)[0]
sp_id = int(prob_sp.argmax().item())
st_id = int(prob_st.argmax().item())
sp_conf = float(prob_sp[sp_id].item())
st_conf = float(prob_st[st_id].item())
sp_text = f"{SPECIES[sp_id]} (score={sp_conf:.3f})"
st_text = f"{STATE[st_id]} (score={st_conf:.3f})"
return sp_text, st_text
EXAMPLE_TEXT = """
**Taxa (20):** Anthidium, Cacoxnus indagator, Chelostoma campanularum, Chelostoma florisomne, Chelostoma rapunculi, Coeliopencyrtus, Eumenidae, Heriades, Hylaeus, Ichneumonidae, Isodontia mexicana, Megachile, Osmia bicornis, Osmia brevicornis, Osmia cornuta, Passaloecus, Pemphredon, Psenulus, Trichodes, Trypoxylon
**States (5):** DauLv, DeadLv, Hatched, Lv, OldFood
"""
STATUS_TABLE = [
["DauLv", "Visible alive prepupa that stopped feeding"],
["DeadLv", "Dead visible larva"],
["Hatched", "Brood cell with hatched bee or wasp traces (cocoon or exuvia)"],
["Lv", "Visible alive larva"],
["OldFood", "Brood cell with only unconsumed food, no larva will develop"],
]
theme = gr.themes.Soft()
with gr.Blocks(theme=theme, title="LTN EfficientNet Two-Head Classifier") as demo:
gr.Markdown(
"# EfficientNet Two-Head Layer Trap Nest (LTN) Classifier\n"
"Upload a brood-cell image crop to predict **taxon** and **developmental state**.\n"
)
with gr.Row():
with gr.Column(scale=1):
img = gr.Image(type="pil", label="Input image", height=320)
gr.Markdown("### Label sets")
gr.Markdown(EXAMPLE_TEXT)
gr.Markdown("")
legend = gr.Dataframe(
value=STATUS_TABLE,
headers=["State", "Description"],
datatype=["str", "str"],
interactive=False,
wrap=True,
row_count=(len(STATUS_TABLE), "fixed"),
col_count=(2, "fixed"),
)
with gr.Column(scale=1):
sp_out = gr.Textbox(label="Predicted taxon", lines=1)
st_out = gr.Textbox(label="Predicted state", lines=1)
btn = gr.Button("Predict", variant="primary")
btn.click(fn=predict, inputs=img, outputs=[sp_out, st_out])
gr.Markdown("### Examples")
gr.Examples(
examples=[[f"./{i}.jpg"] for i in range(23)],
inputs=img,
)
if __name__ == "__main__":
demo.launch()
|