# src/infer.py # SAM2 geoglyph inference pipeline using a georeferenced crop GeoTIFF. # # Important: # This file does NOT receive or open the full orthomosaic. # It only receives a small crop GeoTIFF containing: # - RGB pixels # - CRS # - affine transform from pathlib import Path import logging import numpy as np import torch import rasterio from rasterio.features import shapes import geopandas as gpd from shapely.geometry import shape from src.preprocess import preprocess logger = logging.getLogger("pipeline") # --------------------------------------------------------------------------- # Model cache # --------------------------------------------------------------------------- _MODEL_CACHE = {} # --------------------------------------------------------------------------- # Device handling # --------------------------------------------------------------------------- def resolve_device(device: str | None) -> str: if device is None: return "cuda" if torch.cuda.is_available() else "cpu" device = device.lower().strip() if device not in {"cuda", "cpu"}: raise ValueError(f"Invalid device: {device}. Expected 'cuda' or 'cpu'.") if device == "cuda" and not torch.cuda.is_available(): raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is False.") return device # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def load_sam2_model( device: str = "cuda", points_per_side: int = 32, points_per_batch: int = 32, pred_iou_thresh: float = 0.35, stability_score_thresh: float = 0.65, ): """ Load the SAM2 automatic mask generator from Hugging Face. The model is cached by device and SAM2 hyperparameters so the API does not reload the model for every task. """ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator key = ( device, points_per_side, points_per_batch, pred_iou_thresh, stability_score_thresh, ) if key in _MODEL_CACHE: logger.info("Reusing cached SAM2 model | device=%s", device) return _MODEL_CACHE[key] logger.info( "Loading SAM2 model | device=%s PPS=%d PPB=%d IOU_thresh=%.2f Stability_thresh=%.2f", device, points_per_side, points_per_batch, pred_iou_thresh, stability_score_thresh, ) mask_generator = SAM2AutomaticMaskGenerator.from_pretrained( "facebook/sam2.1-hiera-large", device=device, points_per_side=points_per_side, points_per_batch=points_per_batch, crop_n_layers=0, multimask_output=False, use_m2m=False, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh, ) predictor = getattr(mask_generator, "predictor", None) model = getattr(predictor, "model", None) if model is not None: actual_device = next(model.parameters()).device logger.info("Actual SAM2 model device: %s", actual_device) if device == "cpu" and actual_device.type != "cpu": raise RuntimeError( f"Expected SAM2 to run on CPU, but model is on {actual_device}." ) if device == "cuda" and actual_device.type != "cuda": raise RuntimeError( f"Expected SAM2 to run on CUDA, but model is on {actual_device}." ) else: logger.warning("Could not inspect actual SAM2 model device.") _MODEL_CACHE[key] = mask_generator return mask_generator # --------------------------------------------------------------------------- # Crop GeoTIFF I/O # --------------------------------------------------------------------------- def read_georeferenced_crop( crop_tif_path: str, rgb_bands: tuple = (1, 2, 3), max_crop_side: int = 4096, ): """ Read a small georeferenced GeoTIFF crop. This function does NOT need access to the original orthomosaic. The crop GeoTIFF must contain: - RGB pixel data - CRS - affine transform Returns: - RGB uint8 image - crop transform - crop CRS - crop metadata """ crop_tif_path = str(crop_tif_path) logger.info("Reading georeferenced crop: %s", crop_tif_path) with rasterio.open(crop_tif_path) as src: if src.crs is None: raise ValueError("The crop GeoTIFF has no CRS.") if src.transform is None: raise ValueError("The crop GeoTIFF has no affine transform.") if src.count < max(rgb_bands): raise ValueError( f"The crop has only {src.count} band(s), " f"but rgb_bands={rgb_bands} was requested." ) if src.width > max_crop_side or src.height > max_crop_side: raise ValueError( f"Crop too large: {src.width}x{src.height} px. " f"Maximum allowed side is {max_crop_side} px." ) arr = src.read(rgb_bands) crop_transform = src.transform crop_crs = src.crs crop_bounds = src.bounds metadata = { "width": int(src.width), "height": int(src.height), "count": int(src.count), "crs": str(src.crs), "bounds": { "left": float(crop_bounds.left), "bottom": float(crop_bounds.bottom), "right": float(crop_bounds.right), "top": float(crop_bounds.top), }, } arr = np.transpose(arr, (1, 2, 0)) arr = np.nan_to_num(arr) if arr.dtype != np.uint8: logger.warning( "Crop dtype is %s, converting to uint8 by clipping to [0, 255].", arr.dtype, ) arr = np.clip(arr, 0, 255).astype(np.uint8) logger.info( "Crop loaded | shape=%s crs=%s", arr.shape, metadata["crs"], ) return arr, crop_transform, crop_crs, metadata # --------------------------------------------------------------------------- # Mask → GeoDataFrame # --------------------------------------------------------------------------- def masks_to_geodataframe( masks_data: list, crop_transform, crop_crs, image_shape: tuple, min_area_px: int = 1000, max_area_frac: float = 0.20, min_iou: float = 0.35, min_stability: float = 0.65, border_margin: int = 10, ) -> gpd.GeoDataFrame: """ Convert raw SAM2 masks to a filtered GeoDataFrame of polygons. The important line is: shapes(..., transform=crop_transform) This converts mask pixel coordinates into real map coordinates using the crop GeoTIFF georeference. """ H, W = image_shape[:2] max_area_px = int(H * W * max_area_frac) logger.info( "Filtering masks | area=[%d, %d] px IOU>=%.2f Stability>=%.2f border_margin=%d", min_area_px, max_area_px, min_iou, min_stability, border_margin, ) records = [] for mask_id, m in enumerate(masks_data): area_px = int(m["area"]) if area_px < min_area_px or area_px > max_area_px: continue if m["predicted_iou"] < min_iou: continue if m["stability_score"] < min_stability: continue mask_u8 = m["segmentation"].astype(np.uint8) rows, cols = np.where(mask_u8) if len(rows) == 0: continue touches_border = ( rows.min() < border_margin or rows.max() > H - border_margin or cols.min() < border_margin or cols.max() > W - border_margin ) if touches_border: continue for geom_dict, val in shapes( mask_u8, mask=mask_u8.astype(bool), transform=crop_transform, ): if val != 1: continue geom = shape(geom_dict) if geom.is_empty: continue records.append( { "geometry": geom, "mask_id": mask_id, "predicted_iou": float(m["predicted_iou"]), "stability_score": float(m["stability_score"]), "area_px": area_px, } ) gdf = gpd.GeoDataFrame(records, geometry="geometry", crs=crop_crs) logger.info("Retained %d mask geometries after filtering.", len(gdf)) return gdf # --------------------------------------------------------------------------- # Main orchestrator # --------------------------------------------------------------------------- def run_geoglyph_sam2_on_crop( crop_tif_path: str, output_gpkg: str, layer_name: str = "sam2_geoglyph_detections", device: str | None = None, # Preprocessing use_clahe: bool = True, clahe_clip: float = 4.0, clahe_grid: int = 6, # SAM2 hyperparameters sam2_points_per_side: int = 32, sam2_points_per_batch: int = 32, sam2_pred_iou_thresh: float = 0.35, sam2_stability_score_thresh: float = 0.65, # Postprocessing filters filter_min_area_px: int = 1000, filter_max_area_frac: float = 0.20, filter_min_iou: float = 0.35, filter_min_stability: float = 0.65, filter_border_margin: int = 10, # Safety max_crop_side: int = 4096, ) -> dict: """ End-to-end geoglyph detection pipeline from a georeferenced crop. This function does NOT receive: - original orthomosaic path - bbox - bbox CRS It only receives a small crop GeoTIFF with CRS and transform. """ device = resolve_device(device) logger.info("=" * 60) logger.info("STARTING GEOGLYPH SAM2 INFERENCE ON CROP") logger.info("=" * 60) logger.info("Input crop: %s", crop_tif_path) logger.info("Output GPKG: %s", output_gpkg) logger.info("Requested device: %s", device) crop_tif_path = str(crop_tif_path) output_gpkg = Path(output_gpkg) arr_raw, crop_transform, crop_crs, crop_metadata = read_georeferenced_crop( crop_tif_path=crop_tif_path, max_crop_side=max_crop_side, ) if use_clahe: logger.info( "Applying CLAHE | clip=%.1f grid=%d", clahe_clip, clahe_grid, ) arr_processed = preprocess( arr_raw, use_clahe=use_clahe, clip=clahe_clip, grid=clahe_grid, ) mask_generator = load_sam2_model( device=device, points_per_side=sam2_points_per_side, points_per_batch=sam2_points_per_batch, pred_iou_thresh=sam2_pred_iou_thresh, stability_score_thresh=sam2_stability_score_thresh, ) logger.info("Generating masks...") with torch.inference_mode(): masks_data = mask_generator.generate(arr_processed) logger.info("SAM2 generated %d raw masks.", len(masks_data)) gdf = masks_to_geodataframe( masks_data=masks_data, crop_transform=crop_transform, crop_crs=crop_crs, image_shape=arr_processed.shape, min_area_px=filter_min_area_px, max_area_frac=filter_max_area_frac, min_iou=filter_min_iou, min_stability=filter_min_stability, border_margin=filter_border_margin, ) if len(gdf) > 0: gdf["source_crop"] = crop_tif_path gdf["input_mode"] = "georeferenced_crop" gdf["crop_width"] = crop_metadata["width"] gdf["crop_height"] = crop_metadata["height"] gdf["crop_crs"] = crop_metadata["crs"] logger.info( "Exporting %d geometries → %s layer=%s", len(gdf), output_gpkg, layer_name, ) gdf.to_file(output_gpkg, layer=layer_name, driver="GPKG") output_exists = True else: logger.warning("No geometries to export after filtering.") output_exists = False logger.info("=" * 60) logger.info("INFERENCE COMPLETED") logger.info("=" * 60) return { "output_gpkg": str(output_gpkg), "layer_name": layer_name, "n_masks": len(gdf), "input_mode": "georeferenced_crop", "crop": crop_metadata, "output_exists": output_exists, }