import gradio as gr import tempfile from pathlib import Path from wrapper import run_pipeline_on_image import numpy as np from PIL import Image from itertools import product def show_preview(image): """Render uploaded image faithfully, including 16-bit/single-channel inputs. - RGB/RGBA: show as-is (strip alpha) - 16-bit or single-channel: min-max (or 1-99%ile) normalize to 8-bit for display """ if image is None: return None try: arr = np.array(image) # RGBA → RGB if arr.ndim == 3 and arr.shape[2] == 4: image = image.convert("RGB") arr = np.array(image) # RGB if arr.ndim == 3 and arr.shape[2] == 3: # If high bit-depth or non-uint8, normalize per-channel for visualization if arr.dtype != np.uint8 or np.max(arr) > 255: a = np.nan_to_num(arr.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0) vis = np.empty_like(a, dtype=np.float64) for c in range(3): vmin = np.percentile(a[..., c], 1.0) vmax = np.percentile(a[..., c], 99.0) if not np.isfinite(vmin) or not np.isfinite(vmax) or vmax <= vmin: vmin, vmax = float(np.min(a[..., c])), float(np.max(a[..., c])) denom = max(vmax - vmin, 1e-6) vis[..., c] = np.clip((a[..., c] - vmin) / denom, 0.0, 1.0) * 255.0 return Image.fromarray(vis.astype(np.uint8), mode='RGB') return image # Single-channel or higher bit-depth if arr.ndim == 2 or (arr.ndim == 3 and arr.shape[2] == 1): if arr.ndim == 3: arr = arr[..., 0] a = np.nan_to_num(arr.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0) # Robust contrast stretch vmin = np.percentile(a, 1.0) vmax = np.percentile(a, 99.0) if not np.isfinite(vmin) or not np.isfinite(vmax) or vmax <= vmin: vmin, vmax = float(np.min(a)), float(np.max(a)) denom = max(vmax - vmin, 1e-6) vis = np.clip((a - vmin) / denom, 0.0, 1.0) * 255.0 vis8 = vis.astype(np.uint8) return Image.fromarray(vis8, mode='L') # Fallback return image.convert("RGB") except Exception: return image def process(image): if image is None: return None, None, [], "" with tempfile.TemporaryDirectory() as tmpdir: # Save PIL image preserving original format ext = image.format.lower() if image.format else 'png' img_path = Path(tmpdir) / f"input.{ext}" image.save(img_path) outputs = run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True) # Assemble displays def load_pil(path_str): try: if not path_str: return None im = Image.open(path_str) im = im.convert('RGB') # Copy to memory so it survives after tmpdir is removed copied = im.copy() im.close() return copied except Exception: return None overlay = load_pil(outputs.get('Overlay')) mask = load_pil(outputs.get('Mask')) composite = load_pil(outputs.get('Composite')) order = ['NDVI', 'ARI', 'GNDVI'] gallery_items = [load_pil(outputs[k]) for k in order if k in outputs] stats_text = outputs.get('StatsText', '') return composite, overlay, mask, gallery_items, stats_text with gr.Blocks() as demo: gr.Markdown("# 🌿 Sorghum Plant Analysis Demo") gr.Markdown("Upload a sorghum plant image to analyze vegetation indices, segmentation overlay, and stats.") with gr.Row(): with gr.Column(): inp = gr.Image(type="pil", label="Upload Image") run = gr.Button("Run Pipeline", variant="primary") with gr.Column(): preview = gr.Image(type="pil", label="Uploaded Image Preview", interactive=False) with gr.Row(): composite_img = gr.Image(type="pil", label="Composite (Segmentation Input)", interactive=False) overlay_img = gr.Image(type="pil", label="Segmentation Overlay", interactive=False) mask_img = gr.Image(type="pil", label="Mask", interactive=False) gallery = gr.Gallery(label="Vegetation Indices", columns=3, height="auto") stats = gr.Textbox(label="Statistics", lines=4) # Update preview when image is uploaded inp.change(fn=show_preview, inputs=inp, outputs=preview) run.click(process, inputs=inp, outputs=[composite_img, overlay_img, mask_img, gallery, stats]) if __name__ == "__main__": demo.launch()