Fahimeh Orvati Nia
update
7c80781
raw
history blame
3.92 kB
import gradio as gr
import tempfile
from pathlib import Path
from wrapper import run_pipeline_on_image
import numpy as np
from PIL import Image
from itertools import product
def show_preview(image):
"""Render uploaded image faithfully, including 16-bit/single-channel inputs.
- RGB/RGBA: show as-is (strip alpha)
- 16-bit or single-channel: min-max (or 1-99%ile) normalize to 8-bit for display
"""
if image is None:
return None
try:
arr = np.array(image)
# RGBA โ†’ RGB
if arr.ndim == 3 and arr.shape[2] == 4:
return image.convert("RGB")
# RGB โ†’ as-is
if arr.ndim == 3 and arr.shape[2] == 3:
return image
# Single-channel or higher bit-depth
if arr.ndim == 2 or (arr.ndim == 3 and arr.shape[2] == 1):
if arr.ndim == 3:
arr = arr[..., 0]
a = np.nan_to_num(arr.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0)
# Robust contrast stretch
vmin = np.percentile(a, 1.0)
vmax = np.percentile(a, 99.0)
if not np.isfinite(vmin) or not np.isfinite(vmax) or vmax <= vmin:
vmin, vmax = float(np.min(a)), float(np.max(a))
denom = max(vmax - vmin, 1e-6)
vis = np.clip((a - vmin) / denom, 0.0, 1.0) * 255.0
vis8 = vis.astype(np.uint8)
return Image.fromarray(vis8, mode='L')
# Fallback
return image.convert("RGB")
except Exception:
return image
def process(image):
if image is None:
return None, None, [], ""
with tempfile.TemporaryDirectory() as tmpdir:
# Save PIL image preserving original format
ext = image.format.lower() if image.format else 'png'
img_path = Path(tmpdir) / f"input.{ext}"
image.save(img_path)
outputs = run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True)
# Assemble displays
def load_pil(path_str):
try:
if not path_str:
return None
im = Image.open(path_str)
im = im.convert('RGB')
# Copy to memory so it survives after tmpdir is removed
copied = im.copy()
im.close()
return copied
except Exception:
return None
overlay = load_pil(outputs.get('Overlay'))
mask = load_pil(outputs.get('Mask'))
composite = load_pil(outputs.get('Composite'))
order = ['NDVI', 'ARI', 'GNDVI']
gallery_items = [load_pil(outputs[k]) for k in order if k in outputs]
stats_text = outputs.get('StatsText', '')
return composite, overlay, mask, gallery_items, stats_text
with gr.Blocks() as demo:
gr.Markdown("# ๐ŸŒฟ Sorghum Plant Analysis Demo")
gr.Markdown("Upload a sorghum plant image to analyze vegetation indices, segmentation overlay, and stats.")
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Upload Image")
run = gr.Button("Run Pipeline", variant="primary")
with gr.Column():
preview = gr.Image(type="pil", label="Uploaded Image Preview", interactive=False)
with gr.Row():
composite_img = gr.Image(type="pil", label="Composite (Segmentation Input)", interactive=False)
overlay_img = gr.Image(type="pil", label="Segmentation Overlay", interactive=False)
mask_img = gr.Image(type="pil", label="Mask", interactive=False)
gallery = gr.Gallery(label="Vegetation Indices", columns=3, height="auto")
stats = gr.Textbox(label="Statistics", lines=4)
# Update preview when image is uploaded
inp.change(fn=show_preview, inputs=inp, outputs=preview)
run.click(process, inputs=inp, outputs=[composite_img, overlay_img, mask_img, gallery, stats])
if __name__ == "__main__":
demo.launch()