Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from patchify import patchify, unpatchify | |
| from phasepack import phasecong | |
| from PIL import Image | |
| from segmentation_models_pytorch import Segformer | |
| from skimage import color, io | |
| from skimage.feature import canny | |
| from skimage.filters import sato | |
| from src.unet import UNet | |
| from src.train import eval_single | |
| from src.dataset_benchm import expand_wide_fractures_gt, dilate_labels | |
| # ------------------------------------------------------------ | |
| # Device | |
| # ------------------------------------------------------------ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ------------------------------------------------------------ | |
| # Canny edge detection | |
| # ------------------------------------------------------------ | |
| def canny_fn(img, sigma, lt, ht): | |
| """Apply Canny edge detection using skimage.""" | |
| if img is None: | |
| return None | |
| if img.ndim == 3: | |
| gray = color.rgb2gray(img) | |
| else: | |
| gray = img.astype(float) / 255. | |
| edges = canny( | |
| gray, | |
| sigma=sigma, | |
| low_threshold=lt, | |
| high_threshold=ht | |
| ) | |
| return (255 - edges * 255).astype(np.uint8) | |
| # ------------------------------------------------------------ | |
| # Phase Congruency (phasepack) | |
| # ------------------------------------------------------------ | |
| def phase_congruency_fn( | |
| x, | |
| img, | |
| nscale, | |
| norient, | |
| minWaveLength, | |
| mult, | |
| sigmaOnf, | |
| k, | |
| cutOff, | |
| g, | |
| noiseMethod, | |
| ): | |
| """Compute phase congruency with adjustable parameters.""" | |
| if img is None: | |
| return None | |
| if img.ndim == 3: | |
| gray = color.rgb2gray(img) | |
| else: | |
| gray = img.astype(float) / 255. | |
| pc, m, ori, ft, PC, EO, T = phasecong( | |
| gray, | |
| nscale=nscale, | |
| norient=norient, | |
| minWaveLength=minWaveLength, | |
| mult=mult, | |
| sigmaOnf=sigmaOnf, | |
| k=k, | |
| cutOff=cutOff, | |
| g=g, | |
| noiseMethod=noiseMethod, | |
| ) | |
| # Threshold using slider | |
| pc = pc < x | |
| return (pc * 255).astype(np.uint8) | |
| # ------------------------------------------------------------ | |
| # Sato vesselness-like filter | |
| # ------------------------------------------------------------ | |
| sato_sigmas_list = [ | |
| range(1, 5), | |
| range(1, 20, 4), | |
| (2,), | |
| (1,), | |
| ] | |
| def sato_fn(img, x, sigmas): | |
| """Sato ridge detection over selected sigma set.""" | |
| if img is None: | |
| return None | |
| gray = color.rgb2gray(img) | |
| return np.float64(sato(gray, sato_sigmas_list[sigmas]) < x) | |
| # ------------------------------------------------------------ | |
| # Compute metrics | |
| # ------------------------------------------------------------ | |
| def compute_metrics_ui(gt_img, pred_img, threshold): | |
| if gt_img is None or pred_img is None: | |
| return None | |
| # Normalise to [0,1] | |
| gt = np.array(gt_img, dtype=np.uint8) | |
| pred = np.array(pred_img, dtype=np.uint8) | |
| if gt.ndim == 3: | |
| gt = gt[..., 0] | |
| if pred.ndim == 3: | |
| pred = pred[..., 0] | |
| gt = dilate_labels(gt) | |
| metrics = eval_single(gt, pred, threshold=int(threshold*255), | |
| device=device) | |
| df = pd.DataFrame([metrics]) | |
| df = df.round(3) | |
| return df | |
| # ------------------------------------------------------------ | |
| # Deep learning model loading | |
| # ------------------------------------------------------------ | |
| def load_model(model_name: str): | |
| """Load segmentation model weights.""" | |
| if model_name.lower() == "unet": | |
| model = UNet(init_features=64) | |
| weight_path = "model/unet.pt" | |
| elif model_name.lower() == "segformer": | |
| model = Segformer( | |
| encoder_name='resnet34', | |
| encoder_depth=5, | |
| encoder_weights='imagenet', | |
| decoder_segmentation_channels=256, | |
| in_channels=4, | |
| classes=1, | |
| activation='sigmoid' | |
| ) | |
| weight_path = "model/segformer.pt" | |
| else: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| model.load_state_dict( | |
| torch.load(weight_path, weights_only=True, map_location=torch.device('cpu')) | |
| ) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| # ------------------------------------------------------------ | |
| # Inference on RGB + DEM pair | |
| # ------------------------------------------------------------ | |
| def run_inference(img_path, dem_path, model_name): | |
| """Run patch-based inference for fracture segmentation.""" | |
| model = load_model(model_name) | |
| img = io.imread(img_path) | |
| dem = io.imread(dem_path) | |
| # Ensure RGB format | |
| if img.ndim == 2: | |
| img = np.stack([img, img, img], axis=-1) | |
| if img.shape[2] > 3: | |
| img = img[:, :, :3] | |
| # Merge RGB + DEM | |
| combined = np.concatenate((img[:, :, :3], np.expand_dims(dem, 2)), 2) | |
| patch_shape = 256 | |
| h, w, c = combined.shape | |
| # Padding for patchify | |
| pad_h = (patch_shape - h % patch_shape) % patch_shape | |
| pad_w = (patch_shape - w % patch_shape) % patch_shape | |
| combined_padded = np.pad( | |
| combined, | |
| ((0, pad_h), (0, pad_w), (0, 0)), | |
| mode="constant", | |
| constant_values=0, | |
| ) | |
| # Patchify | |
| patches = patchify( | |
| combined_padded, | |
| (patch_shape, patch_shape, c), | |
| step=patch_shape, | |
| ) | |
| pred_patches = [] | |
| for i in range(patches.shape[0]): | |
| for j in range(patches.shape[1]): | |
| single_patch = patches[i, j, :, :, :, :] | |
| single_patch = torch.Tensor(np.array(single_patch)) | |
| single_patch = single_patch.permute(0, 3, 1, 2) / 255. | |
| with torch.no_grad(): | |
| patch_pred = model(single_patch.to(device)) | |
| pred_patches.append(patch_pred.cpu()) | |
| # Reshape back to full image | |
| pred = np.array(pred_patches) | |
| pred = np.reshape( | |
| pred, | |
| (patches.shape[0], patches.shape[1], 1, patch_shape, patch_shape, 1), | |
| ) | |
| pred = unpatchify(pred, combined_padded.shape[:2] + (1,)) | |
| pred = pred[:h, :w, :] | |
| pred = (255 - pred * 255).astype(np.uint8) | |
| return Image.fromarray(img[:, :, :3]), Image.fromarray(pred.reshape(h, w)) | |
| # ------------------------------------------------------------ | |
| # User Interface | |
| # ------------------------------------------------------------ | |
| with gr.Blocks(title="Fractex2D Segmentation") as demo: | |
| gr.Markdown("# **Fractex2D – Fracture Detection**") | |
| gr.Markdown( | |
| """ | |
| Try out deep models that use RGB+DEM inputs along with classic vision methods that work on RGB images. | |
| Support for RGB-only deep models is on the way. | |
| """ | |
| ) | |
| with gr.Row(): | |
| # ------------------------------------------------------------ | |
| # TAB 1 — DEEP LEARNING | |
| # ------------------------------------------------------------ | |
| with gr.Tab("DEEP LEARNING"): | |
| gr.Markdown( | |
| """ | |
| ## Deep Learning Segmentation | |
| Patch-based fracture segmentation using **UNet** or **SegFormer** trained on [FraXet]() dataset. | |
| **Requirements before running:** | |
| - RGB image: `.png` or `.tif` | |
| - DEM: `.tif` | |
| - Both must have **same resolution** | |
| The model processes the RGB + DEM pair in 256×256 patches internally to produce a binary fracture map, while still allowing you to **input images of any size**. | |
| """ | |
| ) | |
| with gr.Row(): | |
| rgb_input = gr.File(type="filepath", label="RGB image (.png/.tif)") | |
| dem_input = gr.File(type="filepath", label="DEM (.tif)") | |
| model_choice = gr.Dropdown( | |
| choices=["unet", "segformer"], | |
| value="segformer", | |
| label="Model", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): # empty column to push btn to center | |
| pass | |
| run_btn = gr.Button("Run Segmentation", elem_id="run-button") | |
| with gr.Column(scale=1): # empty column to balance | |
| pass | |
| with gr.Row(): | |
| rgb_show = gr.Image(type="pil", label="Input RGB") | |
| pred_show = gr.Image(type="pil", label="Prediction") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/kl5-s3_1.png", "examples/kl5-s3-dem_1.tif", "unet"], | |
| ["examples/kl5-s3_1.png", "examples/kl5-s3-dem_1.tif", "segformer"], | |
| ], | |
| inputs=[rgb_input, dem_input, model_choice], | |
| ) | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[rgb_input, dem_input, model_choice], | |
| outputs=[rgb_show, pred_show], | |
| ) | |
| # ------------------------------------------------------------ | |
| # TAB 2 — SATO FILTER | |
| # ------------------------------------------------------------ | |
| with gr.Tab("Sato"): | |
| gr.Markdown( | |
| """ | |
| ## Sato Ridge Detection | |
| Vesselness-inspired filter (scikit-image) useful for enhancing elongated structures https://doi.org/10.1016/S1361-8415(98)80009-1. | |
| Adjust threshold and σ-sets to explore different ridge responses. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| sato_in = gr.Image(value="examples/kl5-s3_1.png") | |
| sato_x = gr.Slider(0, 1, value=0.08, step=0.01, label="Threshold") | |
| sato_sigmas = gr.Radio( | |
| [('range(1,5)', 0), | |
| ('range(1,20,4)', 1), | |
| ('(2,)', 2), | |
| ('(1,)', 3)], | |
| label="Sigma set", | |
| value=0, | |
| ) | |
| with gr.Column(scale=1): | |
| sato_out = gr.Image() | |
| # Auto update | |
| sato_in.change(sato_fn, [sato_in, sato_x, sato_sigmas], sato_out) | |
| sato_x.change(sato_fn, [sato_in, sato_x, sato_sigmas], sato_out) | |
| sato_sigmas.change(sato_fn, [sato_in, sato_x, sato_sigmas], sato_out) | |
| # ------------------------------------------------------------ | |
| # TAB 3 — CANNY | |
| # ------------------------------------------------------------ | |
| with gr.Tab("Canny"): | |
| gr.Markdown( | |
| """ | |
| ## Canny edge detection | |
| Canny edge detection (scikit-image) with normalised thresholds https://doi.org/10.1109/TPAMI.1986.4767851. | |
| - **sigma** controls Gaussian smoothing | |
| - **lt / ht** are low/high thresholds in the range 0–1 | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| canny_in = gr.Image(value="examples/kl5-s3_1.png") | |
| canny_sigma = gr.Slider( | |
| 0, 7, | |
| value=1.37, | |
| step=0.01, | |
| label="Sigma" | |
| ) | |
| canny_lt = gr.Slider( | |
| 0, 1, | |
| value=0.37, | |
| step=0.01, | |
| label="Low threshold" | |
| ) | |
| canny_ht = gr.Slider( | |
| 0, 1, | |
| value=0.58, | |
| step=0.01, | |
| label="High threshold" | |
| ) | |
| with gr.Column(scale=1): | |
| canny_out = gr.Image() | |
| canny_in.change(canny_fn, [canny_in, canny_sigma, canny_lt, canny_ht], canny_out) | |
| canny_sigma.change(canny_fn, [canny_in, canny_sigma, canny_lt, canny_ht], canny_out) | |
| canny_lt.change(canny_fn, [canny_in, canny_sigma, canny_lt, canny_ht], canny_out) | |
| canny_ht.change(canny_fn, [canny_in, canny_sigma, canny_lt, canny_ht], canny_out) | |
| # ------------------------------------------------------------ | |
| # TAB 4 — PHASE CONGRUENCY | |
| # ------------------------------------------------------------ | |
| with gr.Tab("Phase Congruency"): | |
| gr.Markdown( | |
| """ | |
| ## Phase Congruency | |
| Edge/line detection ([phasepack](https://github.com/alimuldal/phasepack)) based on phase agreement in the frequency domain https://doi.org/10.1007/s004260000024. | |
| Computationally expensive → runs **only on button click**. | |
| Useful for illumination-invariant structural extraction. | |
| """ | |
| ) | |
| with gr.Row(): | |
| pc_in = gr.Image(value="examples/kl5-s3_1.png") | |
| pc_out = gr.Image() | |
| with gr.Row(): | |
| with gr.Column(scale=1): # empty column to push btn to center | |
| pass | |
| pc_btn = gr.Button("Detect") | |
| with gr.Column(scale=1): # empty column to balance | |
| pass | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| x_pc = gr.Slider(0, 1, value=0.15, step=0.01, label="Threshold") | |
| pc_nscale = gr.Slider(3, 10, value=6, step=1, label="nscale") | |
| pc_norient = gr.Slider(1, 16, value=8, step=1, label="norient") | |
| pc_minWL = gr.Slider(1, 10, value=4, step=1, label="minWaveLength") | |
| pc_mult = gr.Slider(1.0, 5.0, value=2.1, step=0.1, label="mult") | |
| with gr.Column(scale=1): | |
| pc_sigma = gr.Slider(0.1, 1.0, value=0.35, step=0.05, label="sigmaOnf") | |
| pc_k = gr.Slider(0.1, 10.0, value=2.8, step=0.1, label="k") | |
| pc_cutoff = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="cutOff") | |
| pc_g = gr.Slider(0.1, 50.0, value=10.6, step=0.5, label="g") | |
| pc_noise = gr.Slider(-2, 2, value=-1, step=1, label="noiseMethod") | |
| pc_btn.click( | |
| fn=phase_congruency_fn, | |
| inputs=[ | |
| x_pc, pc_in, pc_nscale, pc_norient, pc_minWL, pc_mult, | |
| pc_sigma, pc_k, pc_cutoff, pc_g, pc_noise | |
| ], | |
| outputs=pc_out, | |
| ) | |
| # ------------------------------------------------------------ | |
| # TAB 5 — METRICS | |
| # ------------------------------------------------------------ | |
| with gr.Tab("Metrics computation"): | |
| gr.Markdown( | |
| """ | |
| ## Segmentation Metrics | |
| Compute quantitative metrics between a **prediction** and a **ground-truth** (1px wide annotation). | |
| Both images must be aligned and have the same resolution. | |
| """ | |
| ) | |
| with gr.Row(): | |
| gt_input = gr.Image(label="Ground truth", type="numpy") | |
| pred_input = gr.Image(label="Prediction", type="numpy") | |
| with gr.Row(): | |
| thresh = gr.Slider( | |
| 0, 1, | |
| value=0.1, | |
| step=0.01, | |
| label="Binarisation threshold" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pass | |
| metric_btn = gr.Button("Compute metrics") | |
| with gr.Column(scale=1): | |
| pass | |
| metric_table = gr.Dataframe( | |
| headers=[ | |
| "mse", "psnr", "ssim", "ae", | |
| "acc", "prec", "rec", "spec", | |
| "f1", "dice", "iou", "ck", "roc_auc" | |
| ], | |
| label="Metrics (single image pair)" | |
| ) | |
| metric_btn.click( | |
| fn=compute_metrics_ui, | |
| inputs=[gt_input, pred_input, thresh], | |
| outputs=metric_table, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/kl5-s3_1-gt.png", "examples/unet-p1_pred_kl5-s3_1.png", 0.1], | |
| ], | |
| inputs=[gt_input, pred_input, thresh], | |
| ) | |
| # ------------------------------------------------------------ | |
| # Extra reference | |
| # ------------------------------------------------------------ | |
| gr.Markdown( | |
| """ | |
| The sample images included with this interface originate from: | |
| Nordbäck, N., & Ovaskainen, N. (2022). UAV-acquired orthomosaics of \ | |
| Loviisa shoreline outcrops (Version 1.0.0) [Dataset]. Zenodo. \ | |
| https://doi.org/10.5281/zenodo.7077519 | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |