chenchangliu commited on
Commit
722bb9c
·
verified ·
1 Parent(s): 27c6faa

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +119 -0
  2. pipeline.py +311 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio app for HuggingFace Spaces.
4
+ Wraps the LTN localize-and-classify pipeline with a simple web UI.
5
+ """
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import gradio as gr
11
+ from PIL import Image
12
+ from torchvision.io import read_image
13
+
14
+ from pipeline import (
15
+ TAXON_NAMES, STATE_NAMES,
16
+ DET_CONF, YOLO_WEIGHTS, CLF_WEIGHTS,
17
+ load_classifier, classify_crops, annotate_image,
18
+ )
19
+ from ultralytics import YOLO
20
+
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Load models once at startup
24
+ # ---------------------------------------------------------------------------
25
+
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ yolo = YOLO(str(YOLO_WEIGHTS))
28
+ classifier = load_classifier(CLF_WEIGHTS, DEVICE)
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Inference
33
+ # ---------------------------------------------------------------------------
34
+
35
+ def predict(image: Image.Image, conf: float):
36
+ if image is None:
37
+ return None, "No image provided."
38
+
39
+ # Save PIL image to a temp file — YOLO and read_image both need a path
40
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
41
+ tmp_in = Path(f.name)
42
+ image.save(tmp_in)
43
+ tmp_out = tmp_in.with_name(tmp_in.stem + "_out.jpg")
44
+
45
+ try:
46
+ # 1 — Detect
47
+ det = yolo.predict(str(tmp_in), conf=conf, verbose=False)[0]
48
+ boxes = det.boxes.xyxy.cpu().tolist()
49
+ det_confs = det.boxes.conf.cpu().tolist()
50
+
51
+ if not boxes:
52
+ return image, "No cells detected. Try lowering the confidence threshold."
53
+
54
+ # 2 — Crop
55
+ img_tensor = read_image(str(tmp_in))
56
+ if img_tensor.shape[0] == 4:
57
+ img_tensor = img_tensor[:3]
58
+
59
+ crops = [
60
+ img_tensor[:, int(y1):int(y2), int(x1):int(x2)]
61
+ for x1, y1, x2, y2 in boxes
62
+ ]
63
+
64
+ # 3 — Classify
65
+ predictions = classify_crops(crops, classifier, DEVICE)
66
+
67
+ # 4 — Annotate
68
+ annotate_image(tmp_in, boxes, predictions, det_confs, tmp_out)
69
+ result_img = Image.open(tmp_out).copy()
70
+
71
+ # Build results table text
72
+ lines = [f"{len(boxes)} cell(s) detected\n"]
73
+ for i, (taxon_idx, state_idx, tx_conf, st_conf) in enumerate(predictions):
74
+ lines.append(
75
+ f"[{i + 1}] {TAXON_NAMES[taxon_idx]} ({tx_conf:.0%})"
76
+ f" — {STATE_NAMES[state_idx]} ({st_conf:.0%})"
77
+ )
78
+
79
+ return result_img, "\n".join(lines)
80
+
81
+ finally:
82
+ tmp_in.unlink(missing_ok=True)
83
+ tmp_out.unlink(missing_ok=True)
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # UI
88
+ # ---------------------------------------------------------------------------
89
+
90
+ with gr.Blocks(title="LTN Brood Cell Classifier") as demo:
91
+ gr.Markdown(
92
+ "# LTN Brood Cell Classifier\n"
93
+ "Upload a Layer Trap Nest image. "
94
+ "YOLOv8 localizes each brood cell; EfficientNet classifies its **taxon** and **state**."
95
+ )
96
+
97
+ with gr.Row():
98
+ with gr.Column():
99
+ inp_image = gr.Image(type="pil", label="Input image")
100
+ conf_slider = gr.Slider(
101
+ minimum=0.1, maximum=1.0, value=DET_CONF, step=0.05,
102
+ label="Detection confidence threshold",
103
+ info="Raise to keep only high-confidence detections.",
104
+ )
105
+ run_btn = gr.Button("Run", variant="primary")
106
+
107
+ with gr.Column():
108
+ out_image = gr.Image(type="pil", label="Annotated output")
109
+ out_text = gr.Textbox(label="Predictions", lines=12)
110
+
111
+ run_btn.click(
112
+ fn=predict,
113
+ inputs=[inp_image, conf_slider],
114
+ outputs=[out_image, out_text],
115
+ )
116
+
117
+
118
+ if __name__ == "__main__":
119
+ demo.launch()
pipeline.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LTN Pipeline: YOLOv8 localization → EfficientNet two-head classification.
4
+
5
+ Detects brood cells in Layer Trap Nest images, classifies each crop by
6
+ taxon and state, and saves annotated output images.
7
+
8
+ Usage:
9
+ python pipeline.py image.jpg
10
+ python pipeline.py images/ # process a whole directory
11
+ python pipeline.py a.jpg b.jpg --out results/ --conf 0.3
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ from pathlib import Path
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # CONFIG — edit these instead of passing CLI flags every time
20
+ # ---------------------------------------------------------------------------
21
+
22
+ YOLO_WEIGHTS = Path("yolov8_localizer.pt")
23
+ CLF_WEIGHTS = Path("effnet_two_head_classifier.pt")
24
+ OUTPUT_DIR = Path("pipeline_out")
25
+ DET_CONF = 0.5 # YOLO detection confidence threshold (0–1); raise to be more strict
26
+ BATCH_SIZE = 32 # classifier batch size
27
+ DEVICE = "cuda" if __import__("torch").cuda.is_available() else "cpu"
28
+
29
+ # ---------------------------------------------------------------------------
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torchvision.transforms.functional as TF
34
+ from torchvision.models import efficientnet_b0
35
+ from torchvision.io import read_image
36
+ from torchvision import transforms
37
+ from PIL import Image, ImageDraw, ImageFont
38
+ from ultralytics import YOLO
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Class labels
43
+ # ---------------------------------------------------------------------------
44
+
45
+ TAXON_NAMES = [
46
+ "Anthidium", "Cacoxnus indagator", "Chelostoma campanularum",
47
+ "Chelostoma florisomne", "Chelostoma rapunculi", "Coeliopencyrtus",
48
+ "Eumenidae", "Heriades", "Hylaeus", "Ichneumonidae", "Isodontia mexicana",
49
+ "Megachile", "Osmia bicornis", "Osmia brevicornis", "Osmia cornuta",
50
+ "Passaloecus", "Pemphredon", "Psenulus", "Trichodes", "Trypoxylon",
51
+ ]
52
+
53
+ STATE_NAMES = ["DauLv", "DeadLv", "Hatched", "Lv", "OldFood"]
54
+
55
+ # One distinct colour per state (RGB)
56
+ STATE_COLORS = [
57
+ (230, 130, 0), # DauLv - amber
58
+ (210, 30, 45), # DeadLv - crimson
59
+ ( 40, 180, 60), # Hatched - green
60
+ ( 30, 140, 240), # Lv - blue
61
+ (150, 50, 220), # OldFood - purple
62
+ ]
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # Preprocessing (must match training)
67
+ # ---------------------------------------------------------------------------
68
+
69
+ class Letterbox:
70
+ """Resize so the longer side = `size`, pad shorter side to square."""
71
+ def __init__(self, size: int = 224, fill: int = 0):
72
+ self.size = size
73
+ self.fill = fill
74
+
75
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
76
+ _, h, w = x.shape
77
+ scale = self.size / max(h, w)
78
+ new_h, new_w = int(round(h * scale)), int(round(w * scale))
79
+ x = TF.resize(x, [new_h, new_w], antialias=True)
80
+ pad_h, pad_w = self.size - new_h, self.size - new_w
81
+ pad_top, pad_left = pad_h // 2, pad_w // 2
82
+ x = TF.pad(x, [pad_left, pad_top, pad_w - pad_left, pad_h - pad_top], fill=self.fill)
83
+ return x
84
+
85
+
86
+ _letterbox = Letterbox(224, fill=0)
87
+ _normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
88
+
89
+
90
+ def preprocess(crop: torch.Tensor) -> torch.Tensor:
91
+ """CHW uint8 RGB tensor → normalized 224×224 float tensor."""
92
+ return _normalize(_letterbox(crop.float() / 255.0))
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Model
97
+ # ---------------------------------------------------------------------------
98
+
99
+ class EffNetTwoHead(nn.Module):
100
+ def __init__(self, num_species: int, num_states: int):
101
+ super().__init__()
102
+ base = efficientnet_b0(weights=None)
103
+ self.features = base.features
104
+ self.pool = base.avgpool
105
+ c = base.classifier[1].in_features
106
+ self.drop = nn.Dropout(0.3)
107
+ self.head_species = nn.Linear(c, num_species)
108
+ self.head_state = nn.Linear(c, num_states)
109
+
110
+ def forward(self, x: torch.Tensor):
111
+ x = self.features(x)
112
+ x = self.pool(x)
113
+ x = torch.flatten(x, 1)
114
+ x = self.drop(x)
115
+ return self.head_species(x), self.head_state(x)
116
+
117
+
118
+ def load_classifier(ckpt_path: Path, device: str) -> EffNetTwoHead:
119
+ ckpt = torch.load(ckpt_path, map_location=device)
120
+ model = EffNetTwoHead(int(ckpt["num_species"]), int(ckpt["num_states"])).to(device)
121
+ model.load_state_dict(ckpt["model"], strict=True)
122
+ model.eval()
123
+ return model
124
+
125
+
126
+ # ---------------------------------------------------------------------------
127
+ # Inference
128
+ # ---------------------------------------------------------------------------
129
+
130
+ @torch.no_grad()
131
+ def classify_crops(
132
+ crops: list[torch.Tensor],
133
+ model: EffNetTwoHead,
134
+ device: str,
135
+ batch_size: int = 32,
136
+ ) -> list[tuple[int, int, float, float]]:
137
+ """
138
+ Args:
139
+ crops: list of CHW uint8 tensors (RGB)
140
+ Returns:
141
+ list of (taxon_idx, state_idx, taxon_conf, state_conf)
142
+ """
143
+ results = []
144
+ for i in range(0, len(crops), batch_size):
145
+ batch = torch.stack([preprocess(c) for c in crops[i : i + batch_size]]).to(device)
146
+ lsp, lst = model(batch)
147
+ sp_conf, sp_idx = lsp.softmax(1).max(1)
148
+ st_conf, st_idx = lst.softmax(1).max(1)
149
+ for k in range(len(sp_idx)):
150
+ results.append((sp_idx[k].item(), st_idx[k].item(), sp_conf[k].item(), st_conf[k].item()))
151
+ return results
152
+
153
+
154
+ # ---------------------------------------------------------------------------
155
+ # Visualisation
156
+ # ---------------------------------------------------------------------------
157
+
158
+ def _load_font(size: int) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
159
+ for name in ["Arial.ttf", "DejaVuSans.ttf", "LiberationSans-Regular.ttf", "Helvetica.ttc"]:
160
+ try:
161
+ return ImageFont.truetype(name, size)
162
+ except Exception:
163
+ pass
164
+ return ImageFont.load_default()
165
+
166
+
167
+ def annotate_image(
168
+ img_path: Path,
169
+ boxes: list[list[float]],
170
+ predictions: list[tuple[int, int, float, float]],
171
+ det_confs: list[float],
172
+ out_path: Path,
173
+ ) -> None:
174
+ img = Image.open(img_path).convert("RGB")
175
+ draw = ImageDraw.Draw(img)
176
+
177
+ # Scale line width and font size with image resolution
178
+ ref = max(img.width, img.height)
179
+ lw = max(2, ref // 500)
180
+ font_size = max(12, ref // 70)
181
+ font = _load_font(font_size)
182
+ pad = max(3, font_size // 4)
183
+
184
+ for box, (taxon_idx, state_idx, tx_conf, st_conf), det_conf in zip(boxes, predictions, det_confs):
185
+ x1, y1, x2, y2 = (int(v) for v in box)
186
+ color = STATE_COLORS[state_idx % len(STATE_COLORS)]
187
+
188
+ # Bounding box
189
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=lw)
190
+
191
+ line1 = f"{TAXON_NAMES[taxon_idx]} {tx_conf:.0%}"
192
+ line2 = f"{STATE_NAMES[state_idx]} {st_conf:.0%}"
193
+
194
+ # Measure both lines
195
+ bb1 = draw.textbbox((0, 0), line1, font=font)
196
+ bb2 = draw.textbbox((0, 0), line2, font=font)
197
+ tw = max(bb1[2] - bb1[0], bb2[2] - bb2[0])
198
+ th = bb1[3] - bb1[1] # assume same line height
199
+
200
+ label_h = 2 * th + 3 * pad # height of label block
201
+
202
+ # Place label above box; if not enough room, place it inside the box top
203
+ if y1 >= label_h:
204
+ lx1, ly1 = x1, y1 - label_h
205
+ else:
206
+ lx1, ly1 = x1, y1 + lw
207
+
208
+ draw.rectangle([lx1, ly1, lx1 + tw + 2 * pad, ly1 + label_h], fill=color)
209
+ draw.text((lx1 + pad, ly1 + pad), line1, fill=(255, 255, 255), font=font)
210
+ draw.text((lx1 + pad, ly1 + pad + th + pad), line2, fill=(255, 255, 255), font=font)
211
+
212
+ out_path.parent.mkdir(parents=True, exist_ok=True)
213
+ img.save(out_path)
214
+
215
+
216
+ # ---------------------------------------------------------------------------
217
+ # Pipeline
218
+ # ---------------------------------------------------------------------------
219
+
220
+ def run_pipeline(
221
+ img_path: Path,
222
+ yolo: YOLO,
223
+ classifier: EffNetTwoHead,
224
+ device: str,
225
+ conf: float,
226
+ out_dir: Path,
227
+ ) -> None:
228
+ print(f"\n{img_path.name}")
229
+
230
+ # 1 — Detect cells
231
+ det = yolo.predict(str(img_path), conf=conf, verbose=False)[0]
232
+ boxes = det.boxes.xyxy.cpu().tolist()
233
+ det_confs = det.boxes.conf.cpu().tolist()
234
+
235
+ if not boxes:
236
+ print(" No detections.")
237
+ return
238
+
239
+ print(f" {len(boxes)} cell(s) detected")
240
+
241
+ # 2 — Crop each detection from the original image
242
+ img_tensor = read_image(str(img_path))
243
+ if img_tensor.shape[0] == 4: # drop alpha channel if present
244
+ img_tensor = img_tensor[:3]
245
+
246
+ crops = [
247
+ img_tensor[:, int(y1):int(y2), int(x1):int(x2)]
248
+ for x1, y1, x2, y2 in boxes
249
+ ]
250
+
251
+ # 3 — Classify all crops
252
+ predictions = classify_crops(crops, classifier, device)
253
+
254
+ # 4 — Annotate and save
255
+ out_path = out_dir / (img_path.stem + "_annotated" + img_path.suffix)
256
+ annotate_image(img_path, boxes, predictions, det_confs, out_path)
257
+
258
+ for i, (taxon_idx, state_idx, tx_conf, st_conf) in enumerate(predictions):
259
+ print(f" [{i + 1}] {TAXON_NAMES[taxon_idx]} ({tx_conf:.0%}) — {STATE_NAMES[state_idx]} ({st_conf:.0%})")
260
+
261
+ print(f" → {out_path}")
262
+
263
+
264
+ # ---------------------------------------------------------------------------
265
+ # Entry point
266
+ # ---------------------------------------------------------------------------
267
+
268
+ IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif"}
269
+
270
+
271
+ def main() -> None:
272
+ ap = argparse.ArgumentParser(description="LTN localize-and-classify pipeline")
273
+ ap.add_argument("input", type=Path, nargs="+", help="Image file(s) or director(y/ies)")
274
+ ap.add_argument("--yolo", type=Path, default=YOLO_WEIGHTS, help="YOLOv8 weights")
275
+ ap.add_argument("--clf", type=Path, default=CLF_WEIGHTS, help="Classifier checkpoint")
276
+ ap.add_argument("--out", type=Path, default=OUTPUT_DIR, help="Output directory")
277
+ ap.add_argument("--conf", type=float, default=DET_CONF, help="YOLO detection confidence threshold")
278
+ ap.add_argument("--device", type=str, default=DEVICE)
279
+ ap.add_argument("--batch", type=int, default=BATCH_SIZE, help="Classifier batch size")
280
+ args = ap.parse_args()
281
+
282
+ # Collect all image paths
283
+ img_paths: list[Path] = []
284
+ for p in args.input:
285
+ if p.is_dir():
286
+ img_paths.extend(f for f in sorted(p.iterdir()) if f.suffix.lower() in IMG_EXTS)
287
+ elif p.suffix.lower() in IMG_EXTS:
288
+ img_paths.append(p)
289
+ else:
290
+ print(f"Warning: skipping {p} (not a recognised image or directory)")
291
+
292
+ if not img_paths:
293
+ raise SystemExit("No valid image files found.")
294
+
295
+ print(f"Device : {args.device}")
296
+ print(f"Images : {len(img_paths)}")
297
+ print(f"Loading YOLOv8 from {args.yolo}")
298
+ yolo = YOLO(str(args.yolo))
299
+
300
+ print(f"Loading classifier from {args.clf}")
301
+ classifier = load_classifier(args.clf, args.device)
302
+
303
+ for img_path in img_paths:
304
+ run_pipeline(img_path, classifier=classifier, yolo=yolo,
305
+ device=args.device, conf=args.conf, out_dir=args.out)
306
+
307
+ print("\nDone. Results saved to:", args.out)
308
+
309
+
310
+ if __name__ == "__main__":
311
+ main()