Spaces:
Sleeping
Sleeping
| # src/infer.py | |
| # SAM2 geoglyph inference pipeline using a georeferenced crop GeoTIFF. | |
| # | |
| # Important: | |
| # This file does NOT receive or open the full orthomosaic. | |
| # It only receives a small crop GeoTIFF containing: | |
| # - RGB pixels | |
| # - CRS | |
| # - affine transform | |
| from pathlib import Path | |
| import logging | |
| import numpy as np | |
| import torch | |
| import rasterio | |
| from rasterio.features import shapes | |
| import geopandas as gpd | |
| from shapely.geometry import shape | |
| from src.preprocess import preprocess | |
| logger = logging.getLogger("pipeline") | |
| # --------------------------------------------------------------------------- | |
| # Model cache | |
| # --------------------------------------------------------------------------- | |
| _MODEL_CACHE = {} | |
| # --------------------------------------------------------------------------- | |
| # Device handling | |
| # --------------------------------------------------------------------------- | |
| def resolve_device(device: str | None) -> str: | |
| if device is None: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| device = device.lower().strip() | |
| if device not in {"cuda", "cpu"}: | |
| raise ValueError(f"Invalid device: {device}. Expected 'cuda' or 'cpu'.") | |
| if device == "cuda" and not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is False.") | |
| return device | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| def load_sam2_model( | |
| device: str = "cuda", | |
| points_per_side: int = 32, | |
| points_per_batch: int = 32, | |
| pred_iou_thresh: float = 0.35, | |
| stability_score_thresh: float = 0.65, | |
| ): | |
| """ | |
| Load the SAM2 automatic mask generator from Hugging Face. | |
| The model is cached by device and SAM2 hyperparameters so the API does not | |
| reload the model for every task. | |
| """ | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| key = ( | |
| device, | |
| points_per_side, | |
| points_per_batch, | |
| pred_iou_thresh, | |
| stability_score_thresh, | |
| ) | |
| if key in _MODEL_CACHE: | |
| logger.info("Reusing cached SAM2 model | device=%s", device) | |
| return _MODEL_CACHE[key] | |
| logger.info( | |
| "Loading SAM2 model | device=%s PPS=%d PPB=%d IOU_thresh=%.2f Stability_thresh=%.2f", | |
| device, | |
| points_per_side, | |
| points_per_batch, | |
| pred_iou_thresh, | |
| stability_score_thresh, | |
| ) | |
| mask_generator = SAM2AutomaticMaskGenerator.from_pretrained( | |
| "facebook/sam2.1-hiera-large", | |
| device=device, | |
| points_per_side=points_per_side, | |
| points_per_batch=points_per_batch, | |
| crop_n_layers=0, | |
| multimask_output=False, | |
| use_m2m=False, | |
| pred_iou_thresh=pred_iou_thresh, | |
| stability_score_thresh=stability_score_thresh, | |
| ) | |
| predictor = getattr(mask_generator, "predictor", None) | |
| model = getattr(predictor, "model", None) | |
| if model is not None: | |
| actual_device = next(model.parameters()).device | |
| logger.info("Actual SAM2 model device: %s", actual_device) | |
| if device == "cpu" and actual_device.type != "cpu": | |
| raise RuntimeError( | |
| f"Expected SAM2 to run on CPU, but model is on {actual_device}." | |
| ) | |
| if device == "cuda" and actual_device.type != "cuda": | |
| raise RuntimeError( | |
| f"Expected SAM2 to run on CUDA, but model is on {actual_device}." | |
| ) | |
| else: | |
| logger.warning("Could not inspect actual SAM2 model device.") | |
| _MODEL_CACHE[key] = mask_generator | |
| return mask_generator | |
| # --------------------------------------------------------------------------- | |
| # Crop GeoTIFF I/O | |
| # --------------------------------------------------------------------------- | |
| def read_georeferenced_crop( | |
| crop_tif_path: str, | |
| rgb_bands: tuple = (1, 2, 3), | |
| max_crop_side: int = 4096, | |
| ): | |
| """ | |
| Read a small georeferenced GeoTIFF crop. | |
| This function does NOT need access to the original orthomosaic. | |
| The crop GeoTIFF must contain: | |
| - RGB pixel data | |
| - CRS | |
| - affine transform | |
| Returns: | |
| - RGB uint8 image | |
| - crop transform | |
| - crop CRS | |
| - crop metadata | |
| """ | |
| crop_tif_path = str(crop_tif_path) | |
| logger.info("Reading georeferenced crop: %s", crop_tif_path) | |
| with rasterio.open(crop_tif_path) as src: | |
| if src.crs is None: | |
| raise ValueError("The crop GeoTIFF has no CRS.") | |
| if src.transform is None: | |
| raise ValueError("The crop GeoTIFF has no affine transform.") | |
| if src.count < max(rgb_bands): | |
| raise ValueError( | |
| f"The crop has only {src.count} band(s), " | |
| f"but rgb_bands={rgb_bands} was requested." | |
| ) | |
| if src.width > max_crop_side or src.height > max_crop_side: | |
| raise ValueError( | |
| f"Crop too large: {src.width}x{src.height} px. " | |
| f"Maximum allowed side is {max_crop_side} px." | |
| ) | |
| arr = src.read(rgb_bands) | |
| crop_transform = src.transform | |
| crop_crs = src.crs | |
| crop_bounds = src.bounds | |
| metadata = { | |
| "width": int(src.width), | |
| "height": int(src.height), | |
| "count": int(src.count), | |
| "crs": str(src.crs), | |
| "bounds": { | |
| "left": float(crop_bounds.left), | |
| "bottom": float(crop_bounds.bottom), | |
| "right": float(crop_bounds.right), | |
| "top": float(crop_bounds.top), | |
| }, | |
| } | |
| arr = np.transpose(arr, (1, 2, 0)) | |
| arr = np.nan_to_num(arr) | |
| if arr.dtype != np.uint8: | |
| logger.warning( | |
| "Crop dtype is %s, converting to uint8 by clipping to [0, 255].", | |
| arr.dtype, | |
| ) | |
| arr = np.clip(arr, 0, 255).astype(np.uint8) | |
| logger.info( | |
| "Crop loaded | shape=%s crs=%s", | |
| arr.shape, | |
| metadata["crs"], | |
| ) | |
| return arr, crop_transform, crop_crs, metadata | |
| # --------------------------------------------------------------------------- | |
| # Mask → GeoDataFrame | |
| # --------------------------------------------------------------------------- | |
| def masks_to_geodataframe( | |
| masks_data: list, | |
| crop_transform, | |
| crop_crs, | |
| image_shape: tuple, | |
| min_area_px: int = 1000, | |
| max_area_frac: float = 0.20, | |
| min_iou: float = 0.35, | |
| min_stability: float = 0.65, | |
| border_margin: int = 10, | |
| ) -> gpd.GeoDataFrame: | |
| """ | |
| Convert raw SAM2 masks to a filtered GeoDataFrame of polygons. | |
| The important line is: | |
| shapes(..., transform=crop_transform) | |
| This converts mask pixel coordinates into real map coordinates using | |
| the crop GeoTIFF georeference. | |
| """ | |
| H, W = image_shape[:2] | |
| max_area_px = int(H * W * max_area_frac) | |
| logger.info( | |
| "Filtering masks | area=[%d, %d] px IOU>=%.2f Stability>=%.2f border_margin=%d", | |
| min_area_px, | |
| max_area_px, | |
| min_iou, | |
| min_stability, | |
| border_margin, | |
| ) | |
| records = [] | |
| for mask_id, m in enumerate(masks_data): | |
| area_px = int(m["area"]) | |
| if area_px < min_area_px or area_px > max_area_px: | |
| continue | |
| if m["predicted_iou"] < min_iou: | |
| continue | |
| if m["stability_score"] < min_stability: | |
| continue | |
| mask_u8 = m["segmentation"].astype(np.uint8) | |
| rows, cols = np.where(mask_u8) | |
| if len(rows) == 0: | |
| continue | |
| touches_border = ( | |
| rows.min() < border_margin | |
| or rows.max() > H - border_margin | |
| or cols.min() < border_margin | |
| or cols.max() > W - border_margin | |
| ) | |
| if touches_border: | |
| continue | |
| for geom_dict, val in shapes( | |
| mask_u8, | |
| mask=mask_u8.astype(bool), | |
| transform=crop_transform, | |
| ): | |
| if val != 1: | |
| continue | |
| geom = shape(geom_dict) | |
| if geom.is_empty: | |
| continue | |
| records.append( | |
| { | |
| "geometry": geom, | |
| "mask_id": mask_id, | |
| "predicted_iou": float(m["predicted_iou"]), | |
| "stability_score": float(m["stability_score"]), | |
| "area_px": area_px, | |
| } | |
| ) | |
| gdf = gpd.GeoDataFrame(records, geometry="geometry", crs=crop_crs) | |
| logger.info("Retained %d mask geometries after filtering.", len(gdf)) | |
| return gdf | |
| # --------------------------------------------------------------------------- | |
| # Main orchestrator | |
| # --------------------------------------------------------------------------- | |
| def run_geoglyph_sam2_on_crop( | |
| crop_tif_path: str, | |
| output_gpkg: str, | |
| layer_name: str = "sam2_geoglyph_detections", | |
| device: str | None = None, | |
| # Preprocessing | |
| use_clahe: bool = True, | |
| clahe_clip: float = 4.0, | |
| clahe_grid: int = 6, | |
| # SAM2 hyperparameters | |
| sam2_points_per_side: int = 32, | |
| sam2_points_per_batch: int = 32, | |
| sam2_pred_iou_thresh: float = 0.35, | |
| sam2_stability_score_thresh: float = 0.65, | |
| # Postprocessing filters | |
| filter_min_area_px: int = 1000, | |
| filter_max_area_frac: float = 0.20, | |
| filter_min_iou: float = 0.35, | |
| filter_min_stability: float = 0.65, | |
| filter_border_margin: int = 10, | |
| # Safety | |
| max_crop_side: int = 4096, | |
| ) -> dict: | |
| """ | |
| End-to-end geoglyph detection pipeline from a georeferenced crop. | |
| This function does NOT receive: | |
| - original orthomosaic path | |
| - bbox | |
| - bbox CRS | |
| It only receives a small crop GeoTIFF with CRS and transform. | |
| """ | |
| device = resolve_device(device) | |
| logger.info("=" * 60) | |
| logger.info("STARTING GEOGLYPH SAM2 INFERENCE ON CROP") | |
| logger.info("=" * 60) | |
| logger.info("Input crop: %s", crop_tif_path) | |
| logger.info("Output GPKG: %s", output_gpkg) | |
| logger.info("Requested device: %s", device) | |
| crop_tif_path = str(crop_tif_path) | |
| output_gpkg = Path(output_gpkg) | |
| arr_raw, crop_transform, crop_crs, crop_metadata = read_georeferenced_crop( | |
| crop_tif_path=crop_tif_path, | |
| max_crop_side=max_crop_side, | |
| ) | |
| if use_clahe: | |
| logger.info( | |
| "Applying CLAHE | clip=%.1f grid=%d", | |
| clahe_clip, | |
| clahe_grid, | |
| ) | |
| arr_processed = preprocess( | |
| arr_raw, | |
| use_clahe=use_clahe, | |
| clip=clahe_clip, | |
| grid=clahe_grid, | |
| ) | |
| mask_generator = load_sam2_model( | |
| device=device, | |
| points_per_side=sam2_points_per_side, | |
| points_per_batch=sam2_points_per_batch, | |
| pred_iou_thresh=sam2_pred_iou_thresh, | |
| stability_score_thresh=sam2_stability_score_thresh, | |
| ) | |
| logger.info("Generating masks...") | |
| with torch.inference_mode(): | |
| masks_data = mask_generator.generate(arr_processed) | |
| logger.info("SAM2 generated %d raw masks.", len(masks_data)) | |
| gdf = masks_to_geodataframe( | |
| masks_data=masks_data, | |
| crop_transform=crop_transform, | |
| crop_crs=crop_crs, | |
| image_shape=arr_processed.shape, | |
| min_area_px=filter_min_area_px, | |
| max_area_frac=filter_max_area_frac, | |
| min_iou=filter_min_iou, | |
| min_stability=filter_min_stability, | |
| border_margin=filter_border_margin, | |
| ) | |
| if len(gdf) > 0: | |
| gdf["source_crop"] = crop_tif_path | |
| gdf["input_mode"] = "georeferenced_crop" | |
| gdf["crop_width"] = crop_metadata["width"] | |
| gdf["crop_height"] = crop_metadata["height"] | |
| gdf["crop_crs"] = crop_metadata["crs"] | |
| logger.info( | |
| "Exporting %d geometries → %s layer=%s", | |
| len(gdf), | |
| output_gpkg, | |
| layer_name, | |
| ) | |
| gdf.to_file(output_gpkg, layer=layer_name, driver="GPKG") | |
| output_exists = True | |
| else: | |
| logger.warning("No geometries to export after filtering.") | |
| output_exists = False | |
| logger.info("=" * 60) | |
| logger.info("INFERENCE COMPLETED") | |
| logger.info("=" * 60) | |
| return { | |
| "output_gpkg": str(output_gpkg), | |
| "layer_name": layer_name, | |
| "n_masks": len(gdf), | |
| "input_mode": "georeferenced_crop", | |
| "crop": crop_metadata, | |
| "output_exists": output_exists, | |
| } |