File size: 7,878 Bytes
a16f583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import numpy as np
import xgboost as xgb
import matplotlib.pyplot as plt
import pandas as pd
import torch
from tqdm import tqdm
from pathlib import Path
from accelerate import Accelerator
from torch.utils.data import Dataset, DataLoader, TensorDataset
import os

from dataloader.dataloader import MultiRasterDataset 
from dataloader.dataloaderMapping import MultiRasterDatasetMapping
from dataloader.dataframe_loader import filter_dataframe, separate_and_add_data
from config import (TIME_BEGINNING, TIME_END, INFERENCE_TIME, seasons, years_padded, 
                   SamplesCoordinates_Yearly, MatrixCoordinates_1mil_Yearly, 
                   DataYearly, SamplesCoordinates_Seasonally, 
                   MatrixCoordinates_1mil_Seasonally, DataSeasonally, 
                   file_path_LUCAS_LFU_Lfl_00to23_Bavaria_OC, 
                   window_size, bands_list_order)
from mapping import create_prediction_visualizations
from modelCNN import SmallCNN

def load_cnn_model(model_path="/home/vfourel/SOCProject/SOCmapping/SimpleTimeModel/SimpleCNN_map/simpletimecnn_model_MAX_OC_150_TIME_BEGINNING_2007_TIME_END_2023.pth"):
    print("1. Entering load_cnn_model")
    try:
        model = SmallCNN()
        print("2. Model initialized:", model.__class__.__name__)
        state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
        print("3. State dict loaded with keys:", list(state_dict.keys())[:5])
        
        if all(k.startswith('module.') for k in state_dict.keys()):
            print("4. Removing 'module.' prefix from state_dict keys")
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        elif not any(k.startswith('module.') for k in state_dict.keys()) and hasattr(model, 'module'):
            print("5. Adding 'module.' prefix to state_dict keys")
            state_dict = {'module.' + k: v for k, v in state_dict.items()}
            
        model.load_state_dict(state_dict)
        print("6. State dict loaded into model")
        model.eval()
        print("7. Model set to evaluation mode")
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

def separate_and_add_data_1mil_inference(TIME_END=INFERENCE_TIME, 
                                        SamplesCoordinates_Yearly=MatrixCoordinates_1mil_Yearly):
    print("8. Entering separate_and_add_data_1mil_inference")
    
    def create_path_arrays_yearly(paths, year):
        samples_paths = []
        for path in paths:
            if 'Elevation' in path:
                samples_paths.append(path)  # Elevation is static
            else:
                samples_paths.append(f"{path}/{year}")
        return samples_paths

    samples_paths = create_path_arrays_yearly(SamplesCoordinates_Yearly, TIME_END)
    
    print("9. Returning samples paths:", len(samples_paths))
    return samples_paths

def flatten_paths(path_list):
    print("10. Entering flatten_paths with input length:", len(path_list))
    flattened = [item for sublist in path_list for item in (flatten_paths(sublist) if isinstance(sublist, list) else [sublist])]
    print("11. Flattened paths length:", len(flattened))
    return flattened

def accelerated_predict(df_full, cnn_model, subfolders, batch_size=512):
    print("12. Entering accelerated_predict")
    print("13. df_full shape:", df_full.shape)
    print("14. subfolders:", subfolders[:5])
    
    try:
        accelerator = Accelerator()
        print("15. Accelerator initialized")

        print("16. Creating MultiRasterDatasetMapping")
        dataset = MultiRasterDatasetMapping(subfolders, df_full, bands_list_order=bands_list_order)
        print("17. Dataset length:", len(dataset))
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        print("18. DataLoader created with batch_size:", batch_size)

        model, dataloader = accelerator.prepare(cnn_model, dataloader)
        print("19. Model and DataLoader prepared with accelerator")

        all_predictions = []
        all_coordinates = []

        with torch.no_grad():
            print("20. Starting inference loop")
            for i, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
                inputs, coords = batch  # inputs: (batch, num_bands, h, w), coords: (batch, 2)
                outputs = model(inputs)

                # Gather outputs and coordinates across all processes
                gathered_outputs = accelerator.gather(outputs)
                gathered_coords = accelerator.gather(coords)

                if accelerator.is_main_process:
                    # Flatten predictions if necessary (assuming outputs are 2D or higher)
                    gathered_outputs = gathered_outputs.view(-1).cpu().numpy()  # Flatten to 1D
                    gathered_coords = gathered_coords.view(-1, 2).cpu().numpy()  # Ensure coords are (N, 2)
                    all_predictions.append(gathered_outputs)
                    all_coordinates.append(gathered_coords)

        if accelerator.is_main_process:
            predictions = np.concatenate(all_predictions, axis=0)
            coordinates = np.concatenate(all_coordinates, axis=0)
            print("27. Final predictions shape:", predictions.shape)
            print("28. Final coordinates shape:", coordinates.shape)
            
            # Ensure lengths match
            if len(predictions) != len(coordinates):
                raise ValueError(f"Mismatch between predictions ({len(predictions)}) and coordinates ({len(coordinates)})")
            
            return coordinates, predictions
        else:
            print("29. Not main process, returning None")
            return None, None
            
    except Exception as e:
        print(f"Error in accelerated prediction: {e}")
        return None, None
    finally:
        if 'accelerator' in locals():
            accelerator.free_memory()
            accelerator.wait_for_everyone()

def main():
    print("30. Entering main")
    try:
        subfolders = separate_and_add_data_1mil_inference()
        print("31. Subfolders retrieved:", len(subfolders))
        
        subfolders = flatten_paths(subfolders)
        print("32. Paths flattened:", len(subfolders))

        subfolders = list(dict.fromkeys(subfolders))
        print("33. Duplicates removed:", len(subfolders))

        file_path = "/home/vfourel/SOCProject/SOCmapping/Data/Coordinates1Mil/coordinates_Bavaria_1mil.csv"
        df_full = pd.read_csv(file_path)
        print("34. Coordinates loaded, head:\n", df_full.head())

        cnn_model = load_cnn_model()
        print("35. CNN model loaded")

        coordinates, predictions = accelerated_predict(
            df_full=df_full,
            cnn_model=cnn_model,
            subfolders=subfolders,
            batch_size=512
        )
        print("36. Prediction completed")

        if coordinates is not None and predictions is not None:
            save_path_coords = "coordinates_1mil.npy"
            save_path_preds = "predictions_1mil.npy"

            np.save(save_path_coords, coordinates)
            np.save(save_path_preds, predictions)
            print("37. Results saved to:", save_path_coords, save_path_preds)

            save_path = '/home/vfourel/SOCProject/SOCmapping/predictions_plots/cnnsimple_plots'
            Path(save_path).mkdir(parents=True, exist_ok=True)
            print("38. Visualization directory created:", save_path)
            create_prediction_visualizations(INFERENCE_TIME, coordinates, predictions, save_path)
            print(f"39. Results saved and visualizations created at {save_path}")
        else:
            print("40. Prediction failed or not on main process")

    except Exception as e:
        print(f"Error in main: {e}")
        raise

if __name__ == "__main__":
    main()