chenchangliu commited on
Commit
8d30696
·
verified ·
1 Parent(s): f4bf3a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -40
app.py CHANGED
@@ -58,6 +58,7 @@ tfm = transforms.Compose([
58
  ),
59
  ])
60
 
 
61
  @torch.no_grad()
62
  def predict(image: Image.Image):
63
  if image is None:
@@ -68,8 +69,8 @@ def predict(image: Image.Image):
68
 
69
  log_sp, log_st = model(x)
70
 
71
- prob_sp = torch.softmax(log_sp, dim=1)[0] # [num_species]
72
- prob_st = torch.softmax(log_st, dim=1)[0] # [num_states]
73
 
74
  sp_id = int(prob_sp.argmax().item())
75
  st_id = int(prob_st.argmax().item())
@@ -82,48 +83,65 @@ def predict(image: Image.Image):
82
 
83
  return sp_text, st_text
84
 
 
85
  EXAMPLE_TEXT = """
86
- 12 taxa: Cacoxnus indagator, Chelostoma florisomne, Coeliopencyrtus, Eumenidae,
87
- Heriades, Hylaeus, Megachile, Osmia bicornis, Osmia cornuta, Passaloecus, Psenulus, Trypoxylon
88
 
89
- 4 states: DauLv, DeadLv, Lv, OldFood
90
  """
91
 
92
- demo = gr.Interface(
93
- fn=predict,
94
- inputs=gr.Image(type="pil"),
95
- outputs=[
96
- gr.Textbox(label="Predicted taxon"),
97
- gr.Textbox(label="Predicted state"),
98
- ],
99
- title="EfficientNet Two-Head Layer Trap Nest (LTN) Classifier",
100
- description=EXAMPLE_TEXT,
101
- examples=[
102
- ["./0.jpg"],
103
- ["./1.jpg"],
104
- ["./2.jpg"],
105
- ["./3.jpg"],
106
- ["./4.jpg"],
107
- ["./5.jpg"],
108
- ["./6.jpg"],
109
- ["./7.jpg"],
110
- ["./8.jpg"],
111
- ["./9.jpg"],
112
- ["./10.jpg"],
113
- ["./11.jpg"],
114
- ["./12.jpg"],
115
- ["./13.jpg"],
116
- ["./14.jpg"],
117
- ["./15.jpg"],
118
- ["./16.jpg"],
119
- ["./17.jpg"],
120
- ["./18.jpg"],
121
- ["./19.jpg"],
122
- ["./20.jpg"],
123
- ["./21.jpg"],
124
- ["./22.jpg"],
125
- ],
126
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  if __name__ == "__main__":
129
  demo.launch()
 
58
  ),
59
  ])
60
 
61
+
62
  @torch.no_grad()
63
  def predict(image: Image.Image):
64
  if image is None:
 
69
 
70
  log_sp, log_st = model(x)
71
 
72
+ prob_sp = torch.softmax(log_sp, dim=1)[0]
73
+ prob_st = torch.softmax(log_st, dim=1)[0]
74
 
75
  sp_id = int(prob_sp.argmax().item())
76
  st_id = int(prob_st.argmax().item())
 
83
 
84
  return sp_text, st_text
85
 
86
+
87
  EXAMPLE_TEXT = """
88
+ **Taxa (12):** Cacoxnus indagator, Chelostoma florisomne, Coeliopencyrtus, Eumenidae, Heriades, Hylaeus, Megachile, Osmia bicornis, Osmia cornuta, Passaloecus, Psenulus, Trypoxylon
 
89
 
90
+ **States (your model labels):** DauLv, DeadLv, Lv, OldFood
91
  """
92
 
93
+ STATUS_TABLE = [
94
+ ["Dead", "D", "DeadLv", "Dead visible larva"],
95
+ ["Food", "F", "OldFood", "Brood cell with only unconsumed food, no larva will develop"],
96
+ ["Hatched", "H", "Hatched", "Brood cell with hatched bee or wasp traces (cocoon or exuvia)"],
97
+ ["Larva", "L", "Lv", "Visible alive larva"],
98
+ ["Prepupa", "P", "Prepupa", "Visible alive prepupa that stopped feeding"],
99
+ ]
100
+
101
+
102
+ theme = gr.themes.Soft()
103
+
104
+ with gr.Blocks(theme=theme, title="LTN EfficientNet Two-Head Classifier") as demo:
105
+ gr.Markdown(
106
+ "# EfficientNet Two-Head Layer Trap Nest (LTN) Classifier\n"
107
+ "Upload a brood-cell image crop to predict **taxon** and **developmental status**.\n"
108
+ )
109
+
110
+ with gr.Row():
111
+ with gr.Column(scale=1):
112
+ img = gr.Image(type="pil", label="Input image", height=320)
113
+
114
+ gr.Markdown("### Label sets")
115
+ gr.Markdown(EXAMPLE_TEXT)
116
+
117
+ gr.Markdown("### Status legend")
118
+ legend = gr.Dataframe(
119
+ value=STATUS_TABLE,
120
+ headers=["Status", "Letter", "Code", "Description"],
121
+ datatype=["str", "str", "str", "str"],
122
+ interactive=False,
123
+ wrap=True,
124
+ row_count=(len(STATUS_TABLE), "fixed"),
125
+ col_count=(4, "fixed"),
126
+ )
127
+
128
+ with gr.Column(scale=1):
129
+ sp_out = gr.Textbox(label="Predicted taxon", lines=1)
130
+ st_out = gr.Textbox(label="Predicted state", lines=1)
131
+
132
+ btn = gr.Button("Predict", variant="primary")
133
+ btn.click(fn=predict, inputs=img, outputs=[sp_out, st_out])
134
+
135
+ gr.Markdown("### Examples")
136
+ gr.Examples(
137
+ examples=[[f"./{i}.jpg"] for i in range(23)],
138
+ inputs=img,
139
+ )
140
+
141
+ gr.Markdown(
142
+ "<small>Note: The status legend is a human-readable mapping. "
143
+ "Your model output labels come from `labels.json`.</small>"
144
+ )
145
 
146
  if __name__ == "__main__":
147
  demo.launch()