| | import os |
| | import torch |
| | import yaml |
| | import json |
| | import rasterio |
| | from rasterio.windows import Window |
| | from rasterio.transform import rowcol |
| | from pyproj import Transformer |
| | from torchvision import transforms |
| | import numpy as np |
| | from rasterio.features import shapes |
| | from shapely.geometry import shape |
| | import geopandas as gpd |
| | from dotenv import load_dotenv |
| |
|
| | from messis.messis import LogConfusionMatrix |
| |
|
| | |
| | load_dotenv() |
| |
|
| | class InferenceDataLoader: |
| | def __init__(self, features_path, labels_path, field_ids_path, stats_path, window_size=224, n_timesteps=3, fold_indices=None, debug=False): |
| | self.features_path = features_path |
| | self.labels_path = labels_path |
| | self.field_ids_path = field_ids_path |
| | self.stats_path = stats_path |
| | self.window_size = window_size |
| | self.n_timesteps = n_timesteps |
| | self.fold_indices = fold_indices if fold_indices is not None else [] |
| | self.debug = debug |
| |
|
| | |
| | self.means, self.stds = self.load_stats() |
| |
|
| | |
| | self.transformer = Transformer.from_crs("EPSG:4326", "EPSG:32632", always_xy=True) |
| |
|
| | def load_stats(self): |
| | """Load normalization statistics for dataset from YAML file.""" |
| | if self.debug: |
| | print(f"Loading mean/std stats from {self.stats_path}") |
| | assert os.path.exists(self.stats_path), f"Mean/std stats file not found at {self.stats_path}" |
| |
|
| | with open(self.stats_path, 'r') as file: |
| | stats = yaml.safe_load(file) |
| |
|
| | mean_list, std_list, n_list = [], [], [] |
| | for fold in self.fold_indices: |
| | key = f'fold_{fold}' |
| | if key not in stats: |
| | raise ValueError(f"Mean/std stats for fold {fold} not found in {self.stats_path}") |
| | if self.debug: |
| | print(f"Stats with selected test fold {fold}: {stats[key]} over {self.n_timesteps} timesteps.") |
| | mean_list.append(torch.tensor(stats[key]['mean'])) |
| | std_list.append(torch.tensor(stats[key]['std'])) |
| | n_list.append(stats[key]['n_chips']) |
| | |
| | means, stds = [], [] |
| | for channel in range(mean_list[0].shape[0]): |
| | means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean()) |
| | variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))]) |
| | n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32) |
| | combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list)) |
| | stds.append(torch.sqrt(combined_variance)) |
| | |
| | return means * self.n_timesteps, stds * self.n_timesteps |
| |
|
| | def identify_window(self, path, lon, lat): |
| | """Identify the 224x224 window centered on the clicked coordinates (lon, lat) from the specified GeoTIFF.""" |
| | with rasterio.open(path) as src: |
| | |
| | utm_x, utm_y = self.transformer.transform(lon, lat) |
| | if self.debug: |
| | print("Source Transform", src.transform) |
| | print(f"UTM X: {utm_x}, UTM Y: {utm_y}") |
| |
|
| | try: |
| | px, py = rowcol(src.transform, utm_x, utm_y) |
| | except ValueError: |
| | raise ValueError("Coordinates out of bounds for this raster.") |
| |
|
| | if self.debug: |
| | print(f"Row: {py}, Column: {px}") |
| |
|
| | half_window_size = self.window_size // 2 |
| |
|
| | row_off = px - half_window_size |
| | col_off = py - half_window_size |
| |
|
| | if row_off < 0: |
| | row_off = 0 |
| | if col_off < 0: |
| | col_off = 0 |
| | if row_off + self.window_size > src.width: |
| | row_off = src.width - self.window_size |
| | if col_off + self.window_size > src.height: |
| | col_off = src.height - self.window_size |
| |
|
| | window = Window(col_off, row_off, self.window_size, self.window_size) |
| | window_transform = src.window_transform(window) |
| | if self.debug: |
| | print(f"Window: {window}") |
| | print(f"Window Transform: {window_transform}") |
| | crs = src.crs |
| |
|
| | return window, window_transform, crs |
| |
|
| | def extract_window(self, path, window): |
| | """Extract data from the specified window from the GeoTIFF.""" |
| | with rasterio.open(path) as src: |
| | window_data = src.read(window=window) |
| |
|
| | if self.debug: |
| | print(f"Extracted window data from {path}") |
| | print(f"Min: {window_data.min()}, Max: {window_data.max()}") |
| |
|
| | return window_data |
| |
|
| | def prepare_data_for_model(self, features_data): |
| | """Prepare the window data for model inference.""" |
| | |
| | features_data = torch.tensor(features_data, dtype=torch.float32) |
| | |
| | |
| | normalize = transforms.Normalize(mean=self.means, std=self.stds) |
| | features_data = normalize(features_data) |
| | |
| | |
| | height, width = features_data.shape[-2:] |
| | features_data = features_data.view(self.n_timesteps, 6, height, width).permute(1, 0, 2, 3) |
| | |
| | |
| | features_data = features_data.unsqueeze(0) |
| | |
| | return features_data |
| |
|
| | def get_data(self, lon, lat): |
| | """Extract, normalize, and prepare data for inference, including labels and field IDs.""" |
| | |
| | window, features_transform, features_crs = self.identify_window(self.features_path, lon, lat) |
| | |
| | |
| | features_data = self.extract_window(self.features_path, window) |
| | label_data = self.extract_window(self.labels_path, window) |
| | field_ids_data = self.extract_window(self.field_ids_path, window) |
| | |
| | |
| | prepared_features_data = self.prepare_data_for_model(features_data) |
| | |
| | |
| | label_data = torch.tensor(label_data, dtype=torch.long) |
| | field_ids_data = torch.tensor(field_ids_data, dtype=torch.long) |
| | |
| | |
| | return prepared_features_data, label_data, field_ids_data, features_transform, features_crs |
| | |
| | def crop_predictions_to_gdf(field_ids, targets, predictions, transform, crs, class_names): |
| | """ |
| | Convert field_ids, targets, and predictions tensors to field polygons with corresponding class reference. |
| | |
| | :param field_ids: PyTorch tensor of shape (1, 224, 224) representing individual fields |
| | :param targets: PyTorch tensor of shape (1, 224, 224) for targets |
| | :param predictions: PyTorch tensor of shape (1, 224, 224) for predictions |
| | :param transform: Affine transform for georeferencing |
| | :param crs: Coordinate reference system (CRS) of the data |
| | :param class_names: Dictionary mapping class indices to class names |
| | :return: GeoPandas DataFrame with polygons, prediction class labels, and target class labels |
| | """ |
| | field_array = field_ids.squeeze().cpu().numpy().astype(np.int32) |
| | target_array = targets.squeeze().cpu().numpy().astype(np.int8) |
| | pred_array = predictions.squeeze().cpu().numpy().astype(np.int8) |
| |
|
| | polygons = [] |
| | field_values = [] |
| | target_values = [] |
| | pred_values = [] |
| |
|
| | for geom, field_value in shapes(field_array, transform=transform): |
| | polygons.append(shape(geom)) |
| | field_values.append(field_value) |
| |
|
| | |
| | target_value = target_array[field_array == field_value][0] |
| | pred_value = pred_array[field_array == field_value][0] |
| | |
| | target_values.append(target_value) |
| | pred_values.append(pred_value) |
| |
|
| | gdf = gpd.GeoDataFrame({ |
| | 'geometry': polygons, |
| | 'field_id': field_values, |
| | 'target': target_values, |
| | 'prediction': pred_values |
| | }, crs=crs) |
| |
|
| | gdf['prediction_class'] = gdf['prediction'].apply(lambda x: class_names[x]) |
| | gdf['target_class'] = gdf['target'].apply(lambda x: class_names[x]) |
| |
|
| | gdf['correct'] = gdf['target'] == gdf['prediction'] |
| |
|
| | gdf = gdf[gdf.geometry.area > 250] |
| |
|
| | return gdf |
| |
|
| | def perform_inference(lon, lat, model, config, debug=False): |
| | features_path = "https://huggingface.co/datasets/crop-classification/zueri-crop-2/resolve/main/stacked_features_cog.tif" |
| | labels_path = "https://huggingface.co/datasets/crop-classification/zueri-crop-2/resolve/main/labels_cog.tif" |
| | field_ids_path = "https://huggingface.co/datasets/crop-classification/zueri-crop-2/resolve/main/field_ids_cog.tif" |
| |
|
| | |
| | stats_path = "./data/chips_stats.yaml" |
| | dataset_info_path = "./data/dataset_info.json" |
| |
|
| | loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True) |
| |
|
| | |
| | satellite_data, label_data, field_ids_data, features_transform, features_crs = loader.get_data(lon, lat) |
| |
|
| | if debug: |
| | |
| | print(satellite_data.shape) |
| | print(label_data.shape) |
| | print(field_ids_data.shape) |
| |
|
| | with open(dataset_info_path, 'r') as file: |
| | dataset_info = json.load(file) |
| | |
| | class_names = dataset_info['tier3'] |
| |
|
| | tiers_dict = {k: v for k, v in config.hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)} |
| | tiers = list(tiers_dict.keys()) |
| |
|
| | |
| | model.eval() |
| | with torch.no_grad(): |
| | output = model(satellite_data)['tier3_refinement_head'] |
| |
|
| | pixelwise_outputs_stacked, majority_outputs_stacked = LogConfusionMatrix.get_pixelwise_and_majority_outputs(output, tiers, field_ids=field_ids_data, dataset_info=dataset_info) |
| | majority_tier3_predictions = majority_outputs_stacked[2] |
| |
|
| | |
| | gdf = crop_predictions_to_gdf(field_ids_data, label_data, majority_tier3_predictions, features_transform, features_crs, class_names) |
| |
|
| | |
| | gdf = gdf[['prediction_class', 'target_class', 'correct', 'geometry']] |
| | gdf.columns = ['Prediction', 'Target', 'Correct', 'geometry'] |
| |
|
| | return gdf |