Detectree2 / app.py
TeddyUW's picture
revert
4298555
Raw
History Blame Contribute Delete
7.51 kB
# environment setup
import os
val = os.environ.get("OMP_NUM_THREADS", "")
try:
int(val)
except Exception:
os.environ["OMP_NUM_THREADS"] = "1"
os.environ.setdefault("MPLCONFIGDIR", "/var/tmp/matplotlib")
os.environ.setdefault("HF_HOME", "/data/.huggingface")
os.environ.setdefault("TORCH_HOME", "/data/.cache/torch")
import tempfile, shutil, glob
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from affine import Affine
import rasterio
from rasterio.plot import show
import torch
import gradio as gr
from detectron2.engine import DefaultPredictor
from detectree2.preprocessing.tiling import tile_data
from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns
from detectree2.models.predict import predict_on_data
from detectree2.models.train import setup_cfg
# config
WEIGHT_DIR = os.getenv("DTR2_WEIGHT_DIR", "/data/weights")
os.makedirs(WEIGHT_DIR, exist_ok=True)
# config
WEIGHTS = {
"Flexible": os.path.join(WEIGHT_DIR, "250312_flexi.pth"),
"Forest": os.path.join(WEIGHT_DIR, "250711_tropical_closed_canopy.pth"),
"Urban": os.path.join(WEIGHT_DIR, "urban_trees_Cambridge_20230630.pth"),
}
DEFAULT_ENV = "Flexible"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Keep PyTorch from oversubscribing CPU
try:
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
except Exception:
pass
# Tiling defaults
BUFFER, TILE_W, TILE_H = 30, 40, 40
def run_detectree2(geotiff_path: str, environment: str, conf_threshold: float):
workdir = tempfile.mkdtemp(prefix="dtr2_")
tiles_dir = os.path.join(workdir, "tilespred")
os.makedirs(tiles_dir, exist_ok=True)
os.chdir(workdir)
try:
infile = os.path.join(workdir, os.path.basename(geotiff_path))
shutil.copy2(geotiff_path, infile)
# Must be georeferenced
with rasterio.open(infile) as src:
if src.crs is None or src.transform == Affine.identity():
return None, None, None, "Error: input GeoTIFF has no CRS/transform."
tile_data(infile, tiles_dir, BUFFER, TILE_W, TILE_H)
weights = WEIGHTS.get(environment) or WEIGHTS[DEFAULT_ENV]
cfg = setup_cfg(update_model=weights)
cfg.MODEL.DEVICE = DEVICE
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = float(os.getenv("DTR2_SCORE_THRESH", "0.25"))
cfg.OUTPUT_DIR = os.path.join(workdir, "train_outputs")
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
predictor = DefaultPredictor(cfg)
with torch.inference_mode():
predict_on_data(tiles_dir, predictor=predictor)
for d in (os.path.join(tiles_dir, "predictions"),
os.path.join(cfg.OUTPUT_DIR, "predictions"),
os.path.join(workdir, "predictions")):
if os.path.isdir(d):
preds_dir = d
break
else:
return None, None, None, "Error: no predictions folder found."
if not glob.glob(os.path.join(preds_dir, "*.json")):
return None, None, None, (
f"Error: no prediction JSONs in {preds_dir}. "
f"Try lowering SCORE_THRESH_TEST (now {cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST})."
)
# project to geojson
preds_geo_dir = os.path.join(tiles_dir, "predictions_geo")
os.makedirs(preds_geo_dir, exist_ok=True)
project_to_geojson(tiles_dir, preds_dir, preds_geo_dir)
geos = glob.glob(os.path.join(preds_geo_dir, "*.geojson"))
if not geos:
return None, None, None, f"Error: No geojson files found in {preds_geo_dir}."
# stitch + clean + filter
crowns = stitch_crowns(preds_geo_dir, 1)
cleaned = clean_crowns(crowns, 0.6, confidence=0)
cleaned = cleaned[cleaned["Confidence_score"] > float(conf_threshold)].set_geometry(
cleaned.simplify(0.4)
)
# save vector
gpkg_path = os.path.join(workdir, "crowns_out.gpkg")
cleaned.to_file(gpkg_path, driver="GPKG")
# save overlay
overlay_path = os.path.join(workdir, "crowns_out.png")
with rasterio.open(infile) as src:
fig, ax = plt.subplots(figsize=(10, 10))
show(src, ax=ax)
if not cleaned.empty:
cleaned.plot(ax=ax, facecolor="none", edgecolor="cyan", linewidth=1.2)
ax.set_xticks([]); ax.set_yticks([]); ax.set_frame_on(False)
plt.savefig(overlay_path, dpi=220, bbox_inches="tight", pad_inches=0)
plt.close(fig)
stats = {"total_trees": int(len(cleaned))}
return overlay_path, gpkg_path, stats, ""
except Exception as e:
return None, None, None, f"Error: {e}"
def infer(file_path, environment, confidence):
return run_detectree2(file_path, environment, confidence)
def save_to_persistent(path):
import traceback
try:
os.makedirs(WEIGHT_DIR, exist_ok=True)
fname = os.path.basename(path)
dest = os.path.join(WEIGHT_DIR, fname)
total, used, free = shutil.disk_usage("/data")
need = os.path.getsize(path)
if need > free:
return f"Not enough space: need {need/1024**2:.2f} MB, free {free/1024**2:.2f} MB."
shutil.copy2(path, dest)
return f"Saved {fname} to {WEIGHT_DIR}"
except Exception as e:
return "Upload failed:\n" + traceback.format_exc()
def list_persistent():
out = []
for root, dirs, files in os.walk("/data"):
for f in files:
p = os.path.join(root, f)
try:
sz = os.path.getsize(p) / (1024**2)
out.append(f"{p}{sz:.2f} MB")
except Exception:
out.append(p)
return "\n".join(sorted(out)) or "(empty)"
# gradio ui
with gr.Blocks(title="Detectree2 – Landscape Prediction") as demo:
gr.Markdown("# Detectree2\nModels load from `/data/weights` (persistent).")
with gr.Tabs():
with gr.Tab("Predict"):
with gr.Row():
inp_file = gr.File(label="Input (.tif / .tiff)", file_types=[".tif", ".tiff"], type="filepath")
env_dd = gr.Dropdown(label="Environment", choices=list(WEIGHTS.keys()), value=DEFAULT_ENV)
conf = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence threshold")
run_btn = gr.Button("Run prediction", variant="primary")
out_img = gr.Image(label="Overlay (PNG)", type="filepath")
out_gpkg = gr.File(label="Crowns (GPKG)")
out_stats = gr.JSON(label="Statistics")
out_err = gr.Textbox(label="Logs / Errors", interactive=False)
run_btn.click(
infer,
inputs=[inp_file, env_dd, conf],
outputs=[out_img, out_gpkg, out_stats, out_err],
concurrency_limit=1,
)
with gr.Tab("Manage Weights (/data/weights)"):
up = gr.File(label="Upload .pth", file_types=[".pth"])
up_out = gr.Textbox(label="Upload status")
up.upload(save_to_persistent, up, up_out)
list_btn = gr.Button("List /data contents")
list_out = gr.Textbox(label="Persistent files", lines=12)
list_btn.click(lambda: list_persistent(), None, list_out)
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
demo.queue().launch(server_name="0.0.0.0", server_port=port, max_threads=1, show_api=False)