File size: 7,507 Bytes
49e75a7
76c7b4d
 
 
 
 
 
 
 
3a08f1d
 
b3b2242
49e75a7
1b6735c
 
 
 
673aea5
9021e76
 
0c8976a
9021e76
be207af
0c8976a
1b6735c
 
 
 
38dfa48
3a08f1d
 
 
 
49e75a7
9021e76
101609b
2d58642
101609b
9021e76
3a08f1d
101609b
9021e76
 
49e75a7
 
 
 
 
 
 
 
 
9021e76
18df9a4
9021e76
76c7b4d
9021e76
49e75a7
203d933
9021e76
76c7b4d
9021e76
 
49e75a7
673aea5
 
 
4298555
 
be207af
9021e76
 
203d933
18df9a4
b3b2242
 
 
9021e76
 
49e75a7
 
9021e76
18df9a4
 
 
49e75a7
 
 
 
 
673aea5
 
 
 
 
 
 
49e75a7
9021e76
673aea5
9021e76
2c6124d
673aea5
 
 
 
49e75a7
9021e76
76c7b4d
4298555
 
 
592792e
49e75a7
76c7b4d
b3b2242
9021e76
18df9a4
76c7b4d
673aea5
 
4298555
673aea5
4298555
673aea5
 
 
9021e76
b470ccf
4298555
9021e76
18df9a4
 
9021e76
 
 
 
101609b
 
 
 
 
 
 
 
 
 
2d58642
101609b
 
2d58642
101609b
2d58642
101609b
3a08f1d
 
 
 
 
 
 
 
 
 
 
 
 
49e75a7
9021e76
3a08f1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824527d
3a08f1d
 
 
 
824527d
3a08f1d
 
 
824527d
 
3a08f1d
 
824527d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# environment setup
import os
val = os.environ.get("OMP_NUM_THREADS", "")
try:
    int(val)
except Exception:
    os.environ["OMP_NUM_THREADS"] = "1"

os.environ.setdefault("MPLCONFIGDIR", "/var/tmp/matplotlib")
os.environ.setdefault("HF_HOME", "/data/.huggingface")
os.environ.setdefault("TORCH_HOME", "/data/.cache/torch")

import tempfile, shutil, glob
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from affine import Affine
import rasterio
from rasterio.plot import show
import torch
import gradio as gr

from detectron2.engine import DefaultPredictor
from detectree2.preprocessing.tiling import tile_data
from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns
from detectree2.models.predict import predict_on_data
from detectree2.models.train import setup_cfg

# config
WEIGHT_DIR = os.getenv("DTR2_WEIGHT_DIR", "/data/weights")  
os.makedirs(WEIGHT_DIR, exist_ok=True)                      

# config
WEIGHTS = {
    "Flexible": os.path.join(WEIGHT_DIR, "250312_flexi.pth"),
    "Forest": os.path.join(WEIGHT_DIR, "250711_tropical_closed_canopy.pth"),  
    "Urban": os.path.join(WEIGHT_DIR, "urban_trees_Cambridge_20230630.pth"),
}

DEFAULT_ENV = "Flexible"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Keep PyTorch from oversubscribing CPU
try:
    torch.set_num_threads(1)
    torch.set_num_interop_threads(1)
except Exception:
    pass

# Tiling defaults
BUFFER, TILE_W, TILE_H = 30, 40, 40

def run_detectree2(geotiff_path: str, environment: str, conf_threshold: float):
    workdir = tempfile.mkdtemp(prefix="dtr2_")
    tiles_dir = os.path.join(workdir, "tilespred")
    os.makedirs(tiles_dir, exist_ok=True)
    os.chdir(workdir)

    try:
        infile = os.path.join(workdir, os.path.basename(geotiff_path))
        shutil.copy2(geotiff_path, infile)

        # Must be georeferenced
        with rasterio.open(infile) as src:
            if src.crs is None or src.transform == Affine.identity():
                return None, None, None, "Error: input GeoTIFF has no CRS/transform."

        tile_data(infile, tiles_dir, BUFFER, TILE_W, TILE_H)

        weights = WEIGHTS.get(environment) or WEIGHTS[DEFAULT_ENV]
        cfg = setup_cfg(update_model=weights)
        cfg.MODEL.DEVICE = DEVICE
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = float(os.getenv("DTR2_SCORE_THRESH", "0.25"))
        cfg.OUTPUT_DIR = os.path.join(workdir, "train_outputs")
        os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

        predictor = DefaultPredictor(cfg)

        with torch.inference_mode():
            predict_on_data(tiles_dir, predictor=predictor)

        for d in (os.path.join(tiles_dir, "predictions"),
                  os.path.join(cfg.OUTPUT_DIR, "predictions"),
                  os.path.join(workdir, "predictions")):
            if os.path.isdir(d):
                preds_dir = d
                break
        else:
            return None, None, None, "Error: no predictions folder found."

        if not glob.glob(os.path.join(preds_dir, "*.json")):
            return None, None, None, (
                f"Error: no prediction JSONs in {preds_dir}. "
                f"Try lowering SCORE_THRESH_TEST (now {cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST})."
            )

        # project to geojson
        preds_geo_dir = os.path.join(tiles_dir, "predictions_geo")
        os.makedirs(preds_geo_dir, exist_ok=True)
        project_to_geojson(tiles_dir, preds_dir, preds_geo_dir)

        geos = glob.glob(os.path.join(preds_geo_dir, "*.geojson"))
        if not geos:
            return None, None, None, f"Error: No geojson files found in {preds_geo_dir}."

        # stitch + clean + filter
        crowns = stitch_crowns(preds_geo_dir, 1)
        cleaned = clean_crowns(crowns, 0.6, confidence=0)
        cleaned = cleaned[cleaned["Confidence_score"] > float(conf_threshold)].set_geometry(
            cleaned.simplify(0.4)
        )

        # save vector
        gpkg_path = os.path.join(workdir, "crowns_out.gpkg")
        cleaned.to_file(gpkg_path, driver="GPKG")

        # save overlay
        overlay_path = os.path.join(workdir, "crowns_out.png")
        with rasterio.open(infile) as src:
            fig, ax = plt.subplots(figsize=(10, 10))
            show(src, ax=ax)
            if not cleaned.empty:
                cleaned.plot(ax=ax, facecolor="none", edgecolor="cyan", linewidth=1.2)
            ax.set_xticks([]); ax.set_yticks([]); ax.set_frame_on(False)
            plt.savefig(overlay_path, dpi=220, bbox_inches="tight", pad_inches=0)
            plt.close(fig)

        stats = {"total_trees": int(len(cleaned))}
        return overlay_path, gpkg_path, stats, ""

    except Exception as e:
        return None, None, None, f"Error: {e}"

def infer(file_path, environment, confidence):
    return run_detectree2(file_path, environment, confidence)

def save_to_persistent(path):
    import traceback
    try:
        os.makedirs(WEIGHT_DIR, exist_ok=True)
        fname = os.path.basename(path)
        dest = os.path.join(WEIGHT_DIR, fname)

        total, used, free = shutil.disk_usage("/data")
        need = os.path.getsize(path)
        if need > free:
            return f"Not enough space: need {need/1024**2:.2f} MB, free {free/1024**2:.2f} MB."

        shutil.copy2(path, dest)
        return f"Saved {fname} to {WEIGHT_DIR}"
    except Exception as e:
        return "Upload failed:\n" + traceback.format_exc()


def list_persistent():
    out = []
    for root, dirs, files in os.walk("/data"):
        for f in files:
            p = os.path.join(root, f)
            try:
                sz = os.path.getsize(p) / (1024**2)
                out.append(f"{p}{sz:.2f} MB")
            except Exception:
                out.append(p)
    return "\n".join(sorted(out)) or "(empty)"

# gradio ui
with gr.Blocks(title="Detectree2 – Landscape Prediction") as demo:
    gr.Markdown("# Detectree2\nModels load from `/data/weights` (persistent).")

    with gr.Tabs():
        with gr.Tab("Predict"):
            with gr.Row():
                inp_file = gr.File(label="Input (.tif / .tiff)", file_types=[".tif", ".tiff"], type="filepath")
                env_dd = gr.Dropdown(label="Environment", choices=list(WEIGHTS.keys()), value=DEFAULT_ENV)
                conf = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence threshold")
            run_btn = gr.Button("Run prediction", variant="primary")

            out_img = gr.Image(label="Overlay (PNG)", type="filepath")
            out_gpkg = gr.File(label="Crowns (GPKG)")
            out_stats = gr.JSON(label="Statistics")
            out_err = gr.Textbox(label="Logs / Errors", interactive=False)

            run_btn.click(
                infer,
                inputs=[inp_file, env_dd, conf],
                outputs=[out_img, out_gpkg, out_stats, out_err],
                concurrency_limit=1,
            )

        with gr.Tab("Manage Weights (/data/weights)"):
            up = gr.File(label="Upload .pth", file_types=[".pth"])
            up_out = gr.Textbox(label="Upload status")
            up.upload(save_to_persistent, up, up_out)

            list_btn = gr.Button("List /data contents")
            list_out = gr.Textbox(label="Persistent files", lines=12)
            list_btn.click(lambda: list_persistent(), None, list_out)

if __name__ == "__main__":
    port = int(os.getenv("PORT", "7860"))
    demo.queue().launch(server_name="0.0.0.0", server_port=port, max_threads=1, show_api=False)