chenchangliu commited on
Commit
e6d2a2e
·
verified ·
1 Parent(s): 8cb227d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -32
app.py CHANGED
@@ -8,19 +8,30 @@ import gradio as gr
8
 
9
 
10
  # ---------- CONFIG ----------
11
- CKPT_PATH = "best_effnet_twohead.pt"
12
- LABELS_PATH = "labels.json"
 
 
13
  IMG_SIZE = 224
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  # ----------------------------
16
 
17
 
18
- # load labels
19
- with open(LABELS_PATH, "r") as f:
20
- labels = json.load(f)
 
21
 
22
- SPECIES = labels["species"]
23
- STATE = labels["state"]
 
 
 
 
 
 
 
 
24
 
25
 
26
  # model (must match training)
@@ -29,34 +40,69 @@ class EffNetTwoHead(nn.Module):
29
  super().__init__()
30
  base = efficientnet_b0(weights=None)
31
  self.features = base.features
32
- self.avgpool = base.avgpool
33
  c = base.classifier[1].in_features
 
 
 
34
  self.head_species = nn.Linear(c, num_species)
35
  self.head_state = nn.Linear(c, num_states)
36
 
37
  def forward(self, x):
38
  x = self.features(x)
39
- x = self.avgpool(x)
40
  x = torch.flatten(x, 1)
 
41
  return self.head_species(x), self.head_state(x)
42
 
43
 
44
  # load model
45
  ckpt = torch.load(CKPT_PATH, map_location="cpu")
46
- model = EffNetTwoHead(len(SPECIES), len(STATE))
47
- model.load_state_dict(ckpt["model"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  model.to(DEVICE).eval()
49
 
50
 
51
- # preprocessing (same as training)
52
- tfm = transforms.Compose([
53
- transforms.Resize((IMG_SIZE, IMG_SIZE)),
54
- transforms.ToTensor(),
55
- transforms.Normalize(
56
- mean=[0.485, 0.456, 0.406],
57
- std=[0.229, 0.224, 0.225],
58
- ),
59
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  @torch.no_grad()
@@ -65,7 +111,12 @@ def predict(image: Image.Image):
65
  return "No image", "No image"
66
 
67
  image = image.convert("RGB")
68
- x = tfm(image).unsqueeze(0).to(DEVICE)
 
 
 
 
 
69
 
70
  log_sp, log_st = model(x)
71
 
@@ -78,9 +129,6 @@ def predict(image: Image.Image):
78
  sp_conf = float(prob_sp[sp_id].item())
79
  st_conf = float(prob_st[st_id].item())
80
 
81
- #sp_text = f"{SPECIES[sp_id]} (id={sp_id}, score={sp_conf:.3f})"
82
- #st_text = f"{STATE[st_id]} (id={st_id}, score={st_conf:.3f})"
83
-
84
  sp_text = f"{SPECIES[sp_id]} (score={sp_conf:.3f})"
85
  st_text = f"{STATE[st_id]} (score={st_conf:.3f})"
86
 
@@ -88,14 +136,16 @@ def predict(image: Image.Image):
88
 
89
 
90
  EXAMPLE_TEXT = """
91
- **Taxa (12):** Cacoxnus indagator, Chelostoma florisomne, Coeliopencyrtus, Eumenidae, Heriades, Hylaeus, Megachile, Osmia bicornis, Osmia cornuta, Passaloecus, Psenulus, Trypoxylon
92
 
93
- **States (4):** DauLv, DeadLv, Lv, OldFood
94
  """
95
 
 
96
  STATUS_TABLE = [
97
  ["DauLv", "Visible alive prepupa that stopped feeding"],
98
  ["DeadLv", "Dead visible larva"],
 
99
  ["Lv", "Visible alive larva"],
100
  ["OldFood", "Brood cell with only unconsumed food, no larva will develop"],
101
  ]
@@ -140,10 +190,5 @@ with gr.Blocks(theme=theme, title="LTN EfficientNet Two-Head Classifier") as dem
140
  inputs=img,
141
  )
142
 
143
- #gr.Markdown(
144
- # "<small>Note: The status legend is a human-readable mapping. "
145
- # "Your model output labels come from `labels.json`.</small>"
146
- #)
147
-
148
  if __name__ == "__main__":
149
- demo.launch()
 
8
 
9
 
10
  # ---------- CONFIG ----------
11
+ CKPT_PATH = "best_effnet_twohead.pt" # training script saves best.pt / last.pt
12
+ LABELS_PATH = "labels.json" # optional; fallback to taxa.txt / states.txt
13
+ TAXA_PATH = "taxa.txt" # fallback
14
+ STATES_PATH = "states.txt" # fallback
15
  IMG_SIZE = 224
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
  # ----------------------------
18
 
19
 
20
+ def load_lines(path: str) -> list[str]:
21
+ p = path
22
+ with open(p, "r", encoding="utf-8") as f:
23
+ return [ln.strip() for ln in f if ln.strip()]
24
 
25
+
26
+ # load labels (labels.json preferred; fallback to taxa.txt/states.txt)
27
+ try:
28
+ with open(LABELS_PATH, "r", encoding="utf-8") as f:
29
+ labels = json.load(f)
30
+ SPECIES = labels["species"]
31
+ STATE = labels["state"]
32
+ except FileNotFoundError:
33
+ SPECIES = load_lines(TAXA_PATH)
34
+ STATE = load_lines(STATES_PATH)
35
 
36
 
37
  # model (must match training)
 
40
  super().__init__()
41
  base = efficientnet_b0(weights=None)
42
  self.features = base.features
43
+ self.pool = base.avgpool
44
  c = base.classifier[1].in_features
45
+
46
+ # training script uses dropout before heads
47
+ self.drop = nn.Dropout(0.3)
48
  self.head_species = nn.Linear(c, num_species)
49
  self.head_state = nn.Linear(c, num_states)
50
 
51
  def forward(self, x):
52
  x = self.features(x)
53
+ x = self.pool(x)
54
  x = torch.flatten(x, 1)
55
+ x = self.drop(x)
56
  return self.head_species(x), self.head_state(x)
57
 
58
 
59
  # load model
60
  ckpt = torch.load(CKPT_PATH, map_location="cpu")
61
+ num_species = int(ckpt.get("num_species", len(SPECIES)))
62
+ num_states = int(ckpt.get("num_states", len(STATE)))
63
+
64
+ # if checkpoint defines class counts, trust it; labels must match lengths
65
+ if len(SPECIES) != num_species:
66
+ raise RuntimeError(
67
+ f"Label mismatch: len(SPECIES)={len(SPECIES)} but ckpt num_species={num_species}. "
68
+ f"Fix labels.json or taxa.txt."
69
+ )
70
+ if len(STATE) != num_states:
71
+ raise RuntimeError(
72
+ f"Label mismatch: len(STATE)={len(STATE)} but ckpt num_states={num_states}. "
73
+ f"Fix labels.json or states.txt."
74
+ )
75
+
76
+ model = EffNetTwoHead(num_species, num_states)
77
+ model.load_state_dict(ckpt["model"], strict=True)
78
  model.to(DEVICE).eval()
79
 
80
 
81
+ # preprocessing (align with training: letterbox to 224x224 without cropping)
82
+ # This implementation NEVER crops the image: it resizes to fit within 224x224,
83
+ # then pads the remaining area (black) to reach exactly 224x224.
84
+ normalize = transforms.Normalize(
85
+ mean=[0.485, 0.456, 0.406],
86
+ std=[0.229, 0.224, 0.225],
87
+ )
88
+
89
+
90
+ def letterbox_pil(im: Image.Image, size: int = 224, fill: int = 0) -> Image.Image:
91
+ w, h = im.size
92
+ if w == 0 or h == 0:
93
+ return Image.new("RGB", (size, size), (fill, fill, fill))
94
+
95
+ scale = min(size / w, size / h) # fit entirely inside size x size
96
+ new_w = max(1, int(round(w * scale)))
97
+ new_h = max(1, int(round(h * scale)))
98
+
99
+ im_resized = im.resize((new_w, new_h), resample=Image.BILINEAR)
100
+
101
+ canvas = Image.new("RGB", (size, size), (fill, fill, fill))
102
+ left = (size - new_w) // 2
103
+ top = (size - new_h) // 2
104
+ canvas.paste(im_resized, (left, top))
105
+ return canvas
106
 
107
 
108
  @torch.no_grad()
 
111
  return "No image", "No image"
112
 
113
  image = image.convert("RGB")
114
+
115
+ # letterbox to 224x224 (no cropping)
116
+ image = letterbox_pil(image, size=IMG_SIZE, fill=0)
117
+
118
+ x = transforms.ToTensor()(image)
119
+ x = normalize(x).unsqueeze(0).to(DEVICE)
120
 
121
  log_sp, log_st = model(x)
122
 
 
129
  sp_conf = float(prob_sp[sp_id].item())
130
  st_conf = float(prob_st[st_id].item())
131
 
 
 
 
132
  sp_text = f"{SPECIES[sp_id]} (score={sp_conf:.3f})"
133
  st_text = f"{STATE[st_id]} (score={st_conf:.3f})"
134
 
 
136
 
137
 
138
  EXAMPLE_TEXT = """
139
+ **Taxa (20):** Anthidium, Cacoxnus indagator, Chelostoma campanularum, Chelostoma florisomne, Chelostoma rapunculi, Coeliopencyrtus, Eumenidae, Heriades, Hylaeus, Ichneumonidae, Isodontia mexicana, Megachile, Osmia bicornis, Osmia brevicornis, Osmia cornuta, Passaloecus, Pemphredon, Psenulus, Trichodes, Trypoxylon
140
 
141
+ **States (5):** DauLv, DeadLv, Hatched, Lv, OldFood
142
  """
143
 
144
+
145
  STATUS_TABLE = [
146
  ["DauLv", "Visible alive prepupa that stopped feeding"],
147
  ["DeadLv", "Dead visible larva"],
148
+ ["Hatched", "Brood cell with hatched bee or wasp traces (cocoon or exuvia)"],
149
  ["Lv", "Visible alive larva"],
150
  ["OldFood", "Brood cell with only unconsumed food, no larva will develop"],
151
  ]
 
190
  inputs=img,
191
  )
192
 
 
 
 
 
 
193
  if __name__ == "__main__":
194
+ demo.launch()