| |
|
| | |
| | """ |
| | !python -m pip -q install torchvision torch |
| | !python -m pip -q install rasterio |
| | !python -m pip -q install git+https://github.com/PatBall1/detectree2.git # in order for this to work, you must have installed gdal |
| | !python -m pip install opencv-python |
| | !python -m pip install requests |
| | """ |
| | from detectree2.preprocessing.tiling import tile_data |
| | from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns |
| | from detectree2.models.predict import predict_on_data |
| | from detectree2.models.train import setup_cfg |
| | from detectron2.engine import DefaultPredictor |
| | import rasterio |
| | import os |
| | import requests |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | def create_tiles(input_path, tile_width, tile_height, tile_buffer): |
| | img_path = input_path |
| | |
| | current_directory = os.getcwd() |
| | tiles_directory = os.path.join(current_directory, "tiles/") |
| | if not os.path.exists(tiles_directory): |
| | os.makedirs(tiles_directory) |
| |
|
| | data = rasterio.open(img_path) |
| |
|
| | buffer = tile_buffer |
| | tile_width = tile_width |
| | tile_height = tile_height |
| | tile_data(data, tiles_directory, buffer, tile_width, tile_height, dtype_bool = True) |
| |
|
| | return tiles_directory |
| |
|
| | def download_file(url, local_filename): |
| | with requests.get(url, stream=True) as r: |
| | r.raise_for_status() |
| | with open(local_filename, 'wb') as f: |
| | for chunk in r.iter_content(chunk_size=8192): |
| | f.write(chunk) |
| | return local_filename |
| |
|
| | def predict(tile_path, overlap_threshold, confidence_threshold, simplify_value, store_path): |
| | url = "https://zenodo.org/records/10522461/files/230103_randresize_full.pth" |
| | trained_model = "./230103_randresize_full.pth" |
| | |
| | download_file(url=url, local_filename=trained_model) |
| |
|
| | cfg = setup_cfg(update_model=trained_model, out_dir=store_path) |
| |
|
| | |
| | |
| | predict_on_data(tile_path, predictor=DefaultPredictor(cfg)) |
| |
|
| | project_to_geojson(tile_path, tile_path + "predictions/", tile_path + "predictions_geo/") |
| | crowns = stitch_crowns(tile_path + "predictions_geo/", 1) |
| | clean = clean_crowns(crowns, overlap_threshold, confidence=confidence_threshold) |
| | clean = clean.set_geometry(clean.simplify(simplify_value)) |
| | clean.to_file(store_path + "/detectree2_delin.geojson") |
| |
|
| | def run_detectree2(tif_input_path, store_path, tile_width=20, tile_height=20, tile_buffer=20, overlap_threshold=0.35, confidence_threshold=0.2, simplify_value=0.2): |
| | tile_path = create_tiles(input_path=tif_input_path, tile_width=tile_width, tile_height=tile_height, tile_buffer=tile_buffer) |
| | print(tile_path) |
| | predict(tile_path=tile_path, overlap_threshold=overlap_threshold, confidence_threshold=confidence_threshold, simplify_value=simplify_value, store_path=store_path) |
| |
|
| |
|