| |
| """ |
| app.py — Gray Leaf Spot Colony Segmentation Pipeline |
| """ |
|
|
| import csv, json, math, os, re, logging, tempfile, zipfile, io |
| import datetime as dt |
| from pathlib import Path |
|
|
| import cv2 |
| import gradio as gr |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image, ExifTags |
| from huggingface_hub import hf_hub_download |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%H:%M:%S") |
| log = logging.getLogger("app") |
|
|
| IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp"} |
| THUMB_SIZE = (160, 160) |
| DATE_RE = re.compile(r"(20\d{2})(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])") |
| MAX_IMAGES = 50 |
| MODEL_REPO = "rotsl/grayleafspot-segmentation" |
| MODEL_FILE = "best_area_w_0.7.pt" |
| DISH_MM = 90.0 |
| MODEL_SZ = 256 |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| CSS = ".gallery-wrap{max-height:65vh;overflow-y:auto} .footer-text{text-align:center;margin-top:8px}" |
|
|
| |
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int) -> None: |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.block(x) |
|
|
|
|
| class DownBlock(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int) -> None: |
| super().__init__() |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
| self.conv = ConvBlock(in_channels, out_channels) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.conv(self.pool(x)) |
|
|
|
|
| class UpBlock(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int) -> None: |
| super().__init__() |
| self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) |
| self.conv = ConvBlock(in_channels, out_channels) |
|
|
| def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: |
| x = self.up(x) |
| if x.shape[-2:] != skip.shape[-2:]: |
| x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False) |
| x = torch.cat([skip, x], dim=1) |
| return self.conv(x) |
|
|
|
|
| class SmallUNet(nn.Module): |
| def __init__(self, in_channels=3, out_channels=1, base_channels=16): |
| super().__init__() |
| c1 = base_channels |
| c2 = base_channels * 2 |
| c3 = base_channels * 4 |
| c4 = base_channels * 8 |
| bottleneck = base_channels * 16 |
| self.enc1 = ConvBlock(in_channels, c1) |
| self.enc2 = DownBlock(c1, c2) |
| self.enc3 = DownBlock(c2, c3) |
| self.enc4 = DownBlock(c3, c4) |
| self.bottleneck = DownBlock(c4, bottleneck) |
| self.up4 = UpBlock(bottleneck + c4, c4) |
| self.up3 = UpBlock(c4 + c3, c3) |
| self.up2 = UpBlock(c3 + c2, c2) |
| self.up1 = UpBlock(c2 + c1, c1) |
| self.head = nn.Conv2d(c1, out_channels, kernel_size=1) |
| self.activation = nn.Sigmoid() |
|
|
| def forward(self, x): |
| s1 = self.enc1(x) |
| s2 = self.enc2(s1) |
| s3 = self.enc3(s2) |
| s4 = self.enc4(s3) |
| b = self.bottleneck(s4) |
| x = self.up4(b, s4) |
| x = self.up3(x, s3) |
| x = self.up2(x, s2) |
| x = self.up1(x, s1) |
| x = self.head(x) |
| return self.activation(x) |
|
|
| |
| _model = None |
| def load_model(): |
| global _model |
| if _model is None: |
| p = hf_hub_download(MODEL_REPO, MODEL_FILE, token=HF_TOKEN) |
| _model = SmallUNet(in_channels=3, out_channels=1, base_channels=16) |
| ckpt = torch.load(p, map_location="cpu", weights_only=False) |
| sd = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt |
| _model.load_state_dict(sd, strict=True); _model.eval() |
| log.info("Model loaded: SmallUNet (%s)", MODEL_FILE) |
| return _model |
|
|
| |
| def infer_image(img_pil, threshold): |
| model = load_model() |
| img_arr = np.array(img_pil.convert("RGB")) |
| img_resized = cv2.resize(img_arr, (MODEL_SZ, MODEL_SZ)) |
| x = torch.from_numpy(img_resized.transpose(2, 0, 1)).float() / 255.0 |
| x = x.unsqueeze(0) |
| with torch.no_grad(): |
| prob = model(x)[0, 0].detach().cpu().numpy() |
| log.info(" output range: [%.4f, %.4f] mean=%.4f >0.5:%d >0.3:%d >0.1:%d", |
| prob.min(), prob.max(), prob.mean(), |
| (prob > 0.5).sum(), (prob > 0.3).sum(), (prob > 0.1).sum()) |
| mask = (prob > threshold).astype(np.uint8) * 255 |
| mask = cv2.resize(mask, (img_pil.width, img_pil.height), interpolation=cv2.INTER_NEAREST) |
| overlay = img_arr.copy() |
| overlay[mask > 0] = (overlay[mask > 0] * 0.5 + np.array([255, 0, 0]) * 0.5).astype(np.uint8) |
| return Image.fromarray(overlay), Image.fromarray(mask) |
|
|
| |
| def make_thumbnail(p): |
| try: im = Image.open(p); im.thumbnail(THUMB_SIZE, Image.LANCZOS); return im |
| except: return Image.new("RGB", THUMB_SIZE, (200, 200, 200)) |
|
|
| def detect_image_date(p): |
| m = DATE_RE.search(Path(p).stem) |
| if m: |
| try: return dt.date(int(m[1]), int(m[2]), int(m[3])).isoformat() |
| except: pass |
| try: |
| im = Image.open(p); ex = im.getexif() |
| if ex: |
| for tid, tn in ExifTags.TAGS.items(): |
| if tn == "DateTimeOriginal": |
| v = ex.get(tid) |
| if v: return dt.datetime.strptime(v, "%Y:%m:%d %H:%M:%S").date().isoformat() |
| except: pass |
| try: return dt.date.fromtimestamp(os.path.getmtime(p)).isoformat() |
| except: return dt.date.today().isoformat() |
|
|
| def day_code(img_d, exp_d): |
| try: d = (dt.date.fromisoformat(img_d) - dt.date.fromisoformat(exp_d)).days + 1; return f"d{max(d,1):02d}" |
| except: return "d??" |
|
|
| def write_ics(rems, path): |
| L = ["BEGIN:VCALENDAR","VERSION:2.0","PRODID:-//FungalPipeline//EN"] |
| for r in rems: |
| uid = r["image_path"].replace("/","_") |
| ds = r["remind_me"].replace("-","").replace(" ","T").replace(":","") + "00" |
| L += ["BEGIN:VEVENT",f"UID:{uid}@fp",f"DTSTART:{ds}", |
| f"SUMMARY:Reminder - {r['experiment_name']}: {Path(r['image_path']).name}","END:VEVENT"] |
| L.append("END:VCALENDAR") |
| with open(path,"w") as f: f.write("\r\n".join(L)) |
|
|
| def fig_to_pil(fig): |
| buf = io.BytesIO(); fig.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="white") |
| buf.seek(0); img = Image.open(buf).copy(); buf.close(); plt.close(fig); return img |
|
|
| |
| def _load_skimage(): |
| from skimage import filters, measure, morphology |
| from skimage.filters import frangi, meijering |
| from skimage.morphology import skeletonize, disk, opening, closing, erosion, dilation |
| return filters, measure, morphology, frangi, meijering, skeletonize, disk, opening, closing, erosion, dilation |
|
|
| def detect_dish(img_bgr): |
| try: |
| gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) |
| blurred = cv2.GaussianBlur(gray, (9,9), 2); h, w = gray.shape |
| mn, mx = int(min(h,w)*0.25), int(min(h,w)*0.52) |
| circles = cv2.HoughCircles(blurred, cv2.HOUGH_GRADIENT, dp=1.2, minDist=min(h,w)//2, |
| param1=100, param2=40, minRadius=mn, maxRadius=mx) |
| if circles is None: return None |
| circles = np.round(circles[0]).astype(int); ic,jc = w/2, h/2; bi,bs = 0,-1 |
| for i,(cx,cy,r) in enumerate(circles): |
| s = r / (1 + math.hypot(cx-ic, cy-jc)/100) |
| if s > bs: bs=s; bi=i |
| cx,cy,r = int(circles[bi][0]), int(circles[bi][1]), int(circles[bi][2]) |
| return cx, cy, r, DISH_MM/(2*r) |
| except: return None |
|
|
| def detect_cracks(gray, colony_mask): |
| filters,measure,_,_,_,_,disk,opening,_,erosion,_ = _load_skimage() |
| if colony_mask.sum() < 100: return np.zeros_like(colony_mask, dtype=bool) |
| interior = gray.copy(); interior[~colony_mask] = 0; er = erosion(colony_mask, disk(5)) |
| iu = (interior*255 if interior.max()<=1 else interior).astype(np.uint8) |
| lt = filters.threshold_local(iu, block_size=51, method="gaussian") |
| dk = (iu < (lt-15)) & er; dk = opening(dk, disk(1)); lb = measure.label(dk) |
| cm = np.zeros_like(dk, dtype=bool) |
| for rp in measure.regionprops(lb): |
| if rp.area < 10: continue |
| if rp.major_axis_length > 0 and rp.minor_axis_length > 0: |
| if rp.major_axis_length/rp.minor_axis_length > 2.5 or rp.eccentricity > 0.85: |
| cm[lb==rp.label] = True |
| return cm |
|
|
| def detect_hyphae(gray, colony_mask): |
| _,_,_,frangi,meijering,skeletonize,disk,_,_,_,dilation = _load_skimage() |
| if colony_mask.sum() < 100: |
| z = np.zeros_like(colony_mask, dtype=bool); return z, z.copy(), z.copy() |
| g = gray.astype(np.float64); ex = dilation(colony_mask, disk(20)) |
| fr = frangi(g, sigmas=range(1,5), black_ridges=False); fr[~ex]=0 |
| th = fr[ex].mean()+2*fr[ex].std() if ex.sum()>0 else .01; fs = skeletonize(fr>th) |
| mr = meijering(g, sigmas=range(1,5), black_ridges=False); mr[~ex]=0 |
| th2 = mr[ex].mean()+2*mr[ex].std() if ex.sum()>0 else .01; ms = skeletonize(mr>th2) |
| return fs, ms, fs|ms |
|
|
| def compute_metrics(mask_bool, gray, px2mm, dcx, dcy, crack_mask, hyph_f, hyph_m, hyph_h): |
| filters,measure,morphology,_,_,_,_,_,_,_,_ = _load_skimage() |
| mm2 = px2mm**2 |
| if mask_bool.sum() < 50: |
| return {k:0 for k in ["area_mm2","diameter_mm","perimeter_mm","eccentricity","edge_roughness", |
| "centre_delta_mm","entropy","texture_std","crack_px","crack_area_mm2", |
| "crack_coverage_pct","crack_count","hyph_frangi_mm","hyph_meijering_mm","hyph_hybrid_mm"]} |
| pr = measure.regionprops(mask_bool.astype(np.uint8))[0]; R = {} |
| R["area_mm2"]=round(pr.area*mm2,4); pm=measure.perimeter(mask_bool) |
| R["perimeter_mm"]=round(pm*px2mm,4); R["diameter_mm"]=round(pr.equivalent_diameter_area*px2mm,4) |
| R["eccentricity"]=round(pr.eccentricity,6); eq=math.pi*pr.equivalent_diameter_area |
| R["edge_roughness"]=round(pm/eq,6) if eq>0 else 0; cy,cx=pr.centroid |
| R["centre_delta_mm"]=round(math.hypot(cx-dcx,cy-dcy)*px2mm,4) |
| gu8=(gray*255).astype(np.uint8) if gray.max()<=1 else gray.astype(np.uint8) |
| R["entropy"]=round(float(filters.rank.entropy(gu8,morphology.disk(5),mask=mask_bool)[mask_bool].mean()),6) if pr.area>100 else 0 |
| R["texture_std"]=round(float(gray[mask_bool].std()),6) |
| R["crack_px"]=int(crack_mask.sum()); R["crack_area_mm2"]=round(crack_mask.sum()*mm2,6) |
| R["crack_coverage_pct"]=round(100*crack_mask.sum()/pr.area,4) if pr.area>0 else 0 |
| R["crack_count"]=int(measure.label(crack_mask).max()) |
| R["hyph_frangi_mm"]=round(int(hyph_f.sum())*px2mm,4) |
| R["hyph_meijering_mm"]=round(int(hyph_m.sum())*px2mm,4) |
| R["hyph_hybrid_mm"]=round(int(hyph_h.sum())*px2mm,4) |
| return R |
|
|
| def create_full_overlays(img_bgr, colony_mask, crack_mask, hyph_hybrid, dish_info, fname): |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB); h,w = img_bgr.shape[:2] |
| dcx,dcy,dr = (dish_info[0],dish_info[1],dish_info[2]) if dish_info else (w//2,h//2,min(h,w)//2) |
| p1=img_rgb.copy() |
| if dish_info: cv2.circle(p1,(dcx,dcy),dr,(0,255,0),3) |
| cts,_ = cv2.findContours(colony_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| cv2.drawContours(p1,cts,-1,(255,0,0),2) |
| p2=np.zeros_like(img_rgb); p2[colony_mask]=[255,255,255] |
| p3=img_rgb.copy() |
| if colony_mask.sum()>0: p3[colony_mask]=(p3[colony_mask].astype(np.float32)*0.5+np.array([255,0,0],dtype=np.float32)*0.5).astype(np.uint8) |
| if dish_info: cv2.circle(p3,(dcx,dcy),dr,(0,255,0),2) |
| p4=img_rgb.copy() |
| if crack_mask.sum()>0: |
| ck=cv2.dilate(crack_mask.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0 |
| p4[ck]=(p4[ck].astype(np.float32)*0.3+np.array([255,255,0],dtype=np.float32)*0.7).astype(np.uint8) |
| if dish_info: cv2.circle(p4,(dcx,dcy),dr,(0,255,0),2); cv2.drawContours(p4,cts,-1,(255,0,0),1) |
| p5=img_rgb.copy() |
| if hyph_hybrid.sum()>0: |
| hy=cv2.dilate(hyph_hybrid.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0 |
| p5[hy]=(p5[hy].astype(np.float32)*0.3+np.array([0,255,255],dtype=np.float32)*0.7).astype(np.uint8) |
| if dish_info: cv2.circle(p5,(dcx,dcy),dr,(0,255,0),2); cv2.drawContours(p5,cts,-1,(255,0,0),1) |
| p6=img_rgb.copy() |
| if colony_mask.sum()>0: p6[colony_mask]=(p6[colony_mask].astype(np.float32)*0.6+np.array([255,0,0],dtype=np.float32)*0.4).astype(np.uint8) |
| if crack_mask.sum()>0: |
| ck2=cv2.dilate(crack_mask.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0; p6[ck2]=[255,255,0] |
| if hyph_hybrid.sum()>0: |
| hy2=cv2.dilate(hyph_hybrid.astype(np.uint8),cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))>0; p6[hy2]=[0,255,255] |
| if dish_info: cv2.circle(p6,(dcx,dcy),dr,(0,255,0),2) |
| return [(Image.fromarray(p1),f"{fname} — Raw+Dish"),(Image.fromarray(p2),f"{fname} — Mask"), |
| (Image.fromarray(p3),f"{fname} — Colony"),(Image.fromarray(p4),f"{fname} — Cracks"), |
| (Image.fromarray(p5),f"{fname} — Hyphae"),(Image.fromarray(p6),f"{fname} — Combined")] |
|
|
| def make_growth_charts(results): |
| """Generate time-series charts for every morphometric parameter. |
| |
| All spatial metrics are already in mm (or mm²) via the per-image |
| px_to_mm calibration computed from dish detection, so images of |
| different resolutions are correctly comparable. |
| """ |
| if len(results) < 2: |
| return [] |
| df = pd.DataFrame(results) |
| if "error" in df.columns: |
| df = df[df["error"].fillna("").astype(str).str.strip() == ""].copy() |
| if len(df) < 2: |
| return [] |
|
|
| |
| numeric_cols = [ |
| "days_since_start", "area_mm2", "diameter_mm", "perimeter_mm", |
| "eccentricity", "edge_roughness", "centre_delta_mm", |
| "entropy", "texture_std", |
| "crack_area_mm2", "crack_coverage_pct", "crack_count", |
| "hyph_frangi_mm", "hyph_meijering_mm", "hyph_hybrid_mm", |
| "rgr_per_day", "relative_growth_per_day", |
| ] |
| for c in numeric_cols: |
| if c in df.columns: |
| df[c] = pd.to_numeric(df[c], errors="coerce") |
|
|
| df = df.sort_values("days_since_start").reset_index(drop=True) |
| charts = [] |
|
|
| |
| |
| chart_defs = [ |
| |
| ("area_mm2", "Area (mm²)", "Colony Area", "#e74c3c", "o", True), |
| ("diameter_mm", "Diameter (mm)", "Colony Diameter", "#2980b9", "s", False), |
| ("perimeter_mm", "Perimeter (mm)", "Colony Perimeter", "#8e44ad", "^", False), |
| |
| ("eccentricity", "Eccentricity", "Colony Eccentricity", "#e67e22", "D", False), |
| ("edge_roughness", "Edge Roughness", "Edge Roughness (P / πd)", "#16a085", "v", False), |
| ("centre_delta_mm", "Centre Offset (mm)", "Colony Centre Offset", "#d35400", "p", False), |
| |
| ("entropy", "Entropy", "Colony Texture Entropy", "#7f8c8d", "h", False), |
| ("texture_std", "Texture Std Dev", "Colony Texture Std Dev", "#2c3e50", "*", False), |
| |
| ("crack_area_mm2", "Crack Area (mm²)", "Crack Area", "#f1c40f", "o", True), |
| ("crack_coverage_pct", "Crack Coverage (%)", "Crack Coverage", "#d4ac0d", "s", False), |
| ("crack_count", "Crack Count", "Number of Cracks", "#b7950b", "^", False), |
| |
| ("hyph_frangi_mm", "Length (mm)", "Hyphae Length — Frangi", "#1abc9c", "o", False), |
| ("hyph_meijering_mm", "Length (mm)", "Hyphae Length — Meijering", "#3498db", "s", False), |
| ("hyph_hybrid_mm", "Length (mm)", "Hyphae Length — Hybrid", "#2ecc71", "D", False), |
| |
| ("rgr_per_day", "RGR (ln mm² / day)", "Relative Growth Rate", "#c0392b", "o", False), |
| ("relative_growth_per_day", "Growth (mm² / day)", "Absolute Growth Rate", "#e74c3c", "s", False), |
| ] |
|
|
| for col, ylabel, title, color, marker, fill in chart_defs: |
| if col not in df.columns: |
| continue |
| valid = df[col].notna() |
| |
| valid = valid & (df[col].astype(str).str.strip() != "") |
| if valid.sum() < 2: |
| continue |
| sub = df.loc[valid].copy() |
| fig, ax = plt.subplots(figsize=(8, 4)) |
| ax.plot(sub["days_since_start"], sub[col], f"{marker}-", |
| color=color, lw=2, ms=8) |
| if fill: |
| ax.fill_between(sub["days_since_start"], 0, sub[col], |
| alpha=0.15, color=color) |
| ax.set(xlabel="Days", ylabel=ylabel, title=title) |
| ax.grid(True, alpha=0.3) |
| charts.append((fig_to_pil(fig), title)) |
|
|
| return charts |
|
|
| |
| |
| |
| with gr.Blocks(title="Gray Leaf Spot Segmentation", css=CSS) as demo: |
| paths_st=gr.State([]); dates_st=gr.State({}); rems_st=gr.State({}); cur_idx=gr.State(-1); results_st=gr.State([]) |
|
|
| gr.Markdown("# 🔬 Gray Leaf Spot Colony Segmentation\n" |
| "Upload → **Run Inference** → instant results | Toggle *Full Pipeline* for morphometrics\n\n" |
| "Model: [`rotsl/grayleafspot-segmentation/best_area_w_0.7.pt`]" |
| "(https://huggingface.co/rotsl/grayleafspot-segmentation) · SmallUNet (area-consistency w=0.7)") |
|
|
| with gr.Accordion("📂 Step 1 — Upload Images", open=True): |
| upload = gr.File(label="Drag & drop petri dish images", file_count="multiple", |
| file_types=["image"]) |
| up_st = gr.Markdown("") |
|
|
| with gr.Accordion("⚙️ Step 2 — Settings", open=True): |
| with gr.Row(): |
| threshold_slider = gr.Slider(label="Mask confidence threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.01) |
| full_pipeline_cb = gr.Checkbox(label="Full Pipeline (slower: dish, cracks, hyphae, morphometrics)", value=False) |
| with gr.Row(): |
| exp_name = gr.Textbox(label="Experiment Name", placeholder="MagExp01") |
| exp_date = gr.Textbox(label="Experiment Date", placeholder="2025-04-01") |
| user_name = gr.Textbox(label="User Name", placeholder="Your name") |
| plates_count = gr.Number(label="Plates", value=1, minimum=1, maximum=200, precision=0) |
|
|
| with gr.Accordion("🖼️ Step 3 — Review & Edit Dates", open=False): |
| gr.Markdown("*Click thumbnail → edit date → Save*") |
| with gr.Row(): |
| with gr.Column(scale=2): |
| gallery = gr.Gallery(label="Images", columns=4, height=400, object_fit="contain", allow_preview=False, interactive=False) |
| with gr.Column(scale=1): |
| sel_img=gr.Image(label="Selected",height=200,interactive=False) |
| sel_fn=gr.Textbox(label="Filename",interactive=False) |
| sel_dt=gr.Textbox(label="Image Date",interactive=True) |
| sel_dc=gr.Textbox(label="Day Code",interactive=False) |
| sel_rm=gr.Textbox(label="Remind Me",placeholder="YYYY-MM-DD HH:MM",interactive=True) |
| sv_btn=gr.Button("💾 Save Date",variant="primary"); sv_st=gr.Markdown("") |
|
|
| with gr.Accordion("📥 Step 4 — Export Metadata", open=False): |
| exp_btn=gr.Button("📥 Export CSV / JSON / ICS",variant="primary"); exp_st=gr.Markdown("") |
| meta_preview=gr.Dataframe(label="image_metadata.csv",interactive=False,wrap=True) |
| meta_dl=gr.File(label="⬇️ Download metadata zip",interactive=False) |
|
|
| with gr.Accordion("🚀 Step 5 — Run Inference", open=True): |
| run_btn=gr.Button("🚀 Run Inference",variant="primary",size="lg"); run_st=gr.Markdown("") |
| gr.Markdown("### Results") |
| overlay_gallery=gr.Gallery(label="Segmentation results",columns=3,height=500,object_fit="contain",allow_preview=True) |
| gr.Markdown("### Growth Charts (full pipeline, ≥2 images)") |
| chart_gallery=gr.Gallery(label="Growth curves",columns=3,height=400,object_fit="contain",allow_preview=True) |
| gr.Markdown("### Results Table (full pipeline)") |
| results_df=gr.Dataframe(label="analysis_full.csv",interactive=False,wrap=True) |
| results_dl=gr.File(label="⬇️ Download analysis zip",interactive=False) |
|
|
| |
| gr.Markdown("<div class='footer-text'>\n\n---\nDeveloped by [Rohan R](https://rotsl.github.io/)\n</div>") |
|
|
| |
| def on_upload(files): |
| if not files: return [],[],{},[],"",-1 |
| paths=[str(f) for f in files if Path(str(f)).suffix.lower() in IMAGE_EXTS][:MAX_IMAGES] |
| if not paths: return [],[],{},[],"",-1 |
| dates={p:detect_image_date(p) for p in paths}; rems={p:"" for p in paths} |
| return paths,dates,rems,[(p,Path(p).name) for p in paths],f"✅ **{len(paths)}** images loaded.",-1 |
| upload.upload(on_upload,[upload],[paths_st,dates_st,rems_st,gallery,up_st,cur_idx]) |
|
|
| def on_sel(paths,dates,rems,ed,evt:gr.SelectData): |
| i=evt.index |
| if i<0 or i>=len(paths): return -1,None,"","","","" |
| p=paths[i]; return i,make_thumbnail(p),Path(p).name,dates.get(p,""),day_code(dates.get(p,""),ed) if ed else "",rems.get(p,"") |
| gallery.select(on_sel,[paths_st,dates_st,rems_st,exp_date],[cur_idx,sel_img,sel_fn,sel_dt,sel_dc,sel_rm]) |
|
|
| def on_save(paths,dates,rems,i,nd,nr,ed): |
| if i<0 or i>=len(paths): return dates,rems,"","⚠️ Select image." |
| p=paths[i]; dates=dict(dates); rems=dict(rems); dates[p]=nd; rems[p]=nr |
| return dates,rems,day_code(nd,ed) if ed else "",f"✅ **{Path(p).name}** → {nd}" |
| sv_btn.click(on_save,[paths_st,dates_st,rems_st,cur_idx,sel_dt,sel_rm,exp_date],[dates_st,rems_st,sel_dc,sv_st]) |
|
|
| def on_export(paths,dates,rems,en,ed,un,pc): |
| if not paths: return "⚠️ Upload first.",None,None |
| tmp=tempfile.mkdtemp(); rows=[]; rl=[] |
| for p in paths: |
| imd=dates.get(p,detect_image_date(p)); rm=rems.get(p,"") |
| row=dict(image_path=Path(p).name,experiment_name=en or"",experiment_date=ed or"", |
| image_date=imd,day_code=day_code(imd,ed) if ed else"",user_name=un or"", |
| plates_count=int(pc) if pc else 1,remind_me=rm) |
| rows.append(row) |
| if rm.strip(): rl.append({**row}) |
| cp=Path(tmp)/"image_metadata.csv" |
| with open(cp,"w",newline="") as f: w=csv.DictWriter(f,fieldnames=list(rows[0].keys())); w.writeheader(); w.writerows(rows) |
| jp=Path(tmp)/"image_metadata.json" |
| with open(jp,"w") as f: json.dump(rows,f,indent=2) |
| zf=[cp,jp] |
| if rl: ip=Path(tmp)/"reminders.ics"; write_ics(rl,str(ip)); zf.append(ip) |
| zp=Path(tmp)/"image_metadata.zip" |
| with zipfile.ZipFile(zp,"w") as z: |
| for f2 in zf: z.write(f2,f2.name) |
| return f"✅ Exported **{len(rows)}** images.",pd.DataFrame(rows),str(zp) |
| exp_btn.click(on_export,[paths_st,dates_st,rems_st,exp_name,exp_date,user_name,plates_count],[exp_st,meta_preview,meta_dl]) |
|
|
| def on_run(paths,dates,en,ed,un,pc,thresh,full_pipeline,progress=gr.Progress()): |
| if not paths: return "⚠️ Upload images first.",[],[],None,None,[] |
| try: load_model() |
| except Exception as e: return f"❌ Model failed: {e}",[],[],None,None,[] |
|
|
| results=[]; vis=[]; errors=[] |
|
|
| if not full_pipeline: |
| for p in progress.tqdm(paths, desc="Segmenting"): |
| try: |
| img=Image.open(p).convert("RGB") |
| overlay,mask=infer_image(img,thresh) |
| mask_px=np.sum(np.array(mask)>0) |
| vis.append((img,f"{Path(p).name} — Raw")) |
| vis.append((mask,f"{Path(p).name} — Mask")) |
| vis.append((overlay,f"{Path(p).name} — Overlay")) |
| log.info("%s: done (mask_pixels=%d, threshold=%.2f)", Path(p).name, mask_px, thresh) |
| except Exception as e: |
| log.error("%s: %s",Path(p).name,e); errors.append(f"{Path(p).name}: {e}") |
| em=f"\n\n⚠️ Errors: {'; '.join(errors)}" if errors else "" |
| ok=len(paths)-len(errors) |
| return f"✅ **{ok}/{len(paths)}** segmented (fast mode, threshold={thresh:.2f}).{em}",vis,[],None,None,[] |
|
|
| |
| for p in progress.tqdm(paths, desc="Full pipeline"): |
| imd=dates.get(p,detect_image_date(p)) |
| try: |
| img_bgr=cv2.imread(str(p)) |
| if img_bgr is None: raise RuntimeError(f"Cannot read: {p}") |
| model=load_model(); img_rgb=cv2.cvtColor(img_bgr,cv2.COLOR_BGR2RGB) |
| img_resized=cv2.resize(img_rgb,(MODEL_SZ,MODEL_SZ)) |
| x=torch.from_numpy(img_resized.transpose(2,0,1)).float()/255.0; x=x.unsqueeze(0) |
| with torch.no_grad(): prob=model(x)[0,0].detach().cpu().numpy() |
| mask_small=(prob>thresh).astype(np.uint8)*255; h,w=img_bgr.shape[:2] |
| colony_mask=cv2.resize(mask_small,(w,h),interpolation=cv2.INTER_NEAREST)>0 |
| dish_info=detect_dish(img_bgr) |
| gray=cv2.cvtColor(img_bgr,cv2.COLOR_BGR2GRAY).astype(np.float64)/255.0 |
| crack_mask=detect_cracks(gray,colony_mask); hyph_f,hyph_m,hyph_h=detect_hyphae(gray,colony_mask) |
| if dish_info: dcx,dcy,dr,px2mm=dish_info |
| else: dcx,dcy=w//2,h//2; dr=min(h,w)//2; px2mm=1.0 |
| metrics=compute_metrics(colony_mask,gray,px2mm,dcx,dcy,crack_mask,hyph_f,hyph_m,hyph_h) |
| metrics.update(colony_pixels=int(colony_mask.sum()),dish_detected=dish_info is not None, |
| dish_radius_px=dr,px_to_mm=round(px2mm,6), |
| calibration_diameter_mm=round(2*dr*px2mm,4), |
| calibration_error_pct=round(abs(2*dr*px2mm-90)/90*100,4) if dish_info else 0, |
| image_path=Path(p).name,experiment_name=en or"",experiment_date=ed or"", |
| image_date=imd,day_code=day_code(imd,ed) if ed else"", |
| user_name=un or"",plates_count=int(pc) if pc else 1) |
| results.append(metrics) |
| panels=create_full_overlays(img_bgr,colony_mask,crack_mask,hyph_h,dish_info,Path(p).name) |
| vis.extend(panels); log.info("%s: area=%.1f mm²",Path(p).name,metrics["area_mm2"]) |
| except Exception as e: |
| log.error("%s: %s",Path(p).name,e); errors.append(f"{Path(p).name}: {e}") |
| results.append({"image_path":Path(p).name,"error":str(e)}) |
|
|
| ok_results=[r for r in results if not r.get("error")] |
| if len(ok_results)>1: |
| ok_results.sort(key=lambda r:r.get("image_date","")) |
| try: base=dt.date.fromisoformat(ok_results[0].get("image_date","")) |
| except: base=None |
| for i,r in enumerate(ok_results): |
| try: r["days_since_start"]=(dt.date.fromisoformat(r.get("image_date",""))-base).days if base else 0 |
| except: r["days_since_start"]=0 |
| if i==0: r["rgr_per_day"]=""; r["relative_growth_per_day"]=""; continue |
| prev=ok_results[i-1] |
| try: |
| dd=(dt.date.fromisoformat(r["image_date"])-dt.date.fromisoformat(prev["image_date"])).days |
| a2,a1=float(r.get("area_mm2",0)),float(prev.get("area_mm2",0)) |
| if dd>0 and a1>0 and a2>0: |
| r["rgr_per_day"]=round((math.log(a2)-math.log(a1))/dd,6) |
| r["relative_growth_per_day"]=round((a2-a1)/dd,4) |
| else: r["rgr_per_day"]=""; r["relative_growth_per_day"]="" |
| except: r["rgr_per_day"]=""; r["relative_growth_per_day"]="" |
| chart_items=make_growth_charts(ok_results) if len(ok_results)>=2 else [] |
| tmp=tempfile.mkdtemp(); all_results=ok_results+[r for r in results if r.get("error")] |
| cp=Path(tmp)/"analysis_full.csv" |
| if all_results: |
| ks=list(all_results[0].keys()) |
| with open(cp,"w",newline="") as f: w=csv.DictWriter(f,fieldnames=ks,extrasaction="ignore"); w.writeheader(); w.writerows(all_results) |
| jp=Path(tmp)/"analysis_full.json" |
| with open(jp,"w") as f: json.dump(all_results,f,indent=2,default=str) |
| for i,(cimg,cap) in enumerate(chart_items): cimg.save(str(Path(tmp)/f"chart_{i}.png")) |
| zp=Path(tmp)/"analysis_full.zip" |
| with zipfile.ZipFile(zp,"w") as z: |
| for fp in Path(tmp).glob("*"): |
| if fp.name!="analysis_full.zip": z.write(fp,fp.name) |
| em=f"\n\n⚠️ Errors: {'; '.join(errors)}" if errors else "" |
| cm=f"\n\n📊 **{len(chart_items)} charts**" if chart_items else "" |
| return (f"✅ **{len(ok_results)}/{len(results)}** analyzed.{cm}{em}", |
| vis,chart_items,pd.DataFrame(all_results),str(zp),all_results) |
|
|
| run_btn.click(on_run,[paths_st,dates_st,exp_name,exp_date,user_name,plates_count,threshold_slider,full_pipeline_cb], |
| [run_st,overlay_gallery,chart_gallery,results_df,results_dl,results_st]) |
|
|
| if __name__=="__main__": |
| demo.launch(server_name="0.0.0.0",server_port=7860) |
|
|