chenchangliu's picture
Update app.py
f9ebe65 verified
raw
history blame
3.91 kB
import json
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import efficientnet_b0
from PIL import Image
import gradio as gr
# ---------- CONFIG ----------
CKPT_PATH = "best_effnet_twohead.pt"
LABELS_PATH = "labels.json"
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ----------------------------
# load labels
with open(LABELS_PATH, "r") as f:
labels = json.load(f)
SPECIES = labels["species"]
STATE = labels["state"]
# model (must match training)
class EffNetTwoHead(nn.Module):
def __init__(self, num_species, num_states):
super().__init__()
base = efficientnet_b0(weights=None)
self.features = base.features
self.avgpool = base.avgpool
c = base.classifier[1].in_features
self.head_species = nn.Linear(c, num_species)
self.head_state = nn.Linear(c, num_states)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return self.head_species(x), self.head_state(x)
# load model
ckpt = torch.load(CKPT_PATH, map_location="cpu")
model = EffNetTwoHead(len(SPECIES), len(STATE))
model.load_state_dict(ckpt["model"])
model.to(DEVICE).eval()
# preprocessing (same as training)
tfm = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
@torch.no_grad()
def predict(image: Image.Image):
if image is None:
return "No image", "No image"
image = image.convert("RGB")
x = tfm(image).unsqueeze(0).to(DEVICE)
log_sp, log_st = model(x)
prob_sp = torch.softmax(log_sp, dim=1)[0] # [num_species]
prob_st = torch.softmax(log_st, dim=1)[0] # [num_states]
sp_id = int(prob_sp.argmax().item())
st_id = int(prob_st.argmax().item())
sp_conf = float(prob_sp[sp_id].item())
st_conf = float(prob_st[st_id].item())
sp_text = f"{SPECIES[sp_id]} (id={sp_id}, conf={sp_conf:.3f})"
st_text = f"{STATE[st_id]} (id={st_id}, conf={st_conf:.3f})"
return sp_text, st_text
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Textbox(label="Predicted species"),
gr.Textbox(label="Predicted state"),
],
title="EfficientNet Two-Head Layer Trap Nest (LTN) Classifier",
description="12 species, 4 states supported.",
examples=[
["./0.jpg"],
["./1.jpg"],
["./2.jpg"],
["./3.jpg"],
["./4.jpg"],
["./5.jpg"],
["./6.jpg"],
["./7.jpg"],
["./8.jpg"],
["./9.jpg"],
["./10.jpg"],
["./11.jpg"],
["./12.jpg"],
["./13.jpg"],
["./14.jpg"],
["./15.jpg"],
["./16.jpg"],
["./17.jpg"],
["./18.jpg"],
["./19.jpg"],
["./20.jpg"],
["./21.jpg"],
["./22.jpg"],
],
descriptions=[
"0: Cacoxnus indagator - Lv",
"1: Chelostoma florisomne - DauLv",
"2: Chelostoma florisomne - OldFood",
"3: Coeliopencyrtus - DauLv",
"4: Eumenidae - DauLv",
"5: Eumenidae - OldFood",
"6: Heriades - DeadLv",
"7: Heriades - Lv",
"8: Heriades - OldFood",
"9: Hylaeus - Lv",
"10: Hylaeus - OldFood",
"11: Megachile - DauLv",
"12: Osmia bicornis - DauLv",
"13: Osmia bicornis - DeadLv",
"14: Osmia bicornis - OldFood",
"15: Osmia cornuta - DauLv",
"16: Osmia cornuta - DeadLv",
"17: Osmia cornuta - OldFood",
"18: Passaloecus - Lv",
"19: Passaloecus - OldFood",
"20: Psenulus - DauLv",
"21: Trypoxylon - DauLv",
"22: Trypoxylon - OldFood",
],
)
if __name__ == "__main__":
demo.launch()