chenchangliu's picture
Update app.py
e6d2a2e verified
import json
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import efficientnet_b0
from PIL import Image
import gradio as gr
# ---------- CONFIG ----------
CKPT_PATH = "best_effnet_twohead.pt" # training script saves best.pt / last.pt
LABELS_PATH = "labels.json" # optional; fallback to taxa.txt / states.txt
TAXA_PATH = "taxa.txt" # fallback
STATES_PATH = "states.txt" # fallback
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ----------------------------
def load_lines(path: str) -> list[str]:
p = path
with open(p, "r", encoding="utf-8") as f:
return [ln.strip() for ln in f if ln.strip()]
# load labels (labels.json preferred; fallback to taxa.txt/states.txt)
try:
with open(LABELS_PATH, "r", encoding="utf-8") as f:
labels = json.load(f)
SPECIES = labels["species"]
STATE = labels["state"]
except FileNotFoundError:
SPECIES = load_lines(TAXA_PATH)
STATE = load_lines(STATES_PATH)
# model (must match training)
class EffNetTwoHead(nn.Module):
def __init__(self, num_species, num_states):
super().__init__()
base = efficientnet_b0(weights=None)
self.features = base.features
self.pool = base.avgpool
c = base.classifier[1].in_features
# training script uses dropout before heads
self.drop = nn.Dropout(0.3)
self.head_species = nn.Linear(c, num_species)
self.head_state = nn.Linear(c, num_states)
def forward(self, x):
x = self.features(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.drop(x)
return self.head_species(x), self.head_state(x)
# load model
ckpt = torch.load(CKPT_PATH, map_location="cpu")
num_species = int(ckpt.get("num_species", len(SPECIES)))
num_states = int(ckpt.get("num_states", len(STATE)))
# if checkpoint defines class counts, trust it; labels must match lengths
if len(SPECIES) != num_species:
raise RuntimeError(
f"Label mismatch: len(SPECIES)={len(SPECIES)} but ckpt num_species={num_species}. "
f"Fix labels.json or taxa.txt."
)
if len(STATE) != num_states:
raise RuntimeError(
f"Label mismatch: len(STATE)={len(STATE)} but ckpt num_states={num_states}. "
f"Fix labels.json or states.txt."
)
model = EffNetTwoHead(num_species, num_states)
model.load_state_dict(ckpt["model"], strict=True)
model.to(DEVICE).eval()
# preprocessing (align with training: letterbox to 224x224 without cropping)
# This implementation NEVER crops the image: it resizes to fit within 224x224,
# then pads the remaining area (black) to reach exactly 224x224.
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
def letterbox_pil(im: Image.Image, size: int = 224, fill: int = 0) -> Image.Image:
w, h = im.size
if w == 0 or h == 0:
return Image.new("RGB", (size, size), (fill, fill, fill))
scale = min(size / w, size / h) # fit entirely inside size x size
new_w = max(1, int(round(w * scale)))
new_h = max(1, int(round(h * scale)))
im_resized = im.resize((new_w, new_h), resample=Image.BILINEAR)
canvas = Image.new("RGB", (size, size), (fill, fill, fill))
left = (size - new_w) // 2
top = (size - new_h) // 2
canvas.paste(im_resized, (left, top))
return canvas
@torch.no_grad()
def predict(image: Image.Image):
if image is None:
return "No image", "No image"
image = image.convert("RGB")
# letterbox to 224x224 (no cropping)
image = letterbox_pil(image, size=IMG_SIZE, fill=0)
x = transforms.ToTensor()(image)
x = normalize(x).unsqueeze(0).to(DEVICE)
log_sp, log_st = model(x)
prob_sp = torch.softmax(log_sp, dim=1)[0]
prob_st = torch.softmax(log_st, dim=1)[0]
sp_id = int(prob_sp.argmax().item())
st_id = int(prob_st.argmax().item())
sp_conf = float(prob_sp[sp_id].item())
st_conf = float(prob_st[st_id].item())
sp_text = f"{SPECIES[sp_id]} (score={sp_conf:.3f})"
st_text = f"{STATE[st_id]} (score={st_conf:.3f})"
return sp_text, st_text
EXAMPLE_TEXT = """
**Taxa (20):** Anthidium, Cacoxnus indagator, Chelostoma campanularum, Chelostoma florisomne, Chelostoma rapunculi, Coeliopencyrtus, Eumenidae, Heriades, Hylaeus, Ichneumonidae, Isodontia mexicana, Megachile, Osmia bicornis, Osmia brevicornis, Osmia cornuta, Passaloecus, Pemphredon, Psenulus, Trichodes, Trypoxylon
**States (5):** DauLv, DeadLv, Hatched, Lv, OldFood
"""
STATUS_TABLE = [
["DauLv", "Visible alive prepupa that stopped feeding"],
["DeadLv", "Dead visible larva"],
["Hatched", "Brood cell with hatched bee or wasp traces (cocoon or exuvia)"],
["Lv", "Visible alive larva"],
["OldFood", "Brood cell with only unconsumed food, no larva will develop"],
]
theme = gr.themes.Soft()
with gr.Blocks(theme=theme, title="LTN EfficientNet Two-Head Classifier") as demo:
gr.Markdown(
"# EfficientNet Two-Head Layer Trap Nest (LTN) Classifier\n"
"Upload a brood-cell image crop to predict **taxon** and **developmental state**.\n"
)
with gr.Row():
with gr.Column(scale=1):
img = gr.Image(type="pil", label="Input image", height=320)
gr.Markdown("### Label sets")
gr.Markdown(EXAMPLE_TEXT)
gr.Markdown("")
legend = gr.Dataframe(
value=STATUS_TABLE,
headers=["State", "Description"],
datatype=["str", "str"],
interactive=False,
wrap=True,
row_count=(len(STATUS_TABLE), "fixed"),
col_count=(2, "fixed"),
)
with gr.Column(scale=1):
sp_out = gr.Textbox(label="Predicted taxon", lines=1)
st_out = gr.Textbox(label="Predicted state", lines=1)
btn = gr.Button("Predict", variant="primary")
btn.click(fn=predict, inputs=img, outputs=[sp_out, st_out])
gr.Markdown("### Examples")
gr.Examples(
examples=[[f"./{i}.jpg"] for i in range(23)],
inputs=img,
)
if __name__ == "__main__":
demo.launch()