|
|
import gradio as gr |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from wrapper import run_pipeline_on_image |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from itertools import product |
|
|
|
|
|
def show_preview(image): |
|
|
"""Render uploaded image faithfully, including 16-bit/single-channel inputs. |
|
|
|
|
|
- RGB/RGBA: show as-is (strip alpha) |
|
|
- 16-bit or single-channel: min-max (or 1-99%ile) normalize to 8-bit for display |
|
|
""" |
|
|
if image is None: |
|
|
return None |
|
|
try: |
|
|
arr = np.array(image) |
|
|
|
|
|
if arr.ndim == 3 and arr.shape[2] == 4: |
|
|
return image.convert("RGB") |
|
|
|
|
|
if arr.ndim == 3 and arr.shape[2] == 3: |
|
|
return image |
|
|
|
|
|
if arr.ndim == 2 or (arr.ndim == 3 and arr.shape[2] == 1): |
|
|
if arr.ndim == 3: |
|
|
arr = arr[..., 0] |
|
|
a = np.nan_to_num(arr.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0) |
|
|
|
|
|
vmin = np.percentile(a, 1.0) |
|
|
vmax = np.percentile(a, 99.0) |
|
|
if not np.isfinite(vmin) or not np.isfinite(vmax) or vmax <= vmin: |
|
|
vmin, vmax = float(np.min(a)), float(np.max(a)) |
|
|
denom = max(vmax - vmin, 1e-6) |
|
|
vis = np.clip((a - vmin) / denom, 0.0, 1.0) * 255.0 |
|
|
vis8 = vis.astype(np.uint8) |
|
|
return Image.fromarray(vis8, mode='L') |
|
|
|
|
|
return image.convert("RGB") |
|
|
except Exception: |
|
|
return image |
|
|
|
|
|
def process(image): |
|
|
if image is None: |
|
|
return None, None, [], "" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
|
|
|
ext = image.format.lower() if image.format else 'png' |
|
|
img_path = Path(tmpdir) / f"input.{ext}" |
|
|
image.save(img_path) |
|
|
outputs = run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True) |
|
|
|
|
|
|
|
|
def load_pil(path_str): |
|
|
try: |
|
|
if not path_str: |
|
|
return None |
|
|
im = Image.open(path_str) |
|
|
im = im.convert('RGB') |
|
|
|
|
|
copied = im.copy() |
|
|
im.close() |
|
|
return copied |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
overlay = load_pil(outputs.get('Overlay')) |
|
|
mask = load_pil(outputs.get('Mask')) |
|
|
composite = load_pil(outputs.get('Composite')) |
|
|
order = ['NDVI', 'ARI', 'GNDVI'] |
|
|
gallery_items = [load_pil(outputs[k]) for k in order if k in outputs] |
|
|
stats_text = outputs.get('StatsText', '') |
|
|
return composite, overlay, mask, gallery_items, stats_text |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# ๐ฟ Sorghum Plant Analysis Demo") |
|
|
gr.Markdown("Upload a sorghum plant image to analyze vegetation indices, segmentation overlay, and stats.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
inp = gr.Image(type="pil", label="Upload Image") |
|
|
run = gr.Button("Run Pipeline", variant="primary") |
|
|
with gr.Column(): |
|
|
preview = gr.Image(type="pil", label="Uploaded Image Preview", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
composite_img = gr.Image(type="pil", label="Composite (Segmentation Input)", interactive=False) |
|
|
overlay_img = gr.Image(type="pil", label="Segmentation Overlay", interactive=False) |
|
|
mask_img = gr.Image(type="pil", label="Mask", interactive=False) |
|
|
|
|
|
gallery = gr.Gallery(label="Vegetation Indices", columns=3, height="auto") |
|
|
stats = gr.Textbox(label="Statistics", lines=4) |
|
|
|
|
|
|
|
|
inp.change(fn=show_preview, inputs=inp, outputs=preview) |
|
|
run.click(process, inputs=inp, outputs=[composite_img, overlay_img, mask_img, gallery, stats]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |