File size: 7,719 Bytes
eb1aec4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | import numpy as np
import geopandas as gpd
import hashlib
from rasterio.io import MemoryFile
from .grid_cell_fragment import *
from .models import *
import cv2
class MajorTOM_Embedder(torch.nn.Module):
"""
MajorTOM Embedder class that applies a model to geospatial image fragments,
computes embeddings, and returns metadata for each fragment.
This class is designed to work with raster data, where the image is fragmented
into smaller tiles, and embeddings are computed for each tile using the provided
embedder model. The output is a GeoDataFrame containing spatial metadata and
the corresponding embeddings for each tile.
Attributes:
embedder: A model that generates embeddings for image fragments.
frag_params: Dictionary containing fragmentation parameters such as the
target overlap and border shift.
column_types: Dictionary specifying data types for the output GeoDataFrame columns.
"""
def __init__(self, embedder, target_overlap=0.1, border_shift=True):
"""
Initializes the MajorTOM Embedder with the given parameters.
Args:
embedder (torch.nn.Module): A model that generates embeddings for image fragments.
target_overlap (float): The target overlap between image fragments. Default is 0.1.
border_shift (bool): Whether to shift the borders of fragments to avoid edge artifacts. Default is True.
"""
super().__init__()
# Model
self.embedder = embedder
# Fragmentation Settings
self.frag_params = params = {
'fragment_size' : self.embedder.size[0],
'target_overlap' : target_overlap,
'border_shift' : border_shift
}
# Data types for the output dataframe (commented columns need no conversion)
self.column_types = {
#'unique_id' :,
#'embedding' : ,
#'timestamp' : ,
#'product_id' : ,
#'grid_cell' : ,
'grid_row_u' : 'int16',
'grid_col_r' : 'int16',
'centre_lat' : 'float32',
'centre_lon' : 'float32',
#'utm_footprint' : ,
#'utm_crs' : ,
#'pixel_bbox' : ,
}
def bands(self):
"""
Returns the set of input bands in the correct order.
Returns:
list: List of input bands used by the embedder.
"""
return self.embedder.bands
def size(self):
"""
Returns the input image size.
Returns:
tuple: Tuple representing the image size (height, width).
"""
return self.embedder.size
def calculate_checksum(self, geometry, timestamp, product_id, embedding):
"""
Calculates a checksum for the given geometry, timestamp, product ID, and embedding.
Args:
geometry (shapely.geometry): The geometry object representing the fragment's footprint.
timestamp (str): Timestamp of the data.
product_id (str): Product identifier.
embedding (np.ndarray): The embedding of the image fragment.
Returns:
str: A SHA256 checksum of the concatenated input parameters.
"""
combined = f"{geometry}_{timestamp}_{product_id}_{embedding}"
checksum = hashlib.sha256(combined.encode()).hexdigest()
return checksum
def _read_image(self, row):
"""
Reads and processes the image bands for a given row, performs optional upsampling
if the resolution is mismatched, and returns the image data, footprint, and CRS.
Args:
row (pandas.Series): The input row containing the image bands.
Returns:
torch.Tensor: A tensor containing the stacked image bands.
shapely.geometry: The footprint of the image.
rasterio.crs.CRS: The CRS of the image.
"""
# Read the file
img = []
for band in self.embedder.bands:
with MemoryFile(row[band][0].as_py()) as mem_f:
with mem_f.open(driver='GTiff') as f:
crs = f.crs
footprint = box(*f.bounds)
img.append(f.read()[0])
# optional upsampling
shapes = [layer.shape for layer in img]
if any([el!=shapes[0] for el in shapes]): # if any resolution mismatch
h, w = max([el[0] for el in shapes]), max([el[1] for el in shapes]) # maximum size
for layer_idx, layer in enumerate(img):
if layer.shape != (h,w):
img[layer_idx] = cv2.resize(layer, (h,w), interpolation=cv2.INTER_NEAREST)
img = torch.from_numpy(np.stack(img,-1).astype(np.float32))
return img, footprint, crs
def forward(self, row, row_meta, device='cuda'):
"""
Forward pass of the model: Reads the image, fragments it, computes embeddings
for each fragment, and returns a GeoDataFrame with the spatial metadata and
embeddings.
Args:
row (pandas.Series): The input row containing the image data.
row_meta (pandas.Series): Metadata associated with the row (e.g., timestamp, product_id).
device (str): The device to run the model on ('cpu' or 'cuda'). Default is 'cuda'.
Returns:
geopandas.GeoDataFrame: A GeoDataFrame containing metadata and embeddings for each fragment.
"""
# Read file
img, footprint, crs = self._read_image(row)
# Fragment the sample
fragments, xys = fragment_fn(img, **self.frag_params, return_indices=True, verbose=False)
nrows, ncols, c, h, w = fragments.shape
# Apply the model
with torch.no_grad():
embeddings = self.embedder(fragments.reshape(-1,c,h,w).to(device)).view(nrows, ncols, -1)
df_rows = []
# Pack rows for geoparquet
for r_idx in range(nrows):
for c_idx in range(ncols):
embedding = embeddings[r_idx, c_idx].cpu().numpy()
# spatial features per fragment
x_offset,y_offset=xys[r_idx,c_idx].int().tolist()
pixel_bbox = [x_offset, y_offset, x_offset + h,y_offset + w] # in pixels
utm_footprint = crop_footprint(footprint, *img.shape[:2], pixel_bbox)
# main footprint is in WGS84 (needs to be consistent across parquet)
transformer = Transformer.from_crs(crs, CRS.from_epsg(4326), always_xy=True)
geometry = transform(transformer.transform, utm_footprint) # WGS84
centre_lon, centre_lat = geometry.centroid.coords[0]
row_dict = {
'unique_id' : self.calculate_checksum(geometry, row_meta.timestamp.item(), row_meta.product_id.item(), embedding),
'embedding' : embedding,
'timestamp' : row_meta.timestamp.item(),
'product_id' : row_meta.product_id.item(),
'grid_cell' : row_meta.grid_cell.item(),
'grid_row_u' : row_meta.grid_row_u.item(),
'grid_col_r' : row_meta.grid_col_r.item(),
'geometry' : geometry,
'centre_lat' : centre_lat,
'centre_lon' : centre_lon,
'utm_footprint' : utm_footprint.wkt,
'utm_crs' : crs.to_string(),
'pixel_bbox' : pixel_bbox,
}
df_rows.append(row_dict)
return gpd.GeoDataFrame(df_rows).astype(self.column_types) |