File size: 6,259 Bytes
cc470a4
 
 
 
 
 
 
 
 
 
e6d2a2e
 
 
 
cc470a4
 
 
 
 
e6d2a2e
 
 
 
cc470a4
e6d2a2e
 
 
 
 
 
 
 
 
 
cc470a4
 
 
 
 
 
 
 
e6d2a2e
cc470a4
e6d2a2e
 
 
cc470a4
 
 
 
 
e6d2a2e
cc470a4
e6d2a2e
cc470a4
 
 
 
 
e6d2a2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc470a4
 
 
e6d2a2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc470a4
8d30696
cc470a4
 
 
 
 
 
e6d2a2e
 
 
 
 
 
cc470a4
 
 
8d30696
 
0fc44d7
 
 
 
 
 
 
3ece5b1
 
cc470a4
0fc44d7
cc470a4
8d30696
d85ecf2
e6d2a2e
3ece5b1
e6d2a2e
d85ecf2
 
e6d2a2e
8d30696
df03416
0791474
e6d2a2e
0791474
 
8d30696
 
 
 
 
2ac1183
8d30696
 
1abff61
8d30696
 
 
 
 
 
 
 
 
ded23b6
8d30696
 
2b8d7aa
 
8d30696
 
 
2ac1183
8d30696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc470a4
e6d2a2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import json
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import efficientnet_b0
from PIL import Image
import gradio as gr


# ---------- CONFIG ----------
CKPT_PATH = "best_effnet_twohead.pt"          # training script saves best.pt / last.pt
LABELS_PATH = "labels.json"    # optional; fallback to taxa.txt / states.txt
TAXA_PATH = "taxa.txt"         # fallback
STATES_PATH = "states.txt"     # fallback
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ----------------------------


def load_lines(path: str) -> list[str]:
    p = path
    with open(p, "r", encoding="utf-8") as f:
        return [ln.strip() for ln in f if ln.strip()]


# load labels (labels.json preferred; fallback to taxa.txt/states.txt)
try:
    with open(LABELS_PATH, "r", encoding="utf-8") as f:
        labels = json.load(f)
    SPECIES = labels["species"]
    STATE = labels["state"]
except FileNotFoundError:
    SPECIES = load_lines(TAXA_PATH)
    STATE = load_lines(STATES_PATH)


# model (must match training)
class EffNetTwoHead(nn.Module):
    def __init__(self, num_species, num_states):
        super().__init__()
        base = efficientnet_b0(weights=None)
        self.features = base.features
        self.pool = base.avgpool
        c = base.classifier[1].in_features

        # training script uses dropout before heads
        self.drop = nn.Dropout(0.3)
        self.head_species = nn.Linear(c, num_species)
        self.head_state = nn.Linear(c, num_states)

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.drop(x)
        return self.head_species(x), self.head_state(x)


# load model
ckpt = torch.load(CKPT_PATH, map_location="cpu")
num_species = int(ckpt.get("num_species", len(SPECIES)))
num_states = int(ckpt.get("num_states", len(STATE)))

# if checkpoint defines class counts, trust it; labels must match lengths
if len(SPECIES) != num_species:
    raise RuntimeError(
        f"Label mismatch: len(SPECIES)={len(SPECIES)} but ckpt num_species={num_species}. "
        f"Fix labels.json or taxa.txt."
    )
if len(STATE) != num_states:
    raise RuntimeError(
        f"Label mismatch: len(STATE)={len(STATE)} but ckpt num_states={num_states}. "
        f"Fix labels.json or states.txt."
    )

model = EffNetTwoHead(num_species, num_states)
model.load_state_dict(ckpt["model"], strict=True)
model.to(DEVICE).eval()


# preprocessing (align with training: letterbox to 224x224 without cropping)
# This implementation NEVER crops the image: it resizes to fit within 224x224,
# then pads the remaining area (black) to reach exactly 224x224.
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
)


def letterbox_pil(im: Image.Image, size: int = 224, fill: int = 0) -> Image.Image:
    w, h = im.size
    if w == 0 or h == 0:
        return Image.new("RGB", (size, size), (fill, fill, fill))

    scale = min(size / w, size / h)  # fit entirely inside size x size
    new_w = max(1, int(round(w * scale)))
    new_h = max(1, int(round(h * scale)))

    im_resized = im.resize((new_w, new_h), resample=Image.BILINEAR)

    canvas = Image.new("RGB", (size, size), (fill, fill, fill))
    left = (size - new_w) // 2
    top = (size - new_h) // 2
    canvas.paste(im_resized, (left, top))
    return canvas


@torch.no_grad()
def predict(image: Image.Image):
    if image is None:
        return "No image", "No image"

    image = image.convert("RGB")

    # letterbox to 224x224 (no cropping)
    image = letterbox_pil(image, size=IMG_SIZE, fill=0)

    x = transforms.ToTensor()(image)
    x = normalize(x).unsqueeze(0).to(DEVICE)

    log_sp, log_st = model(x)

    prob_sp = torch.softmax(log_sp, dim=1)[0]
    prob_st = torch.softmax(log_st, dim=1)[0]

    sp_id = int(prob_sp.argmax().item())
    st_id = int(prob_st.argmax().item())

    sp_conf = float(prob_sp[sp_id].item())
    st_conf = float(prob_st[st_id].item())

    sp_text = f"{SPECIES[sp_id]}  (score={sp_conf:.3f})"
    st_text = f"{STATE[st_id]}  (score={st_conf:.3f})"

    return sp_text, st_text


EXAMPLE_TEXT = """
**Taxa (20):** Anthidium, Cacoxnus indagator, Chelostoma campanularum, Chelostoma florisomne, Chelostoma rapunculi, Coeliopencyrtus, Eumenidae, Heriades, Hylaeus, Ichneumonidae, Isodontia mexicana, Megachile, Osmia bicornis, Osmia brevicornis, Osmia cornuta, Passaloecus, Pemphredon, Psenulus, Trichodes, Trypoxylon

**States (5):** DauLv, DeadLv, Hatched, Lv, OldFood
"""


STATUS_TABLE = [
    ["DauLv",  "Visible alive prepupa that stopped feeding"],
    ["DeadLv",  "Dead visible larva"],
    ["Hatched",  "Brood cell with hatched bee or wasp traces (cocoon or exuvia)"],
    ["Lv",      "Visible alive larva"],
    ["OldFood", "Brood cell with only unconsumed food, no larva will develop"],
]


theme = gr.themes.Soft()

with gr.Blocks(theme=theme, title="LTN EfficientNet Two-Head Classifier") as demo:
    gr.Markdown(
        "# EfficientNet Two-Head Layer Trap Nest (LTN) Classifier\n"
        "Upload a brood-cell image crop to predict **taxon** and **developmental state**.\n"
    )

    with gr.Row():
        with gr.Column(scale=1):
            img = gr.Image(type="pil", label="Input image", height=320)

            gr.Markdown("### Label sets")
            gr.Markdown(EXAMPLE_TEXT)

            gr.Markdown("")
            legend = gr.Dataframe(
                value=STATUS_TABLE,
                headers=["State", "Description"],
                datatype=["str", "str"],
                interactive=False,
                wrap=True,
                row_count=(len(STATUS_TABLE), "fixed"),
                col_count=(2, "fixed"),
            )

        with gr.Column(scale=1):
            sp_out = gr.Textbox(label="Predicted taxon", lines=1)
            st_out = gr.Textbox(label="Predicted state", lines=1)

            btn = gr.Button("Predict", variant="primary")
            btn.click(fn=predict, inputs=img, outputs=[sp_out, st_out])

            gr.Markdown("### Examples")
            gr.Examples(
                examples=[[f"./{i}.jpg"] for i in range(23)],
                inputs=img,
            )

if __name__ == "__main__":
    demo.launch()