Geoglyph_SAM2 / src /infer.py
JuanHernandez-uc
fix
bd5ab52
Raw
History Blame Contribute Delete
12.4 kB
# src/infer.py
# SAM2 geoglyph inference pipeline using a georeferenced crop GeoTIFF.
#
# Important:
# This file does NOT receive or open the full orthomosaic.
# It only receives a small crop GeoTIFF containing:
# - RGB pixels
# - CRS
# - affine transform
from pathlib import Path
import logging
import numpy as np
import torch
import rasterio
from rasterio.features import shapes
import geopandas as gpd
from shapely.geometry import shape
from src.preprocess import preprocess
logger = logging.getLogger("pipeline")
# ---------------------------------------------------------------------------
# Model cache
# ---------------------------------------------------------------------------
_MODEL_CACHE = {}
# ---------------------------------------------------------------------------
# Device handling
# ---------------------------------------------------------------------------
def resolve_device(device: str | None) -> str:
if device is None:
return "cuda" if torch.cuda.is_available() else "cpu"
device = device.lower().strip()
if device not in {"cuda", "cpu"}:
raise ValueError(f"Invalid device: {device}. Expected 'cuda' or 'cpu'.")
if device == "cuda" and not torch.cuda.is_available():
raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is False.")
return device
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_sam2_model(
device: str = "cuda",
points_per_side: int = 32,
points_per_batch: int = 32,
pred_iou_thresh: float = 0.35,
stability_score_thresh: float = 0.65,
):
"""
Load the SAM2 automatic mask generator from Hugging Face.
The model is cached by device and SAM2 hyperparameters so the API does not
reload the model for every task.
"""
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
key = (
device,
points_per_side,
points_per_batch,
pred_iou_thresh,
stability_score_thresh,
)
if key in _MODEL_CACHE:
logger.info("Reusing cached SAM2 model | device=%s", device)
return _MODEL_CACHE[key]
logger.info(
"Loading SAM2 model | device=%s PPS=%d PPB=%d IOU_thresh=%.2f Stability_thresh=%.2f",
device,
points_per_side,
points_per_batch,
pred_iou_thresh,
stability_score_thresh,
)
mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
"facebook/sam2.1-hiera-large",
device=device,
points_per_side=points_per_side,
points_per_batch=points_per_batch,
crop_n_layers=0,
multimask_output=False,
use_m2m=False,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
)
predictor = getattr(mask_generator, "predictor", None)
model = getattr(predictor, "model", None)
if model is not None:
actual_device = next(model.parameters()).device
logger.info("Actual SAM2 model device: %s", actual_device)
if device == "cpu" and actual_device.type != "cpu":
raise RuntimeError(
f"Expected SAM2 to run on CPU, but model is on {actual_device}."
)
if device == "cuda" and actual_device.type != "cuda":
raise RuntimeError(
f"Expected SAM2 to run on CUDA, but model is on {actual_device}."
)
else:
logger.warning("Could not inspect actual SAM2 model device.")
_MODEL_CACHE[key] = mask_generator
return mask_generator
# ---------------------------------------------------------------------------
# Crop GeoTIFF I/O
# ---------------------------------------------------------------------------
def read_georeferenced_crop(
crop_tif_path: str,
rgb_bands: tuple = (1, 2, 3),
max_crop_side: int = 4096,
):
"""
Read a small georeferenced GeoTIFF crop.
This function does NOT need access to the original orthomosaic.
The crop GeoTIFF must contain:
- RGB pixel data
- CRS
- affine transform
Returns:
- RGB uint8 image
- crop transform
- crop CRS
- crop metadata
"""
crop_tif_path = str(crop_tif_path)
logger.info("Reading georeferenced crop: %s", crop_tif_path)
with rasterio.open(crop_tif_path) as src:
if src.crs is None:
raise ValueError("The crop GeoTIFF has no CRS.")
if src.transform is None:
raise ValueError("The crop GeoTIFF has no affine transform.")
if src.count < max(rgb_bands):
raise ValueError(
f"The crop has only {src.count} band(s), "
f"but rgb_bands={rgb_bands} was requested."
)
if src.width > max_crop_side or src.height > max_crop_side:
raise ValueError(
f"Crop too large: {src.width}x{src.height} px. "
f"Maximum allowed side is {max_crop_side} px."
)
arr = src.read(rgb_bands)
crop_transform = src.transform
crop_crs = src.crs
crop_bounds = src.bounds
metadata = {
"width": int(src.width),
"height": int(src.height),
"count": int(src.count),
"crs": str(src.crs),
"bounds": {
"left": float(crop_bounds.left),
"bottom": float(crop_bounds.bottom),
"right": float(crop_bounds.right),
"top": float(crop_bounds.top),
},
}
arr = np.transpose(arr, (1, 2, 0))
arr = np.nan_to_num(arr)
if arr.dtype != np.uint8:
logger.warning(
"Crop dtype is %s, converting to uint8 by clipping to [0, 255].",
arr.dtype,
)
arr = np.clip(arr, 0, 255).astype(np.uint8)
logger.info(
"Crop loaded | shape=%s crs=%s",
arr.shape,
metadata["crs"],
)
return arr, crop_transform, crop_crs, metadata
# ---------------------------------------------------------------------------
# Mask → GeoDataFrame
# ---------------------------------------------------------------------------
def masks_to_geodataframe(
masks_data: list,
crop_transform,
crop_crs,
image_shape: tuple,
min_area_px: int = 1000,
max_area_frac: float = 0.20,
min_iou: float = 0.35,
min_stability: float = 0.65,
border_margin: int = 10,
) -> gpd.GeoDataFrame:
"""
Convert raw SAM2 masks to a filtered GeoDataFrame of polygons.
The important line is:
shapes(..., transform=crop_transform)
This converts mask pixel coordinates into real map coordinates using
the crop GeoTIFF georeference.
"""
H, W = image_shape[:2]
max_area_px = int(H * W * max_area_frac)
logger.info(
"Filtering masks | area=[%d, %d] px IOU>=%.2f Stability>=%.2f border_margin=%d",
min_area_px,
max_area_px,
min_iou,
min_stability,
border_margin,
)
records = []
for mask_id, m in enumerate(masks_data):
area_px = int(m["area"])
if area_px < min_area_px or area_px > max_area_px:
continue
if m["predicted_iou"] < min_iou:
continue
if m["stability_score"] < min_stability:
continue
mask_u8 = m["segmentation"].astype(np.uint8)
rows, cols = np.where(mask_u8)
if len(rows) == 0:
continue
touches_border = (
rows.min() < border_margin
or rows.max() > H - border_margin
or cols.min() < border_margin
or cols.max() > W - border_margin
)
if touches_border:
continue
for geom_dict, val in shapes(
mask_u8,
mask=mask_u8.astype(bool),
transform=crop_transform,
):
if val != 1:
continue
geom = shape(geom_dict)
if geom.is_empty:
continue
records.append(
{
"geometry": geom,
"mask_id": mask_id,
"predicted_iou": float(m["predicted_iou"]),
"stability_score": float(m["stability_score"]),
"area_px": area_px,
}
)
gdf = gpd.GeoDataFrame(records, geometry="geometry", crs=crop_crs)
logger.info("Retained %d mask geometries after filtering.", len(gdf))
return gdf
# ---------------------------------------------------------------------------
# Main orchestrator
# ---------------------------------------------------------------------------
def run_geoglyph_sam2_on_crop(
crop_tif_path: str,
output_gpkg: str,
layer_name: str = "sam2_geoglyph_detections",
device: str | None = None,
# Preprocessing
use_clahe: bool = True,
clahe_clip: float = 4.0,
clahe_grid: int = 6,
# SAM2 hyperparameters
sam2_points_per_side: int = 32,
sam2_points_per_batch: int = 32,
sam2_pred_iou_thresh: float = 0.35,
sam2_stability_score_thresh: float = 0.65,
# Postprocessing filters
filter_min_area_px: int = 1000,
filter_max_area_frac: float = 0.20,
filter_min_iou: float = 0.35,
filter_min_stability: float = 0.65,
filter_border_margin: int = 10,
# Safety
max_crop_side: int = 4096,
) -> dict:
"""
End-to-end geoglyph detection pipeline from a georeferenced crop.
This function does NOT receive:
- original orthomosaic path
- bbox
- bbox CRS
It only receives a small crop GeoTIFF with CRS and transform.
"""
device = resolve_device(device)
logger.info("=" * 60)
logger.info("STARTING GEOGLYPH SAM2 INFERENCE ON CROP")
logger.info("=" * 60)
logger.info("Input crop: %s", crop_tif_path)
logger.info("Output GPKG: %s", output_gpkg)
logger.info("Requested device: %s", device)
crop_tif_path = str(crop_tif_path)
output_gpkg = Path(output_gpkg)
arr_raw, crop_transform, crop_crs, crop_metadata = read_georeferenced_crop(
crop_tif_path=crop_tif_path,
max_crop_side=max_crop_side,
)
if use_clahe:
logger.info(
"Applying CLAHE | clip=%.1f grid=%d",
clahe_clip,
clahe_grid,
)
arr_processed = preprocess(
arr_raw,
use_clahe=use_clahe,
clip=clahe_clip,
grid=clahe_grid,
)
mask_generator = load_sam2_model(
device=device,
points_per_side=sam2_points_per_side,
points_per_batch=sam2_points_per_batch,
pred_iou_thresh=sam2_pred_iou_thresh,
stability_score_thresh=sam2_stability_score_thresh,
)
logger.info("Generating masks...")
with torch.inference_mode():
masks_data = mask_generator.generate(arr_processed)
logger.info("SAM2 generated %d raw masks.", len(masks_data))
gdf = masks_to_geodataframe(
masks_data=masks_data,
crop_transform=crop_transform,
crop_crs=crop_crs,
image_shape=arr_processed.shape,
min_area_px=filter_min_area_px,
max_area_frac=filter_max_area_frac,
min_iou=filter_min_iou,
min_stability=filter_min_stability,
border_margin=filter_border_margin,
)
if len(gdf) > 0:
gdf["source_crop"] = crop_tif_path
gdf["input_mode"] = "georeferenced_crop"
gdf["crop_width"] = crop_metadata["width"]
gdf["crop_height"] = crop_metadata["height"]
gdf["crop_crs"] = crop_metadata["crs"]
logger.info(
"Exporting %d geometries → %s layer=%s",
len(gdf),
output_gpkg,
layer_name,
)
gdf.to_file(output_gpkg, layer=layer_name, driver="GPKG")
output_exists = True
else:
logger.warning("No geometries to export after filtering.")
output_exists = False
logger.info("=" * 60)
logger.info("INFERENCE COMPLETED")
logger.info("=" * 60)
return {
"output_gpkg": str(output_gpkg),
"layer_name": layer_name,
"n_masks": len(gdf),
"input_mode": "georeferenced_crop",
"crop": crop_metadata,
"output_exists": output_exists,
}