Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -58,6 +58,7 @@ tfm = transforms.Compose([
|
|
| 58 |
),
|
| 59 |
])
|
| 60 |
|
|
|
|
| 61 |
@torch.no_grad()
|
| 62 |
def predict(image: Image.Image):
|
| 63 |
if image is None:
|
|
@@ -68,8 +69,8 @@ def predict(image: Image.Image):
|
|
| 68 |
|
| 69 |
log_sp, log_st = model(x)
|
| 70 |
|
| 71 |
-
prob_sp = torch.softmax(log_sp, dim=1)[0]
|
| 72 |
-
prob_st = torch.softmax(log_st, dim=1)[0]
|
| 73 |
|
| 74 |
sp_id = int(prob_sp.argmax().item())
|
| 75 |
st_id = int(prob_st.argmax().item())
|
|
@@ -82,48 +83,65 @@ def predict(image: Image.Image):
|
|
| 82 |
|
| 83 |
return sp_text, st_text
|
| 84 |
|
|
|
|
| 85 |
EXAMPLE_TEXT = """
|
| 86 |
-
|
| 87 |
-
Heriades, Hylaeus, Megachile, Osmia bicornis, Osmia cornuta, Passaloecus, Psenulus, Trypoxylon
|
| 88 |
|
| 89 |
-
|
| 90 |
"""
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
| 129 |
demo.launch()
|
|
|
|
| 58 |
),
|
| 59 |
])
|
| 60 |
|
| 61 |
+
|
| 62 |
@torch.no_grad()
|
| 63 |
def predict(image: Image.Image):
|
| 64 |
if image is None:
|
|
|
|
| 69 |
|
| 70 |
log_sp, log_st = model(x)
|
| 71 |
|
| 72 |
+
prob_sp = torch.softmax(log_sp, dim=1)[0]
|
| 73 |
+
prob_st = torch.softmax(log_st, dim=1)[0]
|
| 74 |
|
| 75 |
sp_id = int(prob_sp.argmax().item())
|
| 76 |
st_id = int(prob_st.argmax().item())
|
|
|
|
| 83 |
|
| 84 |
return sp_text, st_text
|
| 85 |
|
| 86 |
+
|
| 87 |
EXAMPLE_TEXT = """
|
| 88 |
+
**Taxa (12):** Cacoxnus indagator, Chelostoma florisomne, Coeliopencyrtus, Eumenidae, Heriades, Hylaeus, Megachile, Osmia bicornis, Osmia cornuta, Passaloecus, Psenulus, Trypoxylon
|
|
|
|
| 89 |
|
| 90 |
+
**States (your model labels):** DauLv, DeadLv, Lv, OldFood
|
| 91 |
"""
|
| 92 |
|
| 93 |
+
STATUS_TABLE = [
|
| 94 |
+
["Dead", "D", "DeadLv", "Dead visible larva"],
|
| 95 |
+
["Food", "F", "OldFood", "Brood cell with only unconsumed food, no larva will develop"],
|
| 96 |
+
["Hatched", "H", "Hatched", "Brood cell with hatched bee or wasp traces (cocoon or exuvia)"],
|
| 97 |
+
["Larva", "L", "Lv", "Visible alive larva"],
|
| 98 |
+
["Prepupa", "P", "Prepupa", "Visible alive prepupa that stopped feeding"],
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
theme = gr.themes.Soft()
|
| 103 |
+
|
| 104 |
+
with gr.Blocks(theme=theme, title="LTN EfficientNet Two-Head Classifier") as demo:
|
| 105 |
+
gr.Markdown(
|
| 106 |
+
"# EfficientNet Two-Head Layer Trap Nest (LTN) Classifier\n"
|
| 107 |
+
"Upload a brood-cell image crop to predict **taxon** and **developmental status**.\n"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
with gr.Row():
|
| 111 |
+
with gr.Column(scale=1):
|
| 112 |
+
img = gr.Image(type="pil", label="Input image", height=320)
|
| 113 |
+
|
| 114 |
+
gr.Markdown("### Label sets")
|
| 115 |
+
gr.Markdown(EXAMPLE_TEXT)
|
| 116 |
+
|
| 117 |
+
gr.Markdown("### Status legend")
|
| 118 |
+
legend = gr.Dataframe(
|
| 119 |
+
value=STATUS_TABLE,
|
| 120 |
+
headers=["Status", "Letter", "Code", "Description"],
|
| 121 |
+
datatype=["str", "str", "str", "str"],
|
| 122 |
+
interactive=False,
|
| 123 |
+
wrap=True,
|
| 124 |
+
row_count=(len(STATUS_TABLE), "fixed"),
|
| 125 |
+
col_count=(4, "fixed"),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
with gr.Column(scale=1):
|
| 129 |
+
sp_out = gr.Textbox(label="Predicted taxon", lines=1)
|
| 130 |
+
st_out = gr.Textbox(label="Predicted state", lines=1)
|
| 131 |
+
|
| 132 |
+
btn = gr.Button("Predict", variant="primary")
|
| 133 |
+
btn.click(fn=predict, inputs=img, outputs=[sp_out, st_out])
|
| 134 |
+
|
| 135 |
+
gr.Markdown("### Examples")
|
| 136 |
+
gr.Examples(
|
| 137 |
+
examples=[[f"./{i}.jpg"] for i in range(23)],
|
| 138 |
+
inputs=img,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
gr.Markdown(
|
| 142 |
+
"<small>Note: The status legend is a human-readable mapping. "
|
| 143 |
+
"Your model output labels come from `labels.json`.</small>"
|
| 144 |
+
)
|
| 145 |
|
| 146 |
if __name__ == "__main__":
|
| 147 |
demo.launch()
|