Spaces:
Runtime error
Runtime error
File size: 7,507 Bytes
49e75a7 76c7b4d 3a08f1d b3b2242 49e75a7 1b6735c 673aea5 9021e76 0c8976a 9021e76 be207af 0c8976a 1b6735c 38dfa48 3a08f1d 49e75a7 9021e76 101609b 2d58642 101609b 9021e76 3a08f1d 101609b 9021e76 49e75a7 9021e76 18df9a4 9021e76 76c7b4d 9021e76 49e75a7 203d933 9021e76 76c7b4d 9021e76 49e75a7 673aea5 4298555 be207af 9021e76 203d933 18df9a4 b3b2242 9021e76 49e75a7 9021e76 18df9a4 49e75a7 673aea5 49e75a7 9021e76 673aea5 9021e76 2c6124d 673aea5 49e75a7 9021e76 76c7b4d 4298555 592792e 49e75a7 76c7b4d b3b2242 9021e76 18df9a4 76c7b4d 673aea5 4298555 673aea5 4298555 673aea5 9021e76 b470ccf 4298555 9021e76 18df9a4 9021e76 101609b 2d58642 101609b 2d58642 101609b 2d58642 101609b 3a08f1d 49e75a7 9021e76 3a08f1d 824527d 3a08f1d 824527d 3a08f1d 824527d 3a08f1d 824527d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | # environment setup
import os
val = os.environ.get("OMP_NUM_THREADS", "")
try:
int(val)
except Exception:
os.environ["OMP_NUM_THREADS"] = "1"
os.environ.setdefault("MPLCONFIGDIR", "/var/tmp/matplotlib")
os.environ.setdefault("HF_HOME", "/data/.huggingface")
os.environ.setdefault("TORCH_HOME", "/data/.cache/torch")
import tempfile, shutil, glob
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from affine import Affine
import rasterio
from rasterio.plot import show
import torch
import gradio as gr
from detectron2.engine import DefaultPredictor
from detectree2.preprocessing.tiling import tile_data
from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns
from detectree2.models.predict import predict_on_data
from detectree2.models.train import setup_cfg
# config
WEIGHT_DIR = os.getenv("DTR2_WEIGHT_DIR", "/data/weights")
os.makedirs(WEIGHT_DIR, exist_ok=True)
# config
WEIGHTS = {
"Flexible": os.path.join(WEIGHT_DIR, "250312_flexi.pth"),
"Forest": os.path.join(WEIGHT_DIR, "250711_tropical_closed_canopy.pth"),
"Urban": os.path.join(WEIGHT_DIR, "urban_trees_Cambridge_20230630.pth"),
}
DEFAULT_ENV = "Flexible"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Keep PyTorch from oversubscribing CPU
try:
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
except Exception:
pass
# Tiling defaults
BUFFER, TILE_W, TILE_H = 30, 40, 40
def run_detectree2(geotiff_path: str, environment: str, conf_threshold: float):
workdir = tempfile.mkdtemp(prefix="dtr2_")
tiles_dir = os.path.join(workdir, "tilespred")
os.makedirs(tiles_dir, exist_ok=True)
os.chdir(workdir)
try:
infile = os.path.join(workdir, os.path.basename(geotiff_path))
shutil.copy2(geotiff_path, infile)
# Must be georeferenced
with rasterio.open(infile) as src:
if src.crs is None or src.transform == Affine.identity():
return None, None, None, "Error: input GeoTIFF has no CRS/transform."
tile_data(infile, tiles_dir, BUFFER, TILE_W, TILE_H)
weights = WEIGHTS.get(environment) or WEIGHTS[DEFAULT_ENV]
cfg = setup_cfg(update_model=weights)
cfg.MODEL.DEVICE = DEVICE
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = float(os.getenv("DTR2_SCORE_THRESH", "0.25"))
cfg.OUTPUT_DIR = os.path.join(workdir, "train_outputs")
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
predictor = DefaultPredictor(cfg)
with torch.inference_mode():
predict_on_data(tiles_dir, predictor=predictor)
for d in (os.path.join(tiles_dir, "predictions"),
os.path.join(cfg.OUTPUT_DIR, "predictions"),
os.path.join(workdir, "predictions")):
if os.path.isdir(d):
preds_dir = d
break
else:
return None, None, None, "Error: no predictions folder found."
if not glob.glob(os.path.join(preds_dir, "*.json")):
return None, None, None, (
f"Error: no prediction JSONs in {preds_dir}. "
f"Try lowering SCORE_THRESH_TEST (now {cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST})."
)
# project to geojson
preds_geo_dir = os.path.join(tiles_dir, "predictions_geo")
os.makedirs(preds_geo_dir, exist_ok=True)
project_to_geojson(tiles_dir, preds_dir, preds_geo_dir)
geos = glob.glob(os.path.join(preds_geo_dir, "*.geojson"))
if not geos:
return None, None, None, f"Error: No geojson files found in {preds_geo_dir}."
# stitch + clean + filter
crowns = stitch_crowns(preds_geo_dir, 1)
cleaned = clean_crowns(crowns, 0.6, confidence=0)
cleaned = cleaned[cleaned["Confidence_score"] > float(conf_threshold)].set_geometry(
cleaned.simplify(0.4)
)
# save vector
gpkg_path = os.path.join(workdir, "crowns_out.gpkg")
cleaned.to_file(gpkg_path, driver="GPKG")
# save overlay
overlay_path = os.path.join(workdir, "crowns_out.png")
with rasterio.open(infile) as src:
fig, ax = plt.subplots(figsize=(10, 10))
show(src, ax=ax)
if not cleaned.empty:
cleaned.plot(ax=ax, facecolor="none", edgecolor="cyan", linewidth=1.2)
ax.set_xticks([]); ax.set_yticks([]); ax.set_frame_on(False)
plt.savefig(overlay_path, dpi=220, bbox_inches="tight", pad_inches=0)
plt.close(fig)
stats = {"total_trees": int(len(cleaned))}
return overlay_path, gpkg_path, stats, ""
except Exception as e:
return None, None, None, f"Error: {e}"
def infer(file_path, environment, confidence):
return run_detectree2(file_path, environment, confidence)
def save_to_persistent(path):
import traceback
try:
os.makedirs(WEIGHT_DIR, exist_ok=True)
fname = os.path.basename(path)
dest = os.path.join(WEIGHT_DIR, fname)
total, used, free = shutil.disk_usage("/data")
need = os.path.getsize(path)
if need > free:
return f"Not enough space: need {need/1024**2:.2f} MB, free {free/1024**2:.2f} MB."
shutil.copy2(path, dest)
return f"Saved {fname} to {WEIGHT_DIR}"
except Exception as e:
return "Upload failed:\n" + traceback.format_exc()
def list_persistent():
out = []
for root, dirs, files in os.walk("/data"):
for f in files:
p = os.path.join(root, f)
try:
sz = os.path.getsize(p) / (1024**2)
out.append(f"{p} — {sz:.2f} MB")
except Exception:
out.append(p)
return "\n".join(sorted(out)) or "(empty)"
# gradio ui
with gr.Blocks(title="Detectree2 – Landscape Prediction") as demo:
gr.Markdown("# Detectree2\nModels load from `/data/weights` (persistent).")
with gr.Tabs():
with gr.Tab("Predict"):
with gr.Row():
inp_file = gr.File(label="Input (.tif / .tiff)", file_types=[".tif", ".tiff"], type="filepath")
env_dd = gr.Dropdown(label="Environment", choices=list(WEIGHTS.keys()), value=DEFAULT_ENV)
conf = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence threshold")
run_btn = gr.Button("Run prediction", variant="primary")
out_img = gr.Image(label="Overlay (PNG)", type="filepath")
out_gpkg = gr.File(label="Crowns (GPKG)")
out_stats = gr.JSON(label="Statistics")
out_err = gr.Textbox(label="Logs / Errors", interactive=False)
run_btn.click(
infer,
inputs=[inp_file, env_dd, conf],
outputs=[out_img, out_gpkg, out_stats, out_err],
concurrency_limit=1,
)
with gr.Tab("Manage Weights (/data/weights)"):
up = gr.File(label="Upload .pth", file_types=[".pth"])
up_out = gr.Textbox(label="Upload status")
up.upload(save_to_persistent, up, up_out)
list_btn = gr.Button("List /data contents")
list_out = gr.Textbox(label="Persistent files", lines=12)
list_btn.click(lambda: list_persistent(), None, list_out)
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
demo.queue().launch(server_name="0.0.0.0", server_port=port, max_threads=1, show_api=False)
|