Spaces:
Runtime error
Runtime error
| # environment setup | |
| import os | |
| val = os.environ.get("OMP_NUM_THREADS", "") | |
| try: | |
| int(val) | |
| except Exception: | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ.setdefault("MPLCONFIGDIR", "/var/tmp/matplotlib") | |
| os.environ.setdefault("HF_HOME", "/data/.huggingface") | |
| os.environ.setdefault("TORCH_HOME", "/data/.cache/torch") | |
| import tempfile, shutil, glob | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from affine import Affine | |
| import rasterio | |
| from rasterio.plot import show | |
| import torch | |
| import gradio as gr | |
| from detectron2.engine import DefaultPredictor | |
| from detectree2.preprocessing.tiling import tile_data | |
| from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns | |
| from detectree2.models.predict import predict_on_data | |
| from detectree2.models.train import setup_cfg | |
| # config | |
| WEIGHT_DIR = os.getenv("DTR2_WEIGHT_DIR", "/data/weights") | |
| os.makedirs(WEIGHT_DIR, exist_ok=True) | |
| # config | |
| WEIGHTS = { | |
| "Flexible": os.path.join(WEIGHT_DIR, "250312_flexi.pth"), | |
| "Forest": os.path.join(WEIGHT_DIR, "250711_tropical_closed_canopy.pth"), | |
| "Urban": os.path.join(WEIGHT_DIR, "urban_trees_Cambridge_20230630.pth"), | |
| } | |
| DEFAULT_ENV = "Flexible" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Keep PyTorch from oversubscribing CPU | |
| try: | |
| torch.set_num_threads(1) | |
| torch.set_num_interop_threads(1) | |
| except Exception: | |
| pass | |
| # Tiling defaults | |
| BUFFER, TILE_W, TILE_H = 30, 40, 40 | |
| def run_detectree2(geotiff_path: str, environment: str, conf_threshold: float): | |
| workdir = tempfile.mkdtemp(prefix="dtr2_") | |
| tiles_dir = os.path.join(workdir, "tilespred") | |
| os.makedirs(tiles_dir, exist_ok=True) | |
| os.chdir(workdir) | |
| try: | |
| infile = os.path.join(workdir, os.path.basename(geotiff_path)) | |
| shutil.copy2(geotiff_path, infile) | |
| # Must be georeferenced | |
| with rasterio.open(infile) as src: | |
| if src.crs is None or src.transform == Affine.identity(): | |
| return None, None, None, "Error: input GeoTIFF has no CRS/transform." | |
| tile_data(infile, tiles_dir, BUFFER, TILE_W, TILE_H) | |
| weights = WEIGHTS.get(environment) or WEIGHTS[DEFAULT_ENV] | |
| cfg = setup_cfg(update_model=weights) | |
| cfg.MODEL.DEVICE = DEVICE | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = float(os.getenv("DTR2_SCORE_THRESH", "0.25")) | |
| cfg.OUTPUT_DIR = os.path.join(workdir, "train_outputs") | |
| os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) | |
| predictor = DefaultPredictor(cfg) | |
| with torch.inference_mode(): | |
| predict_on_data(tiles_dir, predictor=predictor) | |
| for d in (os.path.join(tiles_dir, "predictions"), | |
| os.path.join(cfg.OUTPUT_DIR, "predictions"), | |
| os.path.join(workdir, "predictions")): | |
| if os.path.isdir(d): | |
| preds_dir = d | |
| break | |
| else: | |
| return None, None, None, "Error: no predictions folder found." | |
| if not glob.glob(os.path.join(preds_dir, "*.json")): | |
| return None, None, None, ( | |
| f"Error: no prediction JSONs in {preds_dir}. " | |
| f"Try lowering SCORE_THRESH_TEST (now {cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST})." | |
| ) | |
| # project to geojson | |
| preds_geo_dir = os.path.join(tiles_dir, "predictions_geo") | |
| os.makedirs(preds_geo_dir, exist_ok=True) | |
| project_to_geojson(tiles_dir, preds_dir, preds_geo_dir) | |
| geos = glob.glob(os.path.join(preds_geo_dir, "*.geojson")) | |
| if not geos: | |
| return None, None, None, f"Error: No geojson files found in {preds_geo_dir}." | |
| # stitch + clean + filter | |
| crowns = stitch_crowns(preds_geo_dir, 1) | |
| cleaned = clean_crowns(crowns, 0.6, confidence=0) | |
| cleaned = cleaned[cleaned["Confidence_score"] > float(conf_threshold)].set_geometry( | |
| cleaned.simplify(0.4) | |
| ) | |
| # save vector | |
| gpkg_path = os.path.join(workdir, "crowns_out.gpkg") | |
| cleaned.to_file(gpkg_path, driver="GPKG") | |
| # save overlay | |
| overlay_path = os.path.join(workdir, "crowns_out.png") | |
| with rasterio.open(infile) as src: | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| show(src, ax=ax) | |
| if not cleaned.empty: | |
| cleaned.plot(ax=ax, facecolor="none", edgecolor="cyan", linewidth=1.2) | |
| ax.set_xticks([]); ax.set_yticks([]); ax.set_frame_on(False) | |
| plt.savefig(overlay_path, dpi=220, bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| stats = {"total_trees": int(len(cleaned))} | |
| return overlay_path, gpkg_path, stats, "" | |
| except Exception as e: | |
| return None, None, None, f"Error: {e}" | |
| def infer(file_path, environment, confidence): | |
| return run_detectree2(file_path, environment, confidence) | |
| def save_to_persistent(path): | |
| import traceback | |
| try: | |
| os.makedirs(WEIGHT_DIR, exist_ok=True) | |
| fname = os.path.basename(path) | |
| dest = os.path.join(WEIGHT_DIR, fname) | |
| total, used, free = shutil.disk_usage("/data") | |
| need = os.path.getsize(path) | |
| if need > free: | |
| return f"Not enough space: need {need/1024**2:.2f} MB, free {free/1024**2:.2f} MB." | |
| shutil.copy2(path, dest) | |
| return f"Saved {fname} to {WEIGHT_DIR}" | |
| except Exception as e: | |
| return "Upload failed:\n" + traceback.format_exc() | |
| def list_persistent(): | |
| out = [] | |
| for root, dirs, files in os.walk("/data"): | |
| for f in files: | |
| p = os.path.join(root, f) | |
| try: | |
| sz = os.path.getsize(p) / (1024**2) | |
| out.append(f"{p} — {sz:.2f} MB") | |
| except Exception: | |
| out.append(p) | |
| return "\n".join(sorted(out)) or "(empty)" | |
| # gradio ui | |
| with gr.Blocks(title="Detectree2 – Landscape Prediction") as demo: | |
| gr.Markdown("# Detectree2\nModels load from `/data/weights` (persistent).") | |
| with gr.Tabs(): | |
| with gr.Tab("Predict"): | |
| with gr.Row(): | |
| inp_file = gr.File(label="Input (.tif / .tiff)", file_types=[".tif", ".tiff"], type="filepath") | |
| env_dd = gr.Dropdown(label="Environment", choices=list(WEIGHTS.keys()), value=DEFAULT_ENV) | |
| conf = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence threshold") | |
| run_btn = gr.Button("Run prediction", variant="primary") | |
| out_img = gr.Image(label="Overlay (PNG)", type="filepath") | |
| out_gpkg = gr.File(label="Crowns (GPKG)") | |
| out_stats = gr.JSON(label="Statistics") | |
| out_err = gr.Textbox(label="Logs / Errors", interactive=False) | |
| run_btn.click( | |
| infer, | |
| inputs=[inp_file, env_dd, conf], | |
| outputs=[out_img, out_gpkg, out_stats, out_err], | |
| concurrency_limit=1, | |
| ) | |
| with gr.Tab("Manage Weights (/data/weights)"): | |
| up = gr.File(label="Upload .pth", file_types=[".pth"]) | |
| up_out = gr.Textbox(label="Upload status") | |
| up.upload(save_to_persistent, up, up_out) | |
| list_btn = gr.Button("List /data contents") | |
| list_out = gr.Textbox(label="Persistent files", lines=12) | |
| list_btn.click(lambda: list_persistent(), None, list_out) | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", "7860")) | |
| demo.queue().launch(server_name="0.0.0.0", server_port=port, max_threads=1, show_api=False) | |