ValerianFourel's picture
Upload SOC mapping model weights and inference files
a16f583 verified
import numpy as np
import xgboost as xgb
import matplotlib.pyplot as plt
import pandas as pd
import torch
from tqdm import tqdm
from pathlib import Path
from accelerate import Accelerator
from torch.utils.data import Dataset, DataLoader, TensorDataset
import os
from dataloader.dataloader import MultiRasterDataset
from dataloader.dataloaderMapping import MultiRasterDatasetMapping
from dataloader.dataframe_loader import filter_dataframe, separate_and_add_data
from config import (TIME_BEGINNING, TIME_END, INFERENCE_TIME, seasons, years_padded,
SamplesCoordinates_Yearly, MatrixCoordinates_1mil_Yearly,
DataYearly, SamplesCoordinates_Seasonally,
MatrixCoordinates_1mil_Seasonally, DataSeasonally,
file_path_LUCAS_LFU_Lfl_00to23_Bavaria_OC,
window_size, bands_list_order)
from mapping import create_prediction_visualizations
from modelCNN import SmallCNN
def load_cnn_model(model_path="/home/vfourel/SOCProject/SOCmapping/SimpleTimeModel/SimpleCNN_map/simpletimecnn_model_MAX_OC_150_TIME_BEGINNING_2007_TIME_END_2023.pth"):
print("1. Entering load_cnn_model")
try:
model = SmallCNN()
print("2. Model initialized:", model.__class__.__name__)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
print("3. State dict loaded with keys:", list(state_dict.keys())[:5])
if all(k.startswith('module.') for k in state_dict.keys()):
print("4. Removing 'module.' prefix from state_dict keys")
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
elif not any(k.startswith('module.') for k in state_dict.keys()) and hasattr(model, 'module'):
print("5. Adding 'module.' prefix to state_dict keys")
state_dict = {'module.' + k: v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
print("6. State dict loaded into model")
model.eval()
print("7. Model set to evaluation mode")
return model
except Exception as e:
print(f"Error loading model: {e}")
raise
def separate_and_add_data_1mil_inference(TIME_END=INFERENCE_TIME,
SamplesCoordinates_Yearly=MatrixCoordinates_1mil_Yearly):
print("8. Entering separate_and_add_data_1mil_inference")
def create_path_arrays_yearly(paths, year):
samples_paths = []
for path in paths:
if 'Elevation' in path:
samples_paths.append(path) # Elevation is static
else:
samples_paths.append(f"{path}/{year}")
return samples_paths
samples_paths = create_path_arrays_yearly(SamplesCoordinates_Yearly, TIME_END)
print("9. Returning samples paths:", len(samples_paths))
return samples_paths
def flatten_paths(path_list):
print("10. Entering flatten_paths with input length:", len(path_list))
flattened = [item for sublist in path_list for item in (flatten_paths(sublist) if isinstance(sublist, list) else [sublist])]
print("11. Flattened paths length:", len(flattened))
return flattened
def accelerated_predict(df_full, cnn_model, subfolders, batch_size=512):
print("12. Entering accelerated_predict")
print("13. df_full shape:", df_full.shape)
print("14. subfolders:", subfolders[:5])
try:
accelerator = Accelerator()
print("15. Accelerator initialized")
print("16. Creating MultiRasterDatasetMapping")
dataset = MultiRasterDatasetMapping(subfolders, df_full, bands_list_order=bands_list_order)
print("17. Dataset length:", len(dataset))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
print("18. DataLoader created with batch_size:", batch_size)
model, dataloader = accelerator.prepare(cnn_model, dataloader)
print("19. Model and DataLoader prepared with accelerator")
all_predictions = []
all_coordinates = []
with torch.no_grad():
print("20. Starting inference loop")
for i, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
inputs, coords = batch # inputs: (batch, num_bands, h, w), coords: (batch, 2)
outputs = model(inputs)
# Gather outputs and coordinates across all processes
gathered_outputs = accelerator.gather(outputs)
gathered_coords = accelerator.gather(coords)
if accelerator.is_main_process:
# Flatten predictions if necessary (assuming outputs are 2D or higher)
gathered_outputs = gathered_outputs.view(-1).cpu().numpy() # Flatten to 1D
gathered_coords = gathered_coords.view(-1, 2).cpu().numpy() # Ensure coords are (N, 2)
all_predictions.append(gathered_outputs)
all_coordinates.append(gathered_coords)
if accelerator.is_main_process:
predictions = np.concatenate(all_predictions, axis=0)
coordinates = np.concatenate(all_coordinates, axis=0)
print("27. Final predictions shape:", predictions.shape)
print("28. Final coordinates shape:", coordinates.shape)
# Ensure lengths match
if len(predictions) != len(coordinates):
raise ValueError(f"Mismatch between predictions ({len(predictions)}) and coordinates ({len(coordinates)})")
return coordinates, predictions
else:
print("29. Not main process, returning None")
return None, None
except Exception as e:
print(f"Error in accelerated prediction: {e}")
return None, None
finally:
if 'accelerator' in locals():
accelerator.free_memory()
accelerator.wait_for_everyone()
def main():
print("30. Entering main")
try:
subfolders = separate_and_add_data_1mil_inference()
print("31. Subfolders retrieved:", len(subfolders))
subfolders = flatten_paths(subfolders)
print("32. Paths flattened:", len(subfolders))
subfolders = list(dict.fromkeys(subfolders))
print("33. Duplicates removed:", len(subfolders))
file_path = "/home/vfourel/SOCProject/SOCmapping/Data/Coordinates1Mil/coordinates_Bavaria_1mil.csv"
df_full = pd.read_csv(file_path)
print("34. Coordinates loaded, head:\n", df_full.head())
cnn_model = load_cnn_model()
print("35. CNN model loaded")
coordinates, predictions = accelerated_predict(
df_full=df_full,
cnn_model=cnn_model,
subfolders=subfolders,
batch_size=512
)
print("36. Prediction completed")
if coordinates is not None and predictions is not None:
save_path_coords = "coordinates_1mil.npy"
save_path_preds = "predictions_1mil.npy"
np.save(save_path_coords, coordinates)
np.save(save_path_preds, predictions)
print("37. Results saved to:", save_path_coords, save_path_preds)
save_path = '/home/vfourel/SOCProject/SOCmapping/predictions_plots/cnnsimple_plots'
Path(save_path).mkdir(parents=True, exist_ok=True)
print("38. Visualization directory created:", save_path)
create_prediction_visualizations(INFERENCE_TIME, coordinates, predictions, save_path)
print(f"39. Results saved and visualizations created at {save_path}")
else:
print("40. Prediction failed or not on main process")
except Exception as e:
print(f"Error in main: {e}")
raise
if __name__ == "__main__":
main()