ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
import numpy as np
import geopandas as gpd
import hashlib
from rasterio.io import MemoryFile
from .grid_cell_fragment import *
from .models import *
import cv2
class MajorTOM_Embedder(torch.nn.Module):
"""
MajorTOM Embedder class that applies a model to geospatial image fragments,
computes embeddings, and returns metadata for each fragment.
This class is designed to work with raster data, where the image is fragmented
into smaller tiles, and embeddings are computed for each tile using the provided
embedder model. The output is a GeoDataFrame containing spatial metadata and
the corresponding embeddings for each tile.
Attributes:
embedder: A model that generates embeddings for image fragments.
frag_params: Dictionary containing fragmentation parameters such as the
target overlap and border shift.
column_types: Dictionary specifying data types for the output GeoDataFrame columns.
"""
def __init__(self, embedder, target_overlap=0.1, border_shift=True):
"""
Initializes the MajorTOM Embedder with the given parameters.
Args:
embedder (torch.nn.Module): A model that generates embeddings for image fragments.
target_overlap (float): The target overlap between image fragments. Default is 0.1.
border_shift (bool): Whether to shift the borders of fragments to avoid edge artifacts. Default is True.
"""
super().__init__()
# Model
self.embedder = embedder
# Fragmentation Settings
self.frag_params = params = {
'fragment_size' : self.embedder.size[0],
'target_overlap' : target_overlap,
'border_shift' : border_shift
}
# Data types for the output dataframe (commented columns need no conversion)
self.column_types = {
#'unique_id' :,
#'embedding' : ,
#'timestamp' : ,
#'product_id' : ,
#'grid_cell' : ,
'grid_row_u' : 'int16',
'grid_col_r' : 'int16',
'centre_lat' : 'float32',
'centre_lon' : 'float32',
#'utm_footprint' : ,
#'utm_crs' : ,
#'pixel_bbox' : ,
}
def bands(self):
"""
Returns the set of input bands in the correct order.
Returns:
list: List of input bands used by the embedder.
"""
return self.embedder.bands
def size(self):
"""
Returns the input image size.
Returns:
tuple: Tuple representing the image size (height, width).
"""
return self.embedder.size
def calculate_checksum(self, geometry, timestamp, product_id, embedding):
"""
Calculates a checksum for the given geometry, timestamp, product ID, and embedding.
Args:
geometry (shapely.geometry): The geometry object representing the fragment's footprint.
timestamp (str): Timestamp of the data.
product_id (str): Product identifier.
embedding (np.ndarray): The embedding of the image fragment.
Returns:
str: A SHA256 checksum of the concatenated input parameters.
"""
combined = f"{geometry}_{timestamp}_{product_id}_{embedding}"
checksum = hashlib.sha256(combined.encode()).hexdigest()
return checksum
def _read_image(self, row):
"""
Reads and processes the image bands for a given row, performs optional upsampling
if the resolution is mismatched, and returns the image data, footprint, and CRS.
Args:
row (pandas.Series): The input row containing the image bands.
Returns:
torch.Tensor: A tensor containing the stacked image bands.
shapely.geometry: The footprint of the image.
rasterio.crs.CRS: The CRS of the image.
"""
# Read the file
img = []
for band in self.embedder.bands:
with MemoryFile(row[band][0].as_py()) as mem_f:
with mem_f.open(driver='GTiff') as f:
crs = f.crs
footprint = box(*f.bounds)
img.append(f.read()[0])
# optional upsampling
shapes = [layer.shape for layer in img]
if any([el!=shapes[0] for el in shapes]): # if any resolution mismatch
h, w = max([el[0] for el in shapes]), max([el[1] for el in shapes]) # maximum size
for layer_idx, layer in enumerate(img):
if layer.shape != (h,w):
img[layer_idx] = cv2.resize(layer, (h,w), interpolation=cv2.INTER_NEAREST)
img = torch.from_numpy(np.stack(img,-1).astype(np.float32))
return img, footprint, crs
def forward(self, row, row_meta, device='cuda'):
"""
Forward pass of the model: Reads the image, fragments it, computes embeddings
for each fragment, and returns a GeoDataFrame with the spatial metadata and
embeddings.
Args:
row (pandas.Series): The input row containing the image data.
row_meta (pandas.Series): Metadata associated with the row (e.g., timestamp, product_id).
device (str): The device to run the model on ('cpu' or 'cuda'). Default is 'cuda'.
Returns:
geopandas.GeoDataFrame: A GeoDataFrame containing metadata and embeddings for each fragment.
"""
# Read file
img, footprint, crs = self._read_image(row)
# Fragment the sample
fragments, xys = fragment_fn(img, **self.frag_params, return_indices=True, verbose=False)
nrows, ncols, c, h, w = fragments.shape
# Apply the model
with torch.no_grad():
embeddings = self.embedder(fragments.reshape(-1,c,h,w).to(device)).view(nrows, ncols, -1)
df_rows = []
# Pack rows for geoparquet
for r_idx in range(nrows):
for c_idx in range(ncols):
embedding = embeddings[r_idx, c_idx].cpu().numpy()
# spatial features per fragment
x_offset,y_offset=xys[r_idx,c_idx].int().tolist()
pixel_bbox = [x_offset, y_offset, x_offset + h,y_offset + w] # in pixels
utm_footprint = crop_footprint(footprint, *img.shape[:2], pixel_bbox)
# main footprint is in WGS84 (needs to be consistent across parquet)
transformer = Transformer.from_crs(crs, CRS.from_epsg(4326), always_xy=True)
geometry = transform(transformer.transform, utm_footprint) # WGS84
centre_lon, centre_lat = geometry.centroid.coords[0]
row_dict = {
'unique_id' : self.calculate_checksum(geometry, row_meta.timestamp.item(), row_meta.product_id.item(), embedding),
'embedding' : embedding,
'timestamp' : row_meta.timestamp.item(),
'product_id' : row_meta.product_id.item(),
'grid_cell' : row_meta.grid_cell.item(),
'grid_row_u' : row_meta.grid_row_u.item(),
'grid_col_r' : row_meta.grid_col_r.item(),
'geometry' : geometry,
'centre_lat' : centre_lat,
'centre_lon' : centre_lon,
'utm_footprint' : utm_footprint.wkt,
'utm_crs' : crs.to_string(),
'pixel_bbox' : pixel_bbox,
}
df_rows.append(row_dict)
return gpd.GeoDataFrame(df_rows).astype(self.column_types)