|
|
import numpy as np |
|
|
import geopandas as gpd |
|
|
import hashlib |
|
|
from rasterio.io import MemoryFile |
|
|
|
|
|
from .grid_cell_fragment import * |
|
|
from .models import * |
|
|
import cv2 |
|
|
|
|
|
class MajorTOM_Embedder(torch.nn.Module): |
|
|
""" |
|
|
MajorTOM Embedder class that applies a model to geospatial image fragments, |
|
|
computes embeddings, and returns metadata for each fragment. |
|
|
|
|
|
This class is designed to work with raster data, where the image is fragmented |
|
|
into smaller tiles, and embeddings are computed for each tile using the provided |
|
|
embedder model. The output is a GeoDataFrame containing spatial metadata and |
|
|
the corresponding embeddings for each tile. |
|
|
|
|
|
Attributes: |
|
|
embedder: A model that generates embeddings for image fragments. |
|
|
frag_params: Dictionary containing fragmentation parameters such as the |
|
|
target overlap and border shift. |
|
|
column_types: Dictionary specifying data types for the output GeoDataFrame columns. |
|
|
""" |
|
|
|
|
|
def __init__(self, embedder, target_overlap=0.1, border_shift=True): |
|
|
""" |
|
|
Initializes the MajorTOM Embedder with the given parameters. |
|
|
|
|
|
Args: |
|
|
embedder (torch.nn.Module): A model that generates embeddings for image fragments. |
|
|
target_overlap (float): The target overlap between image fragments. Default is 0.1. |
|
|
border_shift (bool): Whether to shift the borders of fragments to avoid edge artifacts. Default is True. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.embedder = embedder |
|
|
|
|
|
|
|
|
self.frag_params = params = { |
|
|
'fragment_size' : self.embedder.size[0], |
|
|
'target_overlap' : target_overlap, |
|
|
'border_shift' : border_shift |
|
|
} |
|
|
|
|
|
|
|
|
self.column_types = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'grid_row_u' : 'int16', |
|
|
'grid_col_r' : 'int16', |
|
|
'centre_lat' : 'float32', |
|
|
'centre_lon' : 'float32', |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
def bands(self): |
|
|
""" |
|
|
Returns the set of input bands in the correct order. |
|
|
|
|
|
Returns: |
|
|
list: List of input bands used by the embedder. |
|
|
""" |
|
|
return self.embedder.bands |
|
|
|
|
|
def size(self): |
|
|
""" |
|
|
Returns the input image size. |
|
|
|
|
|
Returns: |
|
|
tuple: Tuple representing the image size (height, width). |
|
|
""" |
|
|
return self.embedder.size |
|
|
|
|
|
def calculate_checksum(self, geometry, timestamp, product_id, embedding): |
|
|
""" |
|
|
Calculates a checksum for the given geometry, timestamp, product ID, and embedding. |
|
|
|
|
|
Args: |
|
|
geometry (shapely.geometry): The geometry object representing the fragment's footprint. |
|
|
timestamp (str): Timestamp of the data. |
|
|
product_id (str): Product identifier. |
|
|
embedding (np.ndarray): The embedding of the image fragment. |
|
|
|
|
|
Returns: |
|
|
str: A SHA256 checksum of the concatenated input parameters. |
|
|
""" |
|
|
combined = f"{geometry}_{timestamp}_{product_id}_{embedding}" |
|
|
checksum = hashlib.sha256(combined.encode()).hexdigest() |
|
|
return checksum |
|
|
|
|
|
def _read_image(self, row): |
|
|
""" |
|
|
Reads and processes the image bands for a given row, performs optional upsampling |
|
|
if the resolution is mismatched, and returns the image data, footprint, and CRS. |
|
|
|
|
|
Args: |
|
|
row (pandas.Series): The input row containing the image bands. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: A tensor containing the stacked image bands. |
|
|
shapely.geometry: The footprint of the image. |
|
|
rasterio.crs.CRS: The CRS of the image. |
|
|
""" |
|
|
|
|
|
|
|
|
img = [] |
|
|
for band in self.embedder.bands: |
|
|
with MemoryFile(row[band][0].as_py()) as mem_f: |
|
|
with mem_f.open(driver='GTiff') as f: |
|
|
crs = f.crs |
|
|
footprint = box(*f.bounds) |
|
|
img.append(f.read()[0]) |
|
|
|
|
|
|
|
|
shapes = [layer.shape for layer in img] |
|
|
if any([el!=shapes[0] for el in shapes]): |
|
|
h, w = max([el[0] for el in shapes]), max([el[1] for el in shapes]) |
|
|
for layer_idx, layer in enumerate(img): |
|
|
if layer.shape != (h,w): |
|
|
img[layer_idx] = cv2.resize(layer, (h,w), interpolation=cv2.INTER_NEAREST) |
|
|
img = torch.from_numpy(np.stack(img,-1).astype(np.float32)) |
|
|
|
|
|
return img, footprint, crs |
|
|
|
|
|
|
|
|
def forward(self, row, row_meta, device='cuda'): |
|
|
""" |
|
|
Forward pass of the model: Reads the image, fragments it, computes embeddings |
|
|
for each fragment, and returns a GeoDataFrame with the spatial metadata and |
|
|
embeddings. |
|
|
|
|
|
Args: |
|
|
row (pandas.Series): The input row containing the image data. |
|
|
row_meta (pandas.Series): Metadata associated with the row (e.g., timestamp, product_id). |
|
|
device (str): The device to run the model on ('cpu' or 'cuda'). Default is 'cuda'. |
|
|
|
|
|
Returns: |
|
|
geopandas.GeoDataFrame: A GeoDataFrame containing metadata and embeddings for each fragment. |
|
|
""" |
|
|
|
|
|
img, footprint, crs = self._read_image(row) |
|
|
|
|
|
|
|
|
fragments, xys = fragment_fn(img, **self.frag_params, return_indices=True, verbose=False) |
|
|
|
|
|
nrows, ncols, c, h, w = fragments.shape |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = self.embedder(fragments.reshape(-1,c,h,w).to(device)).view(nrows, ncols, -1) |
|
|
|
|
|
df_rows = [] |
|
|
|
|
|
|
|
|
for r_idx in range(nrows): |
|
|
for c_idx in range(ncols): |
|
|
embedding = embeddings[r_idx, c_idx].cpu().numpy() |
|
|
|
|
|
x_offset,y_offset=xys[r_idx,c_idx].int().tolist() |
|
|
pixel_bbox = [x_offset, y_offset, x_offset + h,y_offset + w] |
|
|
utm_footprint = crop_footprint(footprint, *img.shape[:2], pixel_bbox) |
|
|
|
|
|
transformer = Transformer.from_crs(crs, CRS.from_epsg(4326), always_xy=True) |
|
|
geometry = transform(transformer.transform, utm_footprint) |
|
|
centre_lon, centre_lat = geometry.centroid.coords[0] |
|
|
|
|
|
row_dict = { |
|
|
'unique_id' : self.calculate_checksum(geometry, row_meta.timestamp.item(), row_meta.product_id.item(), embedding), |
|
|
'embedding' : embedding, |
|
|
'timestamp' : row_meta.timestamp.item(), |
|
|
'product_id' : row_meta.product_id.item(), |
|
|
'grid_cell' : row_meta.grid_cell.item(), |
|
|
'grid_row_u' : row_meta.grid_row_u.item(), |
|
|
'grid_col_r' : row_meta.grid_col_r.item(), |
|
|
'geometry' : geometry, |
|
|
'centre_lat' : centre_lat, |
|
|
'centre_lon' : centre_lon, |
|
|
'utm_footprint' : utm_footprint.wkt, |
|
|
'utm_crs' : crs.to_string(), |
|
|
'pixel_bbox' : pixel_bbox, |
|
|
} |
|
|
df_rows.append(row_dict) |
|
|
|
|
|
return gpd.GeoDataFrame(df_rows).astype(self.column_types) |