Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| app.py — Gray Leaf Spot Colony Segmentation Pipeline | |
| """ | |
| import csv, json, math, os, re, logging, tempfile, zipfile, io | |
| import datetime as dt | |
| from pathlib import Path | |
| import cv2 | |
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image, ExifTags | |
| from huggingface_hub import hf_hub_download | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%H:%M:%S") | |
| log = logging.getLogger("app") | |
| IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp"} | |
| THUMB_SIZE = (160, 160) | |
| DATE_RE = re.compile(r"(20\d{2})(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])") | |
| MAX_IMAGES = 50 | |
| MODEL_REPO = "rotsl/grayleafspot-segmentation" | |
| MODEL_FILE = "best_area_w_0.7.pt" | |
| DISH_MM = 90.0 | |
| MODEL_SZ = 256 | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| CSS = ".gallery-wrap{max-height:65vh;overflow-y:auto} .footer-row{display:flex;justify-content:center;width:100%} .footer-row p, .footer-row a{text-align:center}" | |
| # ── SmallUNet — exact architecture from model_small_unet.py ── | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int) -> None: | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.block(x) | |
| class DownBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int) -> None: | |
| super().__init__() | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.conv = ConvBlock(in_channels, out_channels) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.conv(self.pool(x)) | |
| class UpBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int) -> None: | |
| super().__init__() | |
| self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) | |
| self.conv = ConvBlock(in_channels, out_channels) | |
| def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: | |
| x = self.up(x) | |
| if x.shape[-2:] != skip.shape[-2:]: | |
| x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False) | |
| x = torch.cat([skip, x], dim=1) | |
| return self.conv(x) | |
| class SmallUNet(nn.Module): | |
| def __init__(self, in_channels=3, out_channels=1, base_channels=16): | |
| super().__init__() | |
| c1 = base_channels | |
| c2 = base_channels * 2 | |
| c3 = base_channels * 4 | |
| c4 = base_channels * 8 | |
| bottleneck = base_channels * 16 | |
| self.enc1 = ConvBlock(in_channels, c1) | |
| self.enc2 = DownBlock(c1, c2) | |
| self.enc3 = DownBlock(c2, c3) | |
| self.enc4 = DownBlock(c3, c4) | |
| self.bottleneck = DownBlock(c4, bottleneck) | |
| self.up4 = UpBlock(bottleneck + c4, c4) | |
| self.up3 = UpBlock(c4 + c3, c3) | |
| self.up2 = UpBlock(c3 + c2, c2) | |
| self.up1 = UpBlock(c2 + c1, c1) | |
| self.head = nn.Conv2d(c1, out_channels, kernel_size=1) | |
| self.activation = nn.Sigmoid() | |
| def forward(self, x): | |
| s1 = self.enc1(x) | |
| s2 = self.enc2(s1) | |
| s3 = self.enc3(s2) | |
| s4 = self.enc4(s3) | |
| b = self.bottleneck(s4) | |
| x = self.up4(b, s4) | |
| x = self.up3(x, s3) | |
| x = self.up2(x, s2) | |
| x = self.up1(x, s1) | |
| x = self.head(x) | |
| return self.activation(x) | |
| # ── Model loading ── | |
| _model = None | |
| def load_model(): | |
| global _model | |
| if _model is None: | |
| p = hf_hub_download(MODEL_REPO, MODEL_FILE, token=HF_TOKEN) | |
| _model = SmallUNet(in_channels=3, out_channels=1, base_channels=16) | |
| ckpt = torch.load(p, map_location="cpu", weights_only=False) | |
| sd = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt | |
| _model.load_state_dict(sd, strict=True); _model.eval() | |
| log.info("Model loaded: SmallUNet (%s)", MODEL_FILE) | |
| return _model | |
| # ── Core inference ── | |
| def infer_image(img_pil, threshold): | |
| model = load_model() | |
| img_arr = np.array(img_pil.convert("RGB")) | |
| img_resized = cv2.resize(img_arr, (MODEL_SZ, MODEL_SZ)) | |
| x = torch.from_numpy(img_resized.transpose(2, 0, 1)).float() / 255.0 | |
| x = x.unsqueeze(0) | |
| with torch.no_grad(): | |
| prob = model(x)[0, 0].detach().cpu().numpy() | |
| log.info(" output range: [%.4f, %.4f] mean=%.4f >0.5:%d >0.3:%d >0.1:%d", | |
| prob.min(), prob.max(), prob.mean(), | |
| (prob > 0.5).sum(), (prob > 0.3).sum(), (prob > 0.1).sum()) | |
| mask = (prob > threshold).astype(np.uint8) * 255 | |
| mask = cv2.resize(mask, (img_pil.width, img_pil.height), interpolation=cv2.INTER_NEAREST) | |
| overlay = img_arr.copy() | |
| overlay[mask > 0] = (overlay[mask > 0] * 0.5 + np.array([255, 0, 0]) * 0.5).astype(np.uint8) | |
| return Image.fromarray(overlay), Image.fromarray(mask) | |
| # ── Helpers ── | |
| def make_thumbnail(p): | |
| try: im = Image.open(p); im.thumbnail(THUMB_SIZE, Image.LANCZOS); return im | |
| except: return Image.new("RGB", THUMB_SIZE, (200, 200, 200)) | |
| def detect_image_date(p): | |
| """Detect image date from embedded EXIF metadata, then filename, then mtime. | |
| Priority: | |
| 1. EXIF DateTimeOriginal (Exif IFD 0x9003) — when the photo was taken | |
| 2. EXIF DateTimeDigitized (Exif IFD 0x9004) — when it was digitised | |
| 3. EXIF DateTime (IFD0 0x0132) — last modified in-camera | |
| 4. Filename regex (YYYYMMDD pattern in the file stem) | |
| 5. File modification time (unreliable after copy / Gradio upload) | |
| """ | |
| # ── 1-3. EXIF metadata (embedded in file bytes, survives Gradio copy) ── | |
| try: | |
| im = Image.open(p); ex = im.getexif() | |
| if ex: | |
| # DateTimeOriginal & DateTimeDigitized live in the Exif sub-IFD | |
| try: exif_ifd = ex.get_ifd(0x8769) | |
| except Exception: exif_ifd = {} | |
| for tag_id in (0x9003, 0x9004): # DateTimeOriginal, DateTimeDigitized | |
| v = exif_ifd.get(tag_id) | |
| if v and isinstance(v, str) and len(v) >= 10: | |
| try: return dt.datetime.strptime(v.strip(), "%Y:%m:%d %H:%M:%S").date().isoformat() | |
| except ValueError: pass | |
| # DateTime lives in the root IFD | |
| v = ex.get(0x0132) | |
| if v and isinstance(v, str) and len(v) >= 10: | |
| try: return dt.datetime.strptime(v.strip(), "%Y:%m:%d %H:%M:%S").date().isoformat() | |
| except ValueError: pass | |
| except Exception: pass | |
| # ── 4. Filename regex ── | |
| m = DATE_RE.search(Path(p).stem) | |
| if m: | |
| try: return dt.date(int(m[1]), int(m[2]), int(m[3])).isoformat() | |
| except: pass | |
| # ── 5. File mtime (last resort) ── | |
| try: return dt.date.fromtimestamp(os.path.getmtime(p)).isoformat() | |
| except: return dt.date.today().isoformat() | |
| def day_code(img_d, exp_d): | |
| """Day code = (image_date − experiment_date) + 1, formatted as d01, d03, etc.""" | |
| try: d = (dt.date.fromisoformat(img_d) - dt.date.fromisoformat(exp_d)).days + 1; return f"d{max(d,1):02d}" | |
| except: return "d??" | |
| def write_ics(rems, path): | |
| L = ["BEGIN:VCALENDAR","VERSION:2.0","PRODID:-//FungalPipeline//EN"] | |
| for r in rems: | |
| uid = r["image_path"].replace("/","_") | |
| ds = r["remind_me"].replace("-","").replace(" ","T").replace(":","") + "00" | |
| L += ["BEGIN:VEVENT",f"UID:{uid}@fp",f"DTSTART:{ds}", | |
| f"SUMMARY:Reminder - {r['experiment_name']}: {Path(r['image_path']).name}","END:VEVENT"] | |
| L.append("END:VCALENDAR") | |
| with open(path,"w") as f: f.write("\r\n".join(L)) | |
| def fig_to_pil(fig): | |
| buf = io.BytesIO(); fig.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="white") | |
| buf.seek(0); img = Image.open(buf).copy(); buf.close(); plt.close(fig); return img | |
| # ── Full pipeline helpers (lazy skimage) ── | |
| def _load_skimage(): | |
| from skimage import filters, measure, morphology | |
| from skimage.filters import frangi, meijering | |
| from skimage.morphology import skeletonize, disk, opening, closing, erosion, dilation | |
| return filters, measure, morphology, frangi, meijering, skeletonize, disk, opening, closing, erosion, dilation | |
| def detect_dish(img_bgr): | |
| try: | |
| gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) | |
| blurred = cv2.GaussianBlur(gray, (9,9), 2); h, w = gray.shape | |
| mn, mx = int(min(h,w)*0.25), int(min(h,w)*0.52) | |
| circles = cv2.HoughCircles(blurred, cv2.HOUGH_GRADIENT, dp=1.2, minDist=min(h,w)//2, | |
| param1=100, param2=40, minRadius=mn, maxRadius=mx) | |
| if circles is None: return None | |
| circles = np.round(circles[0]).astype(int); ic,jc = w/2, h/2; bi,bs = 0,-1 | |
| for i,(cx,cy,r) in enumerate(circles): | |
| s = r / (1 + math.hypot(cx-ic, cy-jc)/100) | |
| if s > bs: bs=s; bi=i | |
| cx,cy,r = int(circles[bi][0]), int(circles[bi][1]), int(circles[bi][2]) | |
| return cx, cy, r, DISH_MM/(2*r) | |
| except: return None | |
| def detect_cracks(gray, colony_mask): | |
| filters,measure,_,_,_,_,disk,opening,_,erosion,_ = _load_skimage() | |
| if colony_mask.sum() < 100: return np.zeros_like(colony_mask, dtype=bool) | |
| interior = gray.copy(); interior[~colony_mask] = 0; er = erosion(colony_mask, disk(5)) | |
| iu = (interior*255 if interior.max()<=1 else interior).astype(np.uint8) | |
| lt = filters.threshold_local(iu, block_size=51, method="gaussian") | |
| dk = (iu < (lt-15)) & er; dk = opening(dk, disk(1)); lb = measure.label(dk) | |
| cm = np.zeros_like(dk, dtype=bool) | |
| for rp in measure.regionprops(lb): | |
| if rp.area < 10: continue | |
| if rp.major_axis_length > 0 and rp.minor_axis_length > 0: | |
| if rp.major_axis_length/rp.minor_axis_length > 2.5 or rp.eccentricity > 0.85: | |
| cm[lb==rp.label] = True | |
| return cm | |
| def detect_hyphae(gray, colony_mask): | |
| _,_,_,frangi,meijering,skeletonize,disk,_,_,_,dilation = _load_skimage() | |
| if colony_mask.sum() < 100: | |
| z = np.zeros_like(colony_mask, dtype=bool); return z, z.copy(), z.copy() | |
| g = gray.astype(np.float64); ex = dilation(colony_mask, disk(20)) | |
| fr = frangi(g, sigmas=range(1,5), black_ridges=False); fr[~ex]=0 | |
| th = fr[ex].mean()+2*fr[ex].std() if ex.sum()>0 else .01; fs = skeletonize(fr>th) | |
| mr = meijering(g, sigmas=range(1,5), black_ridges=False); mr[~ex]=0 | |
| th2 = mr[ex].mean()+2*mr[ex].std() if ex.sum()>0 else .01; ms = skeletonize(mr>th2) | |
| return fs, ms, fs|ms | |
| def compute_metrics(mask_bool, gray, px2mm, dcx, dcy, crack_mask, hyph_f, hyph_m, hyph_h): | |
| filters,measure,morphology,_,_,_,_,_,_,_,_ = _load_skimage() | |
| mm2 = px2mm**2 | |
| if mask_bool.sum() < 50: | |
| return {k:0 for k in ["area_mm2","diameter_mm","perimeter_mm","eccentricity","edge_roughness", | |
| "centre_delta_mm","entropy","texture_std","crack_px","crack_area_mm2", | |
| "crack_coverage_pct","crack_count","hyph_frangi_mm","hyph_meijering_mm","hyph_hybrid_mm"]} | |
| pr = measure.regionprops(mask_bool.astype(np.uint8))[0]; R = {} | |
| R["area_mm2"]=round(pr.area*mm2,4); pm=measure.perimeter(mask_bool) | |
| R["perimeter_mm"]=round(pm*px2mm,4); R["diameter_mm"]=round(pr.equivalent_diameter_area*px2mm,4) | |
| R["eccentricity"]=round(pr.eccentricity,6); eq=math.pi*pr.equivalent_diameter_area | |
| R["edge_roughness"]=round(pm/eq,6) if eq>0 else 0; cy,cx=pr.centroid | |
| R["centre_delta_mm"]=round(math.hypot(cx-dcx,cy-dcy)*px2mm,4) | |
| gu8=(gray*255).astype(np.uint8) if gray.max()<=1 else gray.astype(np.uint8) | |
| R["entropy"]=round(float(filters.rank.entropy(gu8,morphology.disk(5),mask=mask_bool)[mask_bool].mean()),6) if pr.area>100 else 0 | |
| R["texture_std"]=round(float(gray[mask_bool].std()),6) | |
| R["crack_px"]=int(crack_mask.sum()); R["crack_area_mm2"]=round(crack_mask.sum()*mm2,6) | |
| R["crack_coverage_pct"]=round(100*crack_mask.sum()/pr.area,4) if pr.area>0 else 0 | |
| R["crack_count"]=int(measure.label(crack_mask).max()) | |
| R["hyph_frangi_mm"]=round(int(hyph_f.sum())*px2mm,4) | |
| R["hyph_meijering_mm"]=round(int(hyph_m.sum())*px2mm,4) | |
| R["hyph_hybrid_mm"]=round(int(hyph_h.sum())*px2mm,4) | |
| return R | |
| def create_full_overlays(img_bgr, colony_mask, crack_mask, hyph_hybrid, dish_info, fname): | |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB); h,w = img_bgr.shape[:2] | |
| dcx,dcy,dr = (dish_info[0],dish_info[1],dish_info[2]) if dish_info else (w//2,h//2,min(h,w)//2) | |
| p1=img_rgb.copy() | |
| if dish_info: cv2.circle(p1,(dcx,dcy),dr,(0,255,0),3) | |
| cts,_ = cv2.findContours(colony_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(p1,cts,-1,(255,0,0),2) | |
| p2=np.zeros_like(img_rgb); p2[colony_mask]=[255,255,255] | |
| p3=img_rgb.copy() | |
| if colony_mask.sum()>0: p3[colony_mask]=(p3[colony_mask].astype(np.float32)*0.5+np.array([255,0,0],dtype=np.float32)*0.5).astype(np.uint8) | |
| if dish_info: cv2.circle(p3,(dcx,dcy),dr,(0,255,0),2) | |
| p4=img_rgb.copy() | |
| if crack_mask.sum()>0: | |
| ck=cv2.dilate(crack_mask.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0 | |
| p4[ck]=(p4[ck].astype(np.float32)*0.3+np.array([255,255,0],dtype=np.float32)*0.7).astype(np.uint8) | |
| if dish_info: cv2.circle(p4,(dcx,dcy),dr,(0,255,0),2); cv2.drawContours(p4,cts,-1,(255,0,0),1) | |
| p5=img_rgb.copy() | |
| if hyph_hybrid.sum()>0: | |
| hy=cv2.dilate(hyph_hybrid.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0 | |
| p5[hy]=(p5[hy].astype(np.float32)*0.3+np.array([0,255,255],dtype=np.float32)*0.7).astype(np.uint8) | |
| if dish_info: cv2.circle(p5,(dcx,dcy),dr,(0,255,0),2); cv2.drawContours(p5,cts,-1,(255,0,0),1) | |
| p6=img_rgb.copy() | |
| if colony_mask.sum()>0: p6[colony_mask]=(p6[colony_mask].astype(np.float32)*0.6+np.array([255,0,0],dtype=np.float32)*0.4).astype(np.uint8) | |
| if crack_mask.sum()>0: | |
| ck2=cv2.dilate(crack_mask.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0; p6[ck2]=[255,255,0] | |
| if hyph_hybrid.sum()>0: | |
| hy2=cv2.dilate(hyph_hybrid.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0; p6[hy2]=[0,255,255] | |
| if dish_info: cv2.circle(p6,(dcx,dcy),dr,(0,255,0),2) | |
| return [(Image.fromarray(p1),f"{fname} — Raw+Dish"),(Image.fromarray(p2),f"{fname} — Mask"), | |
| (Image.fromarray(p3),f"{fname} — Colony"),(Image.fromarray(p4),f"{fname} — Cracks"), | |
| (Image.fromarray(p5),f"{fname} — Hyphae"),(Image.fromarray(p6),f"{fname} — Combined")] | |
| def make_growth_charts(results): | |
| if len(results) < 2: return [] | |
| df = pd.DataFrame(results) | |
| if "error" in df.columns: df = df[df["error"].fillna("").astype(str).str.strip() == ""].copy() | |
| if len(df) < 2: return [] | |
| numeric_cols = ["days_since_start","area_mm2","diameter_mm","perimeter_mm","eccentricity","edge_roughness","centre_delta_mm","entropy","texture_std","crack_area_mm2","crack_coverage_pct","crack_count","hyph_frangi_mm","hyph_meijering_mm","hyph_hybrid_mm","rgr_per_day","relative_growth_per_day"] | |
| for c in numeric_cols: | |
| if c in df.columns: df[c] = pd.to_numeric(df[c], errors="coerce") | |
| df = df.sort_values("days_since_start").reset_index(drop=True); charts = [] | |
| chart_defs = [("area_mm2","Area (mm²)","Colony Area","#e74c3c","o",True),("diameter_mm","Diameter (mm)","Colony Diameter","#2980b9","s",False),("perimeter_mm","Perimeter (mm)","Colony Perimeter","#8e44ad","^",False),("eccentricity","Eccentricity","Colony Eccentricity","#e67e22","D",False),("edge_roughness","Edge Roughness","Edge Roughness (P / πd)","#16a085","v",False),("centre_delta_mm","Centre Offset (mm)","Colony Centre Offset","#d35400","p",False),("entropy","Entropy","Colony Texture Entropy","#7f8c8d","h",False),("texture_std","Texture Std Dev","Colony Texture Std Dev","#2c3e50","*",False),("crack_area_mm2","Crack Area (mm²)","Crack Area","#f1c40f","o",True),("crack_coverage_pct","Crack Coverage (%)","Crack Coverage","#d4ac0d","s",False),("crack_count","Crack Count","Number of Cracks","#b7950b","^",False),("hyph_frangi_mm","Length (mm)","Hyphae Length — Frangi","#1abc9c","o",False),("hyph_meijering_mm","Length (mm)","Hyphae Length — Meijering","#3498db","s",False),("hyph_hybrid_mm","Length (mm)","Hyphae Length — Hybrid","#2ecc71","D",False),("rgr_per_day","RGR (ln mm² / day)","Relative Growth Rate","#c0392b","o",False),("relative_growth_per_day","Growth (mm² / day)","Absolute Growth Rate","#e74c3c","s",False)] | |
| for col, ylabel, title, color, marker, fill in chart_defs: | |
| if col not in df.columns: continue | |
| valid = df[col].notna() & (df[col].astype(str).str.strip() != "") | |
| if valid.sum() < 2: continue | |
| sub = df.loc[valid].copy(); fig, ax = plt.subplots(figsize=(8, 4)) | |
| ax.plot(sub["days_since_start"], sub[col], f"{marker}-", color=color, lw=2, ms=8) | |
| if fill: ax.fill_between(sub["days_since_start"], 0, sub[col], alpha=0.15, color=color) | |
| ax.set(xlabel="Day Code", ylabel=ylabel, title=title); ax.grid(True, alpha=0.3) | |
| charts.append((fig_to_pil(fig), title)) | |
| return charts | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Gradio UI | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| with gr.Blocks(title="Gray Leaf Spot Segmentation", css=CSS) as demo: | |
| paths_st=gr.State([]); dates_st=gr.State({}); rems_st=gr.State({}); cur_idx=gr.State(-1); results_st=gr.State([]) | |
| gr.Markdown("# 🔬 Gray Leaf Spot Colony Segmentation\n" | |
| "Upload → **Run Inference** → instant results | Toggle *Full Pipeline* for morphometrics\n\n" | |
| "Model: [`rotsl/grayleafspot-segmentation/best_area_w_0.7.pt`]" | |
| "(https://huggingface.co/rotsl/grayleafspot-segmentation) · SmallUNet (area-consistency w=0.7)") | |
| with gr.Accordion("📂 Step 1 — Upload Images", open=True): | |
| upload = gr.File(label="Drag & drop petri dish images", file_count="multiple", file_types=["image"]) | |
| up_st = gr.Markdown("") | |
| with gr.Accordion("⚙️ Step 2 — Settings", open=True): | |
| with gr.Row(): | |
| threshold_slider = gr.Slider(label="Mask confidence threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.01) | |
| full_pipeline_cb = gr.Checkbox(label="Full Pipeline (slower: dish, cracks, hyphae, morphometrics)", value=False) | |
| with gr.Row(): | |
| exp_name = gr.Textbox(label="Experiment Name", placeholder="MagExp01") | |
| exp_date = gr.Textbox(label="Experiment Date", placeholder="2025-04-01") | |
| user_name = gr.Textbox(label="User Name", placeholder="Your name") | |
| plates_count = gr.Number(label="Plates", value=1, minimum=1, maximum=200, precision=0) | |
| with gr.Accordion("🖼️ Step 3 — Review & Edit Dates", open=False): | |
| gr.Markdown("*Click thumbnail → edit date → Save*") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gallery = gr.Gallery(label="Images", columns=4, height=400, object_fit="contain", allow_preview=False, interactive=False) | |
| with gr.Column(scale=1): | |
| sel_img=gr.Image(label="Selected",height=200,interactive=False) | |
| sel_fn=gr.Textbox(label="Filename",interactive=False) | |
| sel_dt=gr.Textbox(label="Image Date",interactive=True) | |
| sel_dc=gr.Textbox(label="Day Code",interactive=False) | |
| sel_rm=gr.Textbox(label="Remind Me",placeholder="YYYY-MM-DD HH:MM",interactive=True) | |
| sv_btn=gr.Button("💾 Save Date",variant="primary"); sv_st=gr.Markdown("") | |
| with gr.Accordion("📥 Step 4 — Export Metadata", open=False): | |
| exp_btn=gr.Button("📥 Export CSV / JSON / ICS",variant="primary"); exp_st=gr.Markdown("") | |
| meta_preview=gr.Dataframe(label="image_metadata.csv",interactive=False,wrap=True) | |
| meta_dl=gr.File(label="⬇️ Download metadata zip",interactive=False) | |
| with gr.Accordion("🚀 Step 5 — Run Inference", open=True): | |
| run_btn=gr.Button("🚀 Run Inference",variant="primary",size="lg"); run_st=gr.Markdown("") | |
| gr.Markdown("### Results") | |
| overlay_gallery=gr.Gallery(label="Segmentation results",columns=3,height=500,object_fit="contain",allow_preview=True) | |
| gr.Markdown("### Growth Charts (full pipeline, ≥2 images)") | |
| chart_gallery=gr.Gallery(label="Growth curves",columns=3,height=400,object_fit="contain",allow_preview=True) | |
| gr.Markdown("### Results Table (full pipeline)") | |
| results_df=gr.Dataframe(label="analysis_full.csv",interactive=False,wrap=True) | |
| results_dl=gr.File(label="⬇️ Download analysis zip",interactive=False) | |
| with gr.Row(elem_classes="footer-row"): | |
| gr.Markdown("---\nDeveloped by [Rohan R](https://rotsl.github.io/)") | |
| # ── Handlers ── | |
| def on_upload(files): | |
| if not files: return [],[],{},[],"",-1 | |
| paths=[str(f) for f in files if Path(str(f)).suffix.lower() in IMAGE_EXTS][:MAX_IMAGES] | |
| if not paths: return [],[],{},[],"",-1 | |
| dates={p:detect_image_date(p) for p in paths}; rems={p:"" for p in paths} | |
| return paths,dates,rems,[(p,Path(p).name) for p in paths],f"✅ **{len(paths)}** images loaded.",-1 | |
| upload.upload(on_upload,[upload],[paths_st,dates_st,rems_st,gallery,up_st,cur_idx]) | |
| def on_sel(paths,dates,rems,ed,evt:gr.SelectData): | |
| i=evt.index | |
| if i<0 or i>=len(paths): return -1,None,"","","","" | |
| p=paths[i]; return i,make_thumbnail(p),Path(p).name,dates.get(p,""),day_code(dates.get(p,""),ed) if ed else "",rems.get(p,"") | |
| gallery.select(on_sel,[paths_st,dates_st,rems_st,exp_date],[cur_idx,sel_img,sel_fn,sel_dt,sel_dc,sel_rm]) | |
| def on_save(paths,dates,rems,i,nd,nr,ed): | |
| if i<0 or i>=len(paths): return dates,rems,"","⚠️ Select image." | |
| p=paths[i]; dates=dict(dates); rems=dict(rems); dates[p]=nd; rems[p]=nr | |
| return dates,rems,day_code(nd,ed) if ed else "",f"✅ **{Path(p).name}** → {nd}" | |
| sv_btn.click(on_save,[paths_st,dates_st,rems_st,cur_idx,sel_dt,sel_rm,exp_date],[dates_st,rems_st,sel_dc,sv_st]) | |
| def on_export(paths,dates,rems,en,ed,un,pc): | |
| if not paths: return "⚠️ Upload first.",None,None | |
| tmp=tempfile.mkdtemp(); rows=[]; rl=[] | |
| for p in paths: | |
| imd=dates.get(p,detect_image_date(p)); rm=rems.get(p,"") | |
| row=dict(image_path=Path(p).name,experiment_name=en or"",experiment_date=ed or"", | |
| image_date=imd,day_code=day_code(imd,ed) if ed else"",user_name=un or"", | |
| plates_count=int(pc) if pc else 1,remind_me=rm) | |
| rows.append(row) | |
| if rm.strip(): rl.append({**row}) | |
| cp=Path(tmp)/"image_metadata.csv" | |
| with open(cp,"w",newline="") as f: w=csv.DictWriter(f,fieldnames=list(rows[0].keys())); w.writeheader(); w.writerows(rows) | |
| jp=Path(tmp)/"image_metadata.json" | |
| with open(jp,"w") as f: json.dump(rows,f,indent=2) | |
| zf=[cp,jp] | |
| if rl: ip=Path(tmp)/"reminders.ics"; write_ics(rl,str(ip)); zf.append(ip) | |
| zp=Path(tmp)/"image_metadata.zip" | |
| with zipfile.ZipFile(zp,"w") as z: | |
| for f2 in zf: z.write(f2,f2.name) | |
| return f"✅ Exported **{len(rows)}** images.",pd.DataFrame(rows),str(zp) | |
| exp_btn.click(on_export,[paths_st,dates_st,rems_st,exp_name,exp_date,user_name,plates_count],[exp_st,meta_preview,meta_dl]) | |
| def on_run(paths,dates,en,ed,un,pc,thresh,full_pipeline,progress=gr.Progress()): | |
| if not paths: return "⚠️ Upload images first.",[],[],None,None,[] | |
| try: load_model() | |
| except Exception as e: return f"❌ Model failed: {e}",[],[],None,None,[] | |
| results=[]; vis=[]; errors=[] | |
| if not full_pipeline: | |
| for p in progress.tqdm(paths, desc="Segmenting"): | |
| try: | |
| img=Image.open(p).convert("RGB"); overlay,mask=infer_image(img,thresh) | |
| mask_px=np.sum(np.array(mask)>0) | |
| vis.append((img,f"{Path(p).name} — Raw")); vis.append((mask,f"{Path(p).name} — Mask")); vis.append((overlay,f"{Path(p).name} — Overlay")) | |
| log.info("%s: done (mask_pixels=%d, threshold=%.2f)", Path(p).name, mask_px, thresh) | |
| except Exception as e: log.error("%s: %s",Path(p).name,e); errors.append(f"{Path(p).name}: {e}") | |
| em=f"\n\n⚠️ Errors: {'; '.join(errors)}" if errors else "" | |
| return f"✅ **{len(paths)-len(errors)}/{len(paths)}** segmented (fast mode, threshold={thresh:.2f}).{em}",vis,[],None,None,[] | |
| for p in progress.tqdm(paths, desc="Full pipeline"): | |
| imd=dates.get(p,detect_image_date(p)) | |
| try: | |
| img_bgr=cv2.imread(str(p)) | |
| if img_bgr is None: raise RuntimeError(f"Cannot read: {p}") | |
| model=load_model(); img_rgb=cv2.cvtColor(img_bgr,cv2.COLOR_BGR2RGB) | |
| img_resized=cv2.resize(img_rgb,(MODEL_SZ,MODEL_SZ)) | |
| x=torch.from_numpy(img_resized.transpose(2,0,1)).float()/255.0; x=x.unsqueeze(0) | |
| with torch.no_grad(): prob=model(x)[0,0].detach().cpu().numpy() | |
| mask_small=(prob>thresh).astype(np.uint8)*255; h,w=img_bgr.shape[:2] | |
| colony_mask=cv2.resize(mask_small,(w,h),interpolation=cv2.INTER_NEAREST)>0 | |
| dish_info=detect_dish(img_bgr) | |
| gray=cv2.cvtColor(img_bgr,cv2.COLOR_BGR2GRAY).astype(np.float64)/255.0 | |
| crack_mask=detect_cracks(gray,colony_mask); hyph_f,hyph_m,hyph_h=detect_hyphae(gray,colony_mask) | |
| if dish_info: dcx,dcy,dr,px2mm=dish_info | |
| else: dcx,dcy=w//2,h//2; dr=min(h,w)//2; px2mm=1.0 | |
| metrics=compute_metrics(colony_mask,gray,px2mm,dcx,dcy,crack_mask,hyph_f,hyph_m,hyph_h) | |
| metrics.update(colony_pixels=int(colony_mask.sum()),dish_detected=dish_info is not None, | |
| dish_radius_px=dr,px_to_mm=round(px2mm,6), | |
| calibration_diameter_mm=round(2*dr*px2mm,4), | |
| calibration_error_pct=round(abs(2*dr*px2mm-90)/90*100,4) if dish_info else 0, | |
| image_path=Path(p).name,experiment_name=en or"",experiment_date=ed or"", | |
| image_date=imd,day_code=day_code(imd,ed) if ed else"", | |
| user_name=un or"",plates_count=int(pc) if pc else 1) | |
| results.append(metrics) | |
| panels=create_full_overlays(img_bgr,colony_mask,crack_mask,hyph_h,dish_info,Path(p).name) | |
| vis.extend(panels); log.info("%s: area=%.1f mm²",Path(p).name,metrics["area_mm2"]) | |
| except Exception as e: | |
| log.error("%s: %s",Path(p).name,e); errors.append(f"{Path(p).name}: {e}") | |
| results.append({"image_path":Path(p).name,"error":str(e)}) | |
| ok_results=[r for r in results if not r.get("error")] | |
| if len(ok_results)>1: | |
| ok_results.sort(key=lambda r:r.get("image_date","")) | |
| # days_since_start uses experiment date as base, +1, matching day_code exactly | |
| try: exp_base=dt.date.fromisoformat(ed) | |
| except: exp_base=None | |
| for i,r in enumerate(ok_results): | |
| try: r["days_since_start"]=(dt.date.fromisoformat(r.get("image_date",""))-exp_base).days+1 if exp_base else 0 | |
| except: r["days_since_start"]=0 | |
| if i==0: r["rgr_per_day"]=""; r["relative_growth_per_day"]=""; continue | |
| prev=ok_results[i-1] | |
| try: | |
| dd=(dt.date.fromisoformat(r["image_date"])-dt.date.fromisoformat(prev["image_date"])).days | |
| a2,a1=float(r.get("area_mm2",0)),float(prev.get("area_mm2",0)) | |
| if dd>0 and a1>0 and a2>0: r["rgr_per_day"]=round((math.log(a2)-math.log(a1))/dd,6); r["relative_growth_per_day"]=round((a2-a1)/dd,4) | |
| else: r["rgr_per_day"]=""; r["relative_growth_per_day"]="" | |
| except: r["rgr_per_day"]=""; r["relative_growth_per_day"]="" | |
| chart_items=make_growth_charts(ok_results) if len(ok_results)>=2 else [] | |
| tmp=tempfile.mkdtemp(); all_results=ok_results+[r for r in results if r.get("error")] | |
| cp=Path(tmp)/"analysis_full.csv" | |
| if all_results: | |
| ks=list(all_results[0].keys()) | |
| with open(cp,"w",newline="") as f: w=csv.DictWriter(f,fieldnames=ks,extrasaction="ignore"); w.writeheader(); w.writerows(all_results) | |
| jp=Path(tmp)/"analysis_full.json" | |
| with open(jp,"w") as f: json.dump(all_results,f,indent=2,default=str) | |
| for i,(cimg,cap) in enumerate(chart_items): cimg.save(str(Path(tmp)/f"chart_{i}.png")) | |
| zp=Path(tmp)/"analysis_full.zip" | |
| with zipfile.ZipFile(zp,"w") as z: | |
| for fp in Path(tmp).glob("*"): | |
| if fp.name!="analysis_full.zip": z.write(fp,fp.name) | |
| em=f"\n\n⚠️ Errors: {'; '.join(errors)}" if errors else "" | |
| cm=f"\n\n📊 **{len(chart_items)} charts**" if chart_items else "" | |
| return (f"✅ **{len(ok_results)}/{len(results)}** analyzed.{cm}{em}",vis,chart_items,pd.DataFrame(all_results),str(zp),all_results) | |
| run_btn.click(on_run,[paths_st,dates_st,exp_name,exp_date,user_name,plates_count,threshold_slider,full_pipeline_cb], | |
| [run_st,overlay_gallery,chart_gallery,results_df,results_dl,results_st]) | |
| if __name__=="__main__": | |
| demo.launch(server_name="0.0.0.0",server_port=7860) | |