Ayoub commited on
Commit
75109a1
·
1 Parent(s): ee9ae36
Files changed (1) hide show
  1. app copy.py +0 -91
app copy.py DELETED
@@ -1,91 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- from hydra.utils import instantiate
5
- from omegaconf import OmegaConf
6
- from patchify import patchify, unpatchify
7
- from PIL import Image
8
- from skimage import io
9
- import os
10
-
11
- from src.unet import UNet
12
-
13
-
14
- # --- Load Hydra config and model ---
15
- model_path = "model"
16
- cfg = OmegaConf.load(os.path.join(model_path, "config.yaml"))
17
- cfg = OmegaConf.merge(cfg, OmegaConf.load(os.path.join(model_path, ".hydra/config.yaml")))
18
-
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- model = UNet*()
21
- model.load_state_dict(torch.load(os.path.join(model_path, "model.pt"), map_location=device))
22
- model.to(device)
23
- model.eval()
24
-
25
-
26
- def load_image(img_file, is_dem=False):
27
- """Load input image (RGB tif/png or DEM tif)."""
28
- ext = os.path.splitext(img_file.name)[-1].lower()
29
-
30
- if ext in [".png", ".jpg", ".jpeg", ".tif", ".tiff"]:
31
- arr = io.imread(img_file.name)
32
- else:
33
- raise ValueError(f"Unsupported file format: {ext}")
34
-
35
- if is_dem:
36
- # DEM should be single-channel grayscale
37
- if arr.ndim == 3:
38
- arr = arr[:, :, 0]
39
- else:
40
- # Ensure RGB
41
- if arr.ndim == 2: # grayscale tif
42
- arr = np.stack([arr, arr, arr], axis=-1)
43
- elif arr.shape[2] > 3: # strip alpha channel if exists
44
- arr = arr[:, :, :3]
45
- return arr
46
-
47
-
48
- def predict(image_file, dem_file):
49
- """Run inference on uploaded image (RGB tif/png) + DEM tif"""
50
- img = load_image(image_file, is_dem=False)
51
- dem = load_image(dem_file, is_dem=True)
52
-
53
- img = np.concatenate((img[:, :, :3], np.expand_dims(dem, 2)), axis=2)
54
-
55
- patch_shape = cfg.dataset.shape
56
- patches = patchify(img, (patch_shape, patch_shape, cfg.in_channels), step=256)
57
-
58
- pred_patches = []
59
- for i in range(patches.shape[0]):
60
- for j in range(patches.shape[1]):
61
- single_patch = patches[i, j, :, :, :, :]
62
- single_patch = torch.Tensor(np.array(single_patch))
63
- single_patch = single_patch.permute(0, 3, 1, 2) / 255.
64
-
65
- with torch.no_grad():
66
- patch_pred = model(single_patch.to(device))
67
-
68
- pred_patches.append(patch_pred.cpu())
69
-
70
- pred = np.array(pred_patches)
71
- pred = np.reshape(pred, (patches.shape[0], patches.shape[1], 1, patch_shape, patch_shape, 1))
72
- pred = unpatchify(pred, (img.shape[0], img.shape[1], 1))
73
-
74
- pred = np.uint8(pred.reshape(img.shape[0], img.shape[1]) * 255)
75
- return Image.fromarray(pred)
76
-
77
-
78
- # --- Gradio UI ---
79
- demo = gr.Interface(
80
- fn=predict,
81
- inputs=[
82
- gr.File(type="file", label="Input RGB Image (.png or .tif)"),
83
- gr.File(type="file", label="DEM Image (.tif)")
84
- ],
85
- outputs=gr.Image(type="pil", label="Prediction"),
86
- title="Fractex2D Segmentation",
87
- description="Upload an RGB image (png or tif) and a DEM image (tif). The model outputs a segmentation map."
88
- )
89
-
90
- if __name__ == "__main__":
91
- demo.launch()