rotsl commited on
Commit
e5a4e03
Β·
verified Β·
1 Parent(s): 6d9c79a

Add footer with developer attribution

Browse files
Files changed (1) hide show
  1. app.py +559 -1
app.py CHANGED
@@ -1 +1,559 @@
1
- app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ app.py β€” Gray Leaf Spot Colony Segmentation Pipeline
4
+ """
5
+
6
+ import csv, json, math, os, re, logging, tempfile, zipfile, io
7
+ import datetime as dt
8
+ from pathlib import Path
9
+
10
+ import cv2
11
+ import gradio as gr
12
+ import matplotlib
13
+ matplotlib.use("Agg")
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ import pandas as pd
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from PIL import Image, ExifTags
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%H:%M:%S")
24
+ log = logging.getLogger("app")
25
+
26
+ IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp"}
27
+ THUMB_SIZE = (160, 160)
28
+ DATE_RE = re.compile(r"(20\d{2})(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])")
29
+ MAX_IMAGES = 50
30
+ MODEL_REPO = "rotsl/grayleafspot-segmentation"
31
+ MODEL_FILE = "best_area_w_0.7.pt"
32
+ DISH_MM = 90.0
33
+ MODEL_SZ = 256
34
+ HF_TOKEN = os.environ.get("HF_TOKEN")
35
+ CSS = ".gallery-wrap{max-height:65vh;overflow-y:auto}"
36
+
37
+ # ── SmallUNet β€” exact architecture from model_small_unet.py ──
38
+
39
+ class ConvBlock(nn.Module):
40
+ def __init__(self, in_channels: int, out_channels: int) -> None:
41
+ super().__init__()
42
+ self.block = nn.Sequential(
43
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
44
+ nn.ReLU(inplace=True),
45
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
46
+ nn.ReLU(inplace=True),
47
+ )
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ return self.block(x)
51
+
52
+
53
+ class DownBlock(nn.Module):
54
+ def __init__(self, in_channels: int, out_channels: int) -> None:
55
+ super().__init__()
56
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
57
+ self.conv = ConvBlock(in_channels, out_channels)
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ return self.conv(self.pool(x))
61
+
62
+
63
+ class UpBlock(nn.Module):
64
+ def __init__(self, in_channels: int, out_channels: int) -> None:
65
+ super().__init__()
66
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
67
+ self.conv = ConvBlock(in_channels, out_channels)
68
+
69
+ def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
70
+ x = self.up(x)
71
+ if x.shape[-2:] != skip.shape[-2:]:
72
+ x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
73
+ x = torch.cat([skip, x], dim=1)
74
+ return self.conv(x)
75
+
76
+
77
+ class SmallUNet(nn.Module):
78
+ def __init__(self, in_channels=3, out_channels=1, base_channels=16):
79
+ super().__init__()
80
+ c1 = base_channels
81
+ c2 = base_channels * 2
82
+ c3 = base_channels * 4
83
+ c4 = base_channels * 8
84
+ bottleneck = base_channels * 16
85
+ self.enc1 = ConvBlock(in_channels, c1)
86
+ self.enc2 = DownBlock(c1, c2)
87
+ self.enc3 = DownBlock(c2, c3)
88
+ self.enc4 = DownBlock(c3, c4)
89
+ self.bottleneck = DownBlock(c4, bottleneck)
90
+ self.up4 = UpBlock(bottleneck + c4, c4)
91
+ self.up3 = UpBlock(c4 + c3, c3)
92
+ self.up2 = UpBlock(c3 + c2, c2)
93
+ self.up1 = UpBlock(c2 + c1, c1)
94
+ self.head = nn.Conv2d(c1, out_channels, kernel_size=1)
95
+ self.activation = nn.Sigmoid()
96
+
97
+ def forward(self, x):
98
+ s1 = self.enc1(x)
99
+ s2 = self.enc2(s1)
100
+ s3 = self.enc3(s2)
101
+ s4 = self.enc4(s3)
102
+ b = self.bottleneck(s4)
103
+ x = self.up4(b, s4)
104
+ x = self.up3(x, s3)
105
+ x = self.up2(x, s2)
106
+ x = self.up1(x, s1)
107
+ x = self.head(x)
108
+ return self.activation(x)
109
+
110
+ # ── Model loading ──
111
+ _model = None
112
+ def load_model():
113
+ global _model
114
+ if _model is None:
115
+ p = hf_hub_download(MODEL_REPO, MODEL_FILE, token=HF_TOKEN)
116
+ _model = SmallUNet(in_channels=3, out_channels=1, base_channels=16)
117
+ ckpt = torch.load(p, map_location="cpu", weights_only=False)
118
+ sd = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
119
+ _model.load_state_dict(sd, strict=True); _model.eval()
120
+ log.info("Model loaded: SmallUNet (%s)", MODEL_FILE)
121
+ return _model
122
+
123
+ # ── Core inference ──
124
+ def infer_image(img_pil, threshold):
125
+ model = load_model()
126
+ img_arr = np.array(img_pil.convert("RGB"))
127
+ img_resized = cv2.resize(img_arr, (MODEL_SZ, MODEL_SZ))
128
+ x = torch.from_numpy(img_resized.transpose(2, 0, 1)).float() / 255.0
129
+ x = x.unsqueeze(0)
130
+ with torch.no_grad():
131
+ prob = model(x)[0, 0].detach().cpu().numpy()
132
+ log.info(" output range: [%.4f, %.4f] mean=%.4f >0.5:%d >0.3:%d >0.1:%d",
133
+ prob.min(), prob.max(), prob.mean(),
134
+ (prob > 0.5).sum(), (prob > 0.3).sum(), (prob > 0.1).sum())
135
+ mask = (prob > threshold).astype(np.uint8) * 255
136
+ mask = cv2.resize(mask, (img_pil.width, img_pil.height), interpolation=cv2.INTER_NEAREST)
137
+ overlay = img_arr.copy()
138
+ overlay[mask > 0] = (overlay[mask > 0] * 0.5 + np.array([255, 0, 0]) * 0.5).astype(np.uint8)
139
+ return Image.fromarray(overlay), Image.fromarray(mask)
140
+
141
+ # ── Helpers ──
142
+ def make_thumbnail(p):
143
+ try: im = Image.open(p); im.thumbnail(THUMB_SIZE, Image.LANCZOS); return im
144
+ except: return Image.new("RGB", THUMB_SIZE, (200, 200, 200))
145
+
146
+ def detect_image_date(p):
147
+ m = DATE_RE.search(Path(p).stem)
148
+ if m:
149
+ try: return dt.date(int(m[1]), int(m[2]), int(m[3])).isoformat()
150
+ except: pass
151
+ try:
152
+ im = Image.open(p); ex = im.getexif()
153
+ if ex:
154
+ for tid, tn in ExifTags.TAGS.items():
155
+ if tn == "DateTimeOriginal":
156
+ v = ex.get(tid)
157
+ if v: return dt.datetime.strptime(v, "%Y:%m:%d %H:%M:%S").date().isoformat()
158
+ except: pass
159
+ try: return dt.date.fromtimestamp(os.path.getmtime(p)).isoformat()
160
+ except: return dt.date.today().isoformat()
161
+
162
+ def day_code(img_d, exp_d):
163
+ try: d = (dt.date.fromisoformat(img_d) - dt.date.fromisoformat(exp_d)).days + 1; return f"d{max(d,1):02d}"
164
+ except: return "d??"
165
+
166
+ def write_ics(rems, path):
167
+ L = ["BEGIN:VCALENDAR","VERSION:2.0","PRODID:-//FungalPipeline//EN"]
168
+ for r in rems:
169
+ uid = r["image_path"].replace("/","_")
170
+ ds = r["remind_me"].replace("-","").replace(" ","T").replace(":","") + "00"
171
+ L += ["BEGIN:VEVENT",f"UID:{uid}@fp",f"DTSTART:{ds}",
172
+ f"SUMMARY:Reminder - {r['experiment_name']}: {Path(r['image_path']).name}","END:VEVENT"]
173
+ L.append("END:VCALENDAR")
174
+ with open(path,"w") as f: f.write("\r\n".join(L))
175
+
176
+ def fig_to_pil(fig):
177
+ buf = io.BytesIO(); fig.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="white")
178
+ buf.seek(0); img = Image.open(buf).copy(); buf.close(); plt.close(fig); return img
179
+
180
+ # ── Full pipeline helpers (lazy skimage) ──
181
+ def _load_skimage():
182
+ from skimage import filters, measure, morphology
183
+ from skimage.filters import frangi, meijering
184
+ from skimage.morphology import skeletonize, disk, opening, closing, erosion, dilation
185
+ return filters, measure, morphology, frangi, meijering, skeletonize, disk, opening, closing, erosion, dilation
186
+
187
+ def detect_dish(img_bgr):
188
+ try:
189
+ gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
190
+ blurred = cv2.GaussianBlur(gray, (9,9), 2); h, w = gray.shape
191
+ mn, mx = int(min(h,w)*0.25), int(min(h,w)*0.52)
192
+ circles = cv2.HoughCircles(blurred, cv2.HOUGH_GRADIENT, dp=1.2, minDist=min(h,w)//2,
193
+ param1=100, param2=40, minRadius=mn, maxRadius=mx)
194
+ if circles is None: return None
195
+ circles = np.round(circles[0]).astype(int); ic,jc = w/2, h/2; bi,bs = 0,-1
196
+ for i,(cx,cy,r) in enumerate(circles):
197
+ s = r / (1 + math.hypot(cx-ic, cy-jc)/100)
198
+ if s > bs: bs=s; bi=i
199
+ cx,cy,r = int(circles[bi][0]), int(circles[bi][1]), int(circles[bi][2])
200
+ return cx, cy, r, DISH_MM/(2*r)
201
+ except: return None
202
+
203
+ def detect_cracks(gray, colony_mask):
204
+ filters,measure,_,_,_,_,disk,opening,_,erosion,_ = _load_skimage()
205
+ if colony_mask.sum() < 100: return np.zeros_like(colony_mask, dtype=bool)
206
+ interior = gray.copy(); interior[~colony_mask] = 0; er = erosion(colony_mask, disk(5))
207
+ iu = (interior*255 if interior.max()<=1 else interior).astype(np.uint8)
208
+ lt = filters.threshold_local(iu, block_size=51, method="gaussian")
209
+ dk = (iu < (lt-15)) & er; dk = opening(dk, disk(1)); lb = measure.label(dk)
210
+ cm = np.zeros_like(dk, dtype=bool)
211
+ for rp in measure.regionprops(lb):
212
+ if rp.area < 10: continue
213
+ if rp.major_axis_length > 0 and rp.minor_axis_length > 0:
214
+ if rp.major_axis_length/rp.minor_axis_length > 2.5 or rp.eccentricity > 0.85:
215
+ cm[lb==rp.label] = True
216
+ return cm
217
+
218
+ def detect_hyphae(gray, colony_mask):
219
+ _,_,_,frangi,meijering,skeletonize,disk,_,_,_,dilation = _load_skimage()
220
+ if colony_mask.sum() < 100:
221
+ z = np.zeros_like(colony_mask, dtype=bool); return z, z.copy(), z.copy()
222
+ g = gray.astype(np.float64); ex = dilation(colony_mask, disk(20))
223
+ fr = frangi(g, sigmas=range(1,5), black_ridges=False); fr[~ex]=0
224
+ th = fr[ex].mean()+2*fr[ex].std() if ex.sum()>0 else .01; fs = skeletonize(fr>th)
225
+ mr = meijering(g, sigmas=range(1,5), black_ridges=False); mr[~ex]=0
226
+ th2 = mr[ex].mean()+2*mr[ex].std() if ex.sum()>0 else .01; ms = skeletonize(mr>th2)
227
+ return fs, ms, fs|ms
228
+
229
+ def compute_metrics(mask_bool, gray, px2mm, dcx, dcy, crack_mask, hyph_f, hyph_m, hyph_h):
230
+ filters,measure,morphology,_,_,_,_,_,_,_,_ = _load_skimage()
231
+ mm2 = px2mm**2
232
+ if mask_bool.sum() < 50:
233
+ return {k:0 for k in ["area_mm2","diameter_mm","perimeter_mm","eccentricity","edge_roughness",
234
+ "centre_delta_mm","entropy","texture_std","crack_px","crack_area_mm2",
235
+ "crack_coverage_pct","crack_count","hyph_frangi_mm","hyph_meijering_mm","hyph_hybrid_mm"]}
236
+ pr = measure.regionprops(mask_bool.astype(np.uint8))[0]; R = {}
237
+ R["area_mm2"]=round(pr.area*mm2,4); pm=measure.perimeter(mask_bool)
238
+ R["perimeter_mm"]=round(pm*px2mm,4); R["diameter_mm"]=round(pr.equivalent_diameter_area*px2mm,4)
239
+ R["eccentricity"]=round(pr.eccentricity,6); eq=math.pi*pr.equivalent_diameter_area
240
+ R["edge_roughness"]=round(pm/eq,6) if eq>0 else 0; cy,cx=pr.centroid
241
+ R["centre_delta_mm"]=round(math.hypot(cx-dcx,cy-dcy)*px2mm,4)
242
+ gu8=(gray*255).astype(np.uint8) if gray.max()<=1 else gray.astype(np.uint8)
243
+ R["entropy"]=round(float(filters.rank.entropy(gu8,morphology.disk(5),mask=mask_bool)[mask_bool].mean()),6) if pr.area>100 else 0
244
+ R["texture_std"]=round(float(gray[mask_bool].std()),6)
245
+ R["crack_px"]=int(crack_mask.sum()); R["crack_area_mm2"]=round(crack_mask.sum()*mm2,6)
246
+ R["crack_coverage_pct"]=round(100*crack_mask.sum()/pr.area,4) if pr.area>0 else 0
247
+ R["crack_count"]=int(measure.label(crack_mask).max())
248
+ R["hyph_frangi_mm"]=round(int(hyph_f.sum())*px2mm,4)
249
+ R["hyph_meijering_mm"]=round(int(hyph_m.sum())*px2mm,4)
250
+ R["hyph_hybrid_mm"]=round(int(hyph_h.sum())*px2mm,4)
251
+ return R
252
+
253
+ def create_full_overlays(img_bgr, colony_mask, crack_mask, hyph_hybrid, dish_info, fname):
254
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB); h,w = img_bgr.shape[:2]
255
+ dcx,dcy,dr = (dish_info[0],dish_info[1],dish_info[2]) if dish_info else (w//2,h//2,min(h,w)//2)
256
+ p1=img_rgb.copy()
257
+ if dish_info: cv2.circle(p1,(dcx,dcy),dr,(0,255,0),3)
258
+ cts,_ = cv2.findContours(colony_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
259
+ cv2.drawContours(p1,cts,-1,(255,0,0),2)
260
+ p2=np.zeros_like(img_rgb); p2[colony_mask]=[255,255,255]
261
+ p3=img_rgb.copy()
262
+ if colony_mask.sum()>0: p3[colony_mask]=(p3[colony_mask].astype(np.float32)*0.5+np.array([255,0,0],dtype=np.float32)*0.5).astype(np.uint8)
263
+ if dish_info: cv2.circle(p3,(dcx,dcy),dr,(0,255,0),2)
264
+ p4=img_rgb.copy()
265
+ if crack_mask.sum()>0:
266
+ ck=cv2.dilate(crack_mask.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0
267
+ p4[ck]=(p4[ck].astype(np.float32)*0.3+np.array([255,255,0],dtype=np.float32)*0.7).astype(np.uint8)
268
+ if dish_info: cv2.circle(p4,(dcx,dcy),dr,(0,255,0),2); cv2.drawContours(p4,cts,-1,(255,0,0),1)
269
+ p5=img_rgb.copy()
270
+ if hyph_hybrid.sum()>0:
271
+ hy=cv2.dilate(hyph_hybrid.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0
272
+ p5[hy]=(p5[hy].astype(np.float32)*0.3+np.array([0,255,255],dtype=np.float32)*0.7).astype(np.uint8)
273
+ if dish_info: cv2.circle(p5,(dcx,dcy),dr,(0,255,0),2); cv2.drawContours(p5,cts,-1,(255,0,0),1)
274
+ p6=img_rgb.copy()
275
+ if colony_mask.sum()>0: p6[colony_mask]=(p6[colony_mask].astype(np.float32)*0.6+np.array([255,0,0],dtype=np.float32)*0.4).astype(np.uint8)
276
+ if crack_mask.sum()>0:
277
+ ck2=cv2.dilate(crack_mask.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0; p6[ck2]=[255,255,0]
278
+ if hyph_hybrid.sum()>0:
279
+ hy2=cv2.dilate(hyph_hybrid.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0; p6[hy2]=[0,255,255]
280
+ if dish_info: cv2.circle(p6,(dcx,dcy),dr,(0,255,0),2)
281
+ return [(Image.fromarray(p1),f"{fname} β€” Raw+Dish"),(Image.fromarray(p2),f"{fname} β€” Mask"),
282
+ (Image.fromarray(p3),f"{fname} β€” Colony"),(Image.fromarray(p4),f"{fname} β€” Cracks"),
283
+ (Image.fromarray(p5),f"{fname} β€” Hyphae"),(Image.fromarray(p6),f"{fname} β€” Combined")]
284
+
285
+ def make_growth_charts(results):
286
+ """Generate time-series charts for every morphometric parameter.
287
+
288
+ All spatial metrics are already in mm (or mmΒ²) via the per-image
289
+ px_to_mm calibration computed from dish detection, so images of
290
+ different resolutions are correctly comparable.
291
+ """
292
+ if len(results) < 2:
293
+ return []
294
+ df = pd.DataFrame(results)
295
+ if "error" in df.columns:
296
+ df = df[df["error"].fillna("").astype(str).str.strip() == ""].copy()
297
+ if len(df) < 2:
298
+ return []
299
+
300
+ # Coerce every plottable column to numeric
301
+ numeric_cols = [
302
+ "days_since_start", "area_mm2", "diameter_mm", "perimeter_mm",
303
+ "eccentricity", "edge_roughness", "centre_delta_mm",
304
+ "entropy", "texture_std",
305
+ "crack_area_mm2", "crack_coverage_pct", "crack_count",
306
+ "hyph_frangi_mm", "hyph_meijering_mm", "hyph_hybrid_mm",
307
+ "rgr_per_day", "relative_growth_per_day",
308
+ ]
309
+ for c in numeric_cols:
310
+ if c in df.columns:
311
+ df[c] = pd.to_numeric(df[c], errors="coerce")
312
+
313
+ df = df.sort_values("days_since_start").reset_index(drop=True)
314
+ charts = []
315
+
316
+ # ── Chart definitions ──────────────────────────────────────────────
317
+ # (column, y-label, title, colour, marker, fill_under_curve)
318
+ chart_defs = [
319
+ # Colony geometry
320
+ ("area_mm2", "Area (mmΒ²)", "Colony Area", "#e74c3c", "o", True),
321
+ ("diameter_mm", "Diameter (mm)", "Colony Diameter", "#2980b9", "s", False),
322
+ ("perimeter_mm", "Perimeter (mm)", "Colony Perimeter", "#8e44ad", "^", False),
323
+ # Shape descriptors
324
+ ("eccentricity", "Eccentricity", "Colony Eccentricity", "#e67e22", "D", False),
325
+ ("edge_roughness", "Edge Roughness", "Edge Roughness (P / Ο€d)", "#16a085", "v", False),
326
+ ("centre_delta_mm", "Centre Offset (mm)", "Colony Centre Offset", "#d35400", "p", False),
327
+ # Texture
328
+ ("entropy", "Entropy", "Colony Texture Entropy", "#7f8c8d", "h", False),
329
+ ("texture_std", "Texture Std Dev", "Colony Texture Std Dev", "#2c3e50", "*", False),
330
+ # Cracks
331
+ ("crack_area_mm2", "Crack Area (mmΒ²)", "Crack Area", "#f1c40f", "o", True),
332
+ ("crack_coverage_pct", "Crack Coverage (%)", "Crack Coverage", "#d4ac0d", "s", False),
333
+ ("crack_count", "Crack Count", "Number of Cracks", "#b7950b", "^", False),
334
+ # Hyphae
335
+ ("hyph_frangi_mm", "Length (mm)", "Hyphae Length β€” Frangi", "#1abc9c", "o", False),
336
+ ("hyph_meijering_mm", "Length (mm)", "Hyphae Length β€” Meijering", "#3498db", "s", False),
337
+ ("hyph_hybrid_mm", "Length (mm)", "Hyphae Length β€” Hybrid", "#2ecc71", "D", False),
338
+ # Growth rates (only present from image 2 onward)
339
+ ("rgr_per_day", "RGR (ln mmΒ² / day)", "Relative Growth Rate", "#c0392b", "o", False),
340
+ ("relative_growth_per_day", "Growth (mmΒ² / day)", "Absolute Growth Rate", "#e74c3c", "s", False),
341
+ ]
342
+
343
+ for col, ylabel, title, color, marker, fill in chart_defs:
344
+ if col not in df.columns:
345
+ continue
346
+ valid = df[col].notna()
347
+ # Also drop rows where the value was left as empty string
348
+ valid = valid & (df[col].astype(str).str.strip() != "")
349
+ if valid.sum() < 2:
350
+ continue
351
+ sub = df.loc[valid].copy()
352
+ fig, ax = plt.subplots(figsize=(8, 4))
353
+ ax.plot(sub["days_since_start"], sub[col], f"{marker}-",
354
+ color=color, lw=2, ms=8)
355
+ if fill:
356
+ ax.fill_between(sub["days_since_start"], 0, sub[col],
357
+ alpha=0.15, color=color)
358
+ ax.set(xlabel="Days", ylabel=ylabel, title=title)
359
+ ax.grid(True, alpha=0.3)
360
+ charts.append((fig_to_pil(fig), title))
361
+
362
+ return charts
363
+
364
+ # ═══════════════════════════════════════════════════════════════════════════
365
+ # Gradio UI
366
+ # ═══════════════════════════════════════════════════════════════════════════
367
+ with gr.Blocks(title="Gray Leaf Spot Segmentation", css=CSS) as demo:
368
+ paths_st=gr.State([]); dates_st=gr.State({}); rems_st=gr.State({}); cur_idx=gr.State(-1); results_st=gr.State([])
369
+
370
+ gr.Markdown("# πŸ”¬ Gray Leaf Spot Colony Segmentation\n"
371
+ "Upload β†’ **Run Inference** β†’ instant results | Toggle *Full Pipeline* for morphometrics\n\n"
372
+ "Model: [`rotsl/grayleafspot-segmentation/best_area_w_0.7.pt`]"
373
+ "(https://huggingface.co/rotsl/grayleafspot-segmentation) Β· SmallUNet (area-consistency w=0.7)")
374
+
375
+ with gr.Accordion("πŸ“‚ Step 1 β€” Upload Images", open=True):
376
+ upload = gr.File(label="Drag & drop petri dish images", file_count="multiple",
377
+ file_types=[".jpg",".jpeg",".png",".tif",".tiff",".bmp",".webp"])
378
+ up_st = gr.Markdown("")
379
+
380
+ with gr.Accordion("βš™οΈ Step 2 β€” Settings", open=True):
381
+ with gr.Row():
382
+ threshold_slider = gr.Slider(label="Mask confidence threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
383
+ full_pipeline_cb = gr.Checkbox(label="Full Pipeline (slower: dish, cracks, hyphae, morphometrics)", value=False)
384
+ with gr.Row():
385
+ exp_name = gr.Textbox(label="Experiment Name", placeholder="MagExp01")
386
+ exp_date = gr.Textbox(label="Experiment Date", placeholder="2025-04-01")
387
+ user_name = gr.Textbox(label="User Name", placeholder="Your name")
388
+ plates_count = gr.Number(label="Plates", value=1, minimum=1, maximum=200, precision=0)
389
+
390
+ with gr.Accordion("πŸ–ΌοΈ Step 3 β€” Review & Edit Dates", open=False):
391
+ gr.Markdown("*Click thumbnail β†’ edit date β†’ Save*")
392
+ with gr.Row():
393
+ with gr.Column(scale=2):
394
+ gallery = gr.Gallery(label="Images", columns=4, height=400, object_fit="contain", allow_preview=False, interactive=False)
395
+ with gr.Column(scale=1):
396
+ sel_img=gr.Image(label="Selected",height=200,interactive=False)
397
+ sel_fn=gr.Textbox(label="Filename",interactive=False)
398
+ sel_dt=gr.Textbox(label="Image Date",interactive=True)
399
+ sel_dc=gr.Textbox(label="Day Code",interactive=False)
400
+ sel_rm=gr.Textbox(label="Remind Me",placeholder="YYYY-MM-DD HH:MM",interactive=True)
401
+ sv_btn=gr.Button("πŸ’Ύ Save Date",variant="primary"); sv_st=gr.Markdown("")
402
+
403
+ with gr.Accordion("πŸ“₯ Step 4 β€” Export Metadata", open=False):
404
+ exp_btn=gr.Button("πŸ“₯ Export CSV / JSON / ICS",variant="primary"); exp_st=gr.Markdown("")
405
+ meta_preview=gr.Dataframe(label="image_metadata.csv",interactive=False,wrap=True)
406
+ meta_dl=gr.File(label="⬇️ Download metadata zip",interactive=False)
407
+
408
+ with gr.Accordion("πŸš€ Step 5 β€” Run Inference", open=True):
409
+ run_btn=gr.Button("πŸš€ Run Inference",variant="primary",size="lg"); run_st=gr.Markdown("")
410
+ gr.Markdown("### Results")
411
+ overlay_gallery=gr.Gallery(label="Segmentation results",columns=3,height=500,object_fit="contain",allow_preview=True)
412
+ gr.Markdown("### Growth Charts (full pipeline, β‰₯2 images)")
413
+ chart_gallery=gr.Gallery(label="Growth curves",columns=3,height=400,object_fit="contain",allow_preview=True)
414
+ gr.Markdown("### Results Table (full pipeline)")
415
+ results_df=gr.Dataframe(label="analysis_full.csv",interactive=False,wrap=True)
416
+ results_dl=gr.File(label="⬇️ Download analysis zip",interactive=False)
417
+
418
+ # ── Footer ──
419
+ gr.Markdown("---\nDeveloped by [Rohan R](https://rotsl.github.io/)")
420
+
421
+ # ── Handlers ──
422
+ def on_upload(files):
423
+ if not files: return [],[],{},[],"",-1
424
+ paths=[str(f) for f in files if Path(str(f)).suffix.lower() in IMAGE_EXTS][:MAX_IMAGES]
425
+ if not paths: return [],[],{},[],"",-1
426
+ dates={p:detect_image_date(p) for p in paths}; rems={p:"" for p in paths}
427
+ return paths,dates,rems,[(p,Path(p).name) for p in paths],f"βœ… **{len(paths)}** images loaded.",-1
428
+ upload.upload(on_upload,[upload],[paths_st,dates_st,rems_st,gallery,up_st,cur_idx])
429
+
430
+ def on_sel(paths,dates,rems,ed,evt:gr.SelectData):
431
+ i=evt.index
432
+ if i<0 or i>=len(paths): return -1,None,"","","",""
433
+ p=paths[i]; return i,make_thumbnail(p),Path(p).name,dates.get(p,""),day_code(dates.get(p,""),ed) if ed else "",rems.get(p,"")
434
+ gallery.select(on_sel,[paths_st,dates_st,rems_st,exp_date],[cur_idx,sel_img,sel_fn,sel_dt,sel_dc,sel_rm])
435
+
436
+ def on_save(paths,dates,rems,i,nd,nr,ed):
437
+ if i<0 or i>=len(paths): return dates,rems,"","⚠️ Select image."
438
+ p=paths[i]; dates=dict(dates); rems=dict(rems); dates[p]=nd; rems[p]=nr
439
+ return dates,rems,day_code(nd,ed) if ed else "",f"βœ… **{Path(p).name}** β†’ {nd}"
440
+ sv_btn.click(on_save,[paths_st,dates_st,rems_st,cur_idx,sel_dt,sel_rm,exp_date],[dates_st,rems_st,sel_dc,sv_st])
441
+
442
+ def on_export(paths,dates,rems,en,ed,un,pc):
443
+ if not paths: return "⚠️ Upload first.",None,None
444
+ tmp=tempfile.mkdtemp(); rows=[]; rl=[]
445
+ for p in paths:
446
+ imd=dates.get(p,detect_image_date(p)); rm=rems.get(p,"")
447
+ row=dict(image_path=Path(p).name,experiment_name=en or"",experiment_date=ed or"",
448
+ image_date=imd,day_code=day_code(imd,ed) if ed else"",user_name=un or"",
449
+ plates_count=int(pc) if pc else 1,remind_me=rm)
450
+ rows.append(row)
451
+ if rm.strip(): rl.append({**row})
452
+ cp=Path(tmp)/"image_metadata.csv"
453
+ with open(cp,"w",newline="") as f: w=csv.DictWriter(f,fieldnames=list(rows[0].keys())); w.writeheader(); w.writerows(rows)
454
+ jp=Path(tmp)/"image_metadata.json"
455
+ with open(jp,"w") as f: json.dump(rows,f,indent=2)
456
+ zf=[cp,jp]
457
+ if rl: ip=Path(tmp)/"reminders.ics"; write_ics(rl,str(ip)); zf.append(ip)
458
+ zp=Path(tmp)/"image_metadata.zip"
459
+ with zipfile.ZipFile(zp,"w") as z:
460
+ for f2 in zf: z.write(f2,f2.name)
461
+ return f"βœ… Exported **{len(rows)}** images.",pd.DataFrame(rows),str(zp)
462
+ exp_btn.click(on_export,[paths_st,dates_st,rems_st,exp_name,exp_date,user_name,plates_count],[exp_st,meta_preview,meta_dl])
463
+
464
+ def on_run(paths,dates,en,ed,un,pc,thresh,full_pipeline,progress=gr.Progress()):
465
+ if not paths: return "⚠️ Upload images first.",[],[],None,None,[]
466
+ try: load_model()
467
+ except Exception as e: return f"❌ Model failed: {e}",[],[],None,None,[]
468
+
469
+ results=[]; vis=[]; errors=[]
470
+
471
+ if not full_pipeline:
472
+ for p in progress.tqdm(paths, desc="Segmenting"):
473
+ try:
474
+ img=Image.open(p).convert("RGB")
475
+ overlay,mask=infer_image(img,thresh)
476
+ mask_px=np.sum(np.array(mask)>0)
477
+ vis.append((img,f"{Path(p).name} β€” Raw"))
478
+ vis.append((mask,f"{Path(p).name} β€” Mask"))
479
+ vis.append((overlay,f"{Path(p).name} β€” Overlay"))
480
+ log.info("%s: done (mask_pixels=%d, threshold=%.2f)", Path(p).name, mask_px, thresh)
481
+ except Exception as e:
482
+ log.error("%s: %s",Path(p).name,e); errors.append(f"{Path(p).name}: {e}")
483
+ em=f"\n\n⚠️ Errors: {'; '.join(errors)}" if errors else ""
484
+ ok=len(paths)-len(errors)
485
+ return f"βœ… **{ok}/{len(paths)}** segmented (fast mode, threshold={thresh:.2f}).{em}",vis,[],None,None,[]
486
+
487
+ # Full pipeline
488
+ for p in progress.tqdm(paths, desc="Full pipeline"):
489
+ imd=dates.get(p,detect_image_date(p))
490
+ try:
491
+ img_bgr=cv2.imread(str(p))
492
+ if img_bgr is None: raise RuntimeError(f"Cannot read: {p}")
493
+ model=load_model(); img_rgb=cv2.cvtColor(img_bgr,cv2.COLOR_BGR2RGB)
494
+ img_resized=cv2.resize(img_rgb,(MODEL_SZ,MODEL_SZ))
495
+ x=torch.from_numpy(img_resized.transpose(2,0,1)).float()/255.0; x=x.unsqueeze(0)
496
+ with torch.no_grad(): prob=model(x)[0,0].detach().cpu().numpy()
497
+ mask_small=(prob>thresh).astype(np.uint8)*255; h,w=img_bgr.shape[:2]
498
+ colony_mask=cv2.resize(mask_small,(w,h),interpolation=cv2.INTER_NEAREST)>0
499
+ dish_info=detect_dish(img_bgr)
500
+ gray=cv2.cvtColor(img_bgr,cv2.COLOR_BGR2GRAY).astype(np.float64)/255.0
501
+ crack_mask=detect_cracks(gray,colony_mask); hyph_f,hyph_m,hyph_h=detect_hyphae(gray,colony_mask)
502
+ if dish_info: dcx,dcy,dr,px2mm=dish_info
503
+ else: dcx,dcy=w//2,h//2; dr=min(h,w)//2; px2mm=1.0
504
+ metrics=compute_metrics(colony_mask,gray,px2mm,dcx,dcy,crack_mask,hyph_f,hyph_m,hyph_h)
505
+ metrics.update(colony_pixels=int(colony_mask.sum()),dish_detected=dish_info is not None,
506
+ dish_radius_px=dr,px_to_mm=round(px2mm,6),
507
+ calibration_diameter_mm=round(2*dr*px2mm,4),
508
+ calibration_error_pct=round(abs(2*dr*px2mm-90)/90*100,4) if dish_info else 0,
509
+ image_path=Path(p).name,experiment_name=en or"",experiment_date=ed or"",
510
+ image_date=imd,day_code=day_code(imd,ed) if ed else"",
511
+ user_name=un or"",plates_count=int(pc) if pc else 1)
512
+ results.append(metrics)
513
+ panels=create_full_overlays(img_bgr,colony_mask,crack_mask,hyph_h,dish_info,Path(p).name)
514
+ vis.extend(panels); log.info("%s: area=%.1f mmΒ²",Path(p).name,metrics["area_mm2"])
515
+ except Exception as e:
516
+ log.error("%s: %s",Path(p).name,e); errors.append(f"{Path(p).name}: {e}")
517
+ results.append({"image_path":Path(p).name,"error":str(e)})
518
+
519
+ ok_results=[r for r in results if not r.get("error")]
520
+ if len(ok_results)>1:
521
+ ok_results.sort(key=lambda r:r.get("image_date",""))
522
+ try: base=dt.date.fromisoformat(ok_results[0].get("image_date",""))
523
+ except: base=None
524
+ for i,r in enumerate(ok_results):
525
+ try: r["days_since_start"]=(dt.date.fromisoformat(r.get("image_date",""))-base).days if base else 0
526
+ except: r["days_since_start"]=0
527
+ if i==0: r["rgr_per_day"]=""; r["relative_growth_per_day"]=""; continue
528
+ prev=ok_results[i-1]
529
+ try:
530
+ dd=(dt.date.fromisoformat(r["image_date"])-dt.date.fromisoformat(prev["image_date"])).days
531
+ a2,a1=float(r.get("area_mm2",0)),float(prev.get("area_mm2",0))
532
+ if dd>0 and a1>0 and a2>0:
533
+ r["rgr_per_day"]=round((math.log(a2)-math.log(a1))/dd,6)
534
+ r["relative_growth_per_day"]=round((a2-a1)/dd,4)
535
+ else: r["rgr_per_day"]=""; r["relative_growth_per_day"]=""
536
+ except: r["rgr_per_day"]=""; r["relative_growth_per_day"]=""
537
+ chart_items=make_growth_charts(ok_results) if len(ok_results)>=2 else []
538
+ tmp=tempfile.mkdtemp(); all_results=ok_results+[r for r in results if r.get("error")]
539
+ cp=Path(tmp)/"analysis_full.csv"
540
+ if all_results:
541
+ ks=list(all_results[0].keys())
542
+ with open(cp,"w",newline="") as f: w=csv.DictWriter(f,fieldnames=ks,extrasaction="ignore"); w.writeheader(); w.writerows(all_results)
543
+ jp=Path(tmp)/"analysis_full.json"
544
+ with open(jp,"w") as f: json.dump(all_results,f,indent=2,default=str)
545
+ for i,(cimg,cap) in enumerate(chart_items): cimg.save(str(Path(tmp)/f"chart_{i}.png"))
546
+ zp=Path(tmp)/"analysis_full.zip"
547
+ with zipfile.ZipFile(zp,"w") as z:
548
+ for fp in Path(tmp).glob("*"):
549
+ if fp.name!="analysis_full.zip": z.write(fp,fp.name)
550
+ em=f"\n\n⚠️ Errors: {'; '.join(errors)}" if errors else ""
551
+ cm=f"\n\nπŸ“Š **{len(chart_items)} charts**" if chart_items else ""
552
+ return (f"βœ… **{len(ok_results)}/{len(results)}** analyzed.{cm}{em}",
553
+ vis,chart_items,pd.DataFrame(all_results),str(zp),all_results)
554
+
555
+ run_btn.click(on_run,[paths_st,dates_st,exp_name,exp_date,user_name,plates_count,threshold_slider,full_pipeline_cb],
556
+ [run_st,overlay_gallery,chart_gallery,results_df,results_dl,results_st])
557
+
558
+ if __name__=="__main__":
559
+ demo.launch(server_name="0.0.0.0",server_port=7860)