# environment setup import os val = os.environ.get("OMP_NUM_THREADS", "") try: int(val) except Exception: os.environ["OMP_NUM_THREADS"] = "1" os.environ.setdefault("MPLCONFIGDIR", "/var/tmp/matplotlib") os.environ.setdefault("HF_HOME", "/data/.huggingface") os.environ.setdefault("TORCH_HOME", "/data/.cache/torch") import tempfile, shutil, glob import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from affine import Affine import rasterio from rasterio.plot import show import torch import gradio as gr from detectron2.engine import DefaultPredictor from detectree2.preprocessing.tiling import tile_data from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns from detectree2.models.predict import predict_on_data from detectree2.models.train import setup_cfg # config WEIGHT_DIR = os.getenv("DTR2_WEIGHT_DIR", "/data/weights") os.makedirs(WEIGHT_DIR, exist_ok=True) # config WEIGHTS = { "Flexible": os.path.join(WEIGHT_DIR, "250312_flexi.pth"), "Forest": os.path.join(WEIGHT_DIR, "250711_tropical_closed_canopy.pth"), "Urban": os.path.join(WEIGHT_DIR, "urban_trees_Cambridge_20230630.pth"), } DEFAULT_ENV = "Flexible" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Keep PyTorch from oversubscribing CPU try: torch.set_num_threads(1) torch.set_num_interop_threads(1) except Exception: pass # Tiling defaults BUFFER, TILE_W, TILE_H = 30, 40, 40 def run_detectree2(geotiff_path: str, environment: str, conf_threshold: float): workdir = tempfile.mkdtemp(prefix="dtr2_") tiles_dir = os.path.join(workdir, "tilespred") os.makedirs(tiles_dir, exist_ok=True) os.chdir(workdir) try: infile = os.path.join(workdir, os.path.basename(geotiff_path)) shutil.copy2(geotiff_path, infile) # Must be georeferenced with rasterio.open(infile) as src: if src.crs is None or src.transform == Affine.identity(): return None, None, None, "Error: input GeoTIFF has no CRS/transform." tile_data(infile, tiles_dir, BUFFER, TILE_W, TILE_H) weights = WEIGHTS.get(environment) or WEIGHTS[DEFAULT_ENV] cfg = setup_cfg(update_model=weights) cfg.MODEL.DEVICE = DEVICE cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = float(os.getenv("DTR2_SCORE_THRESH", "0.25")) cfg.OUTPUT_DIR = os.path.join(workdir, "train_outputs") os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) predictor = DefaultPredictor(cfg) with torch.inference_mode(): predict_on_data(tiles_dir, predictor=predictor) for d in (os.path.join(tiles_dir, "predictions"), os.path.join(cfg.OUTPUT_DIR, "predictions"), os.path.join(workdir, "predictions")): if os.path.isdir(d): preds_dir = d break else: return None, None, None, "Error: no predictions folder found." if not glob.glob(os.path.join(preds_dir, "*.json")): return None, None, None, ( f"Error: no prediction JSONs in {preds_dir}. " f"Try lowering SCORE_THRESH_TEST (now {cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST})." ) # project to geojson preds_geo_dir = os.path.join(tiles_dir, "predictions_geo") os.makedirs(preds_geo_dir, exist_ok=True) project_to_geojson(tiles_dir, preds_dir, preds_geo_dir) geos = glob.glob(os.path.join(preds_geo_dir, "*.geojson")) if not geos: return None, None, None, f"Error: No geojson files found in {preds_geo_dir}." # stitch + clean + filter crowns = stitch_crowns(preds_geo_dir, 1) cleaned = clean_crowns(crowns, 0.6, confidence=0) cleaned = cleaned[cleaned["Confidence_score"] > float(conf_threshold)].set_geometry( cleaned.simplify(0.4) ) # save vector gpkg_path = os.path.join(workdir, "crowns_out.gpkg") cleaned.to_file(gpkg_path, driver="GPKG") # save overlay overlay_path = os.path.join(workdir, "crowns_out.png") with rasterio.open(infile) as src: fig, ax = plt.subplots(figsize=(10, 10)) show(src, ax=ax) if not cleaned.empty: cleaned.plot(ax=ax, facecolor="none", edgecolor="cyan", linewidth=1.2) ax.set_xticks([]); ax.set_yticks([]); ax.set_frame_on(False) plt.savefig(overlay_path, dpi=220, bbox_inches="tight", pad_inches=0) plt.close(fig) stats = {"total_trees": int(len(cleaned))} return overlay_path, gpkg_path, stats, "" except Exception as e: return None, None, None, f"Error: {e}" def infer(file_path, environment, confidence): return run_detectree2(file_path, environment, confidence) def save_to_persistent(path): import traceback try: os.makedirs(WEIGHT_DIR, exist_ok=True) fname = os.path.basename(path) dest = os.path.join(WEIGHT_DIR, fname) total, used, free = shutil.disk_usage("/data") need = os.path.getsize(path) if need > free: return f"Not enough space: need {need/1024**2:.2f} MB, free {free/1024**2:.2f} MB." shutil.copy2(path, dest) return f"Saved {fname} to {WEIGHT_DIR}" except Exception as e: return "Upload failed:\n" + traceback.format_exc() def list_persistent(): out = [] for root, dirs, files in os.walk("/data"): for f in files: p = os.path.join(root, f) try: sz = os.path.getsize(p) / (1024**2) out.append(f"{p} — {sz:.2f} MB") except Exception: out.append(p) return "\n".join(sorted(out)) or "(empty)" # gradio ui with gr.Blocks(title="Detectree2 – Landscape Prediction") as demo: gr.Markdown("# Detectree2\nModels load from `/data/weights` (persistent).") with gr.Tabs(): with gr.Tab("Predict"): with gr.Row(): inp_file = gr.File(label="Input (.tif / .tiff)", file_types=[".tif", ".tiff"], type="filepath") env_dd = gr.Dropdown(label="Environment", choices=list(WEIGHTS.keys()), value=DEFAULT_ENV) conf = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence threshold") run_btn = gr.Button("Run prediction", variant="primary") out_img = gr.Image(label="Overlay (PNG)", type="filepath") out_gpkg = gr.File(label="Crowns (GPKG)") out_stats = gr.JSON(label="Statistics") out_err = gr.Textbox(label="Logs / Errors", interactive=False) run_btn.click( infer, inputs=[inp_file, env_dd, conf], outputs=[out_img, out_gpkg, out_stats, out_err], concurrency_limit=1, ) with gr.Tab("Manage Weights (/data/weights)"): up = gr.File(label="Upload .pth", file_types=[".pth"]) up_out = gr.Textbox(label="Upload status") up.upload(save_to_persistent, up, up_out) list_btn = gr.Button("List /data contents") list_out = gr.Textbox(label="Persistent files", lines=12) list_btn.click(lambda: list_persistent(), None, list_out) if __name__ == "__main__": port = int(os.getenv("PORT", "7860")) demo.queue().launch(server_name="0.0.0.0", server_port=port, max_threads=1, show_api=False)