| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from dataloader.dataloader import MultiRasterDataset |
| from dataloader.dataloaderMapping import MultiRasterDatasetMapping |
| from dataloader.dataframe_loader import filter_dataframe, separate_and_add_data |
| import pandas as pd |
| from tqdm import tqdm |
| from pathlib import Path |
| import wandb |
| from accelerate import Accelerator |
| from config import (TIME_BEGINNING, TIME_END, INFERENCE_TIME, MAX_OC, |
| seasons, years_padded, num_epochs, |
| SamplesCoordinates_Yearly, MatrixCoordinates_1mil_Yearly, |
| DataYearly, SamplesCoordinates_Seasonally, |
| MatrixCoordinates_1mil_Seasonally, DataSeasonally, |
| file_path_LUCAS_LFU_Lfl_00to23_Bavaria_OC) |
| from torch.utils.data import Dataset, DataLoader |
| from modelCNN import SmallCNN |
| import argparse |
|
|
| def composite_l1_chi2_loss(outputs, targets, sigma=3.0, alpha=0.5): |
| errors = targets - outputs |
| l1_loss = torch.mean(torch.abs(errors)) |
| squared_errors = errors ** 2 |
| chi2_unscaled = (1/4) * squared_errors * torch.exp(-squared_errors / (2 * sigma)) |
| chi2_unscaled_mean = torch.mean(chi2_unscaled) |
| chi2_unscaled_mean = torch.clamp(chi2_unscaled_mean, min=1e-8) |
| scale_factor = l1_loss / chi2_unscaled_mean |
| chi2_scaled = scale_factor * chi2_unscaled_mean |
| return alpha * l1_loss + (1 - alpha) * chi2_scaled |
|
|
| def composite_l2_chi2_loss(outputs, targets, sigma=3.0, alpha=0.5): |
| errors = targets - outputs |
| l2_loss = torch.mean(errors ** 2) |
| chi2_loss = torch.mean((errors ** 2) / (sigma ** 2)) |
| chi2_loss = torch.clamp(chi2_loss, min=1e-8) |
| scale_factor = l2_loss / chi2_loss |
| chi2_scaled = scale_factor * chi2_loss |
| return alpha * l2_loss + (1 - alpha) * chi2_scaled |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Train SimpleCNN model with customizable parameters') |
| parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') |
| parser.add_argument('--loss_type', type=str, default='mse', choices=['composite_l1', 'l1', 'mse','composite_l2'], help='Type of loss function') |
| parser.add_argument('--loss_alpha', type=float, default=0.5, help='Weight for L1 loss in composite loss (if used)') |
| parser.add_argument('--target_transform', type=str, default='log', choices=['none', 'log', 'normalize'], help='Transformation to apply to targets') |
| parser.add_argument('--use_validation', action='store_true', default=True, help='Whether to use validation set') |
| return parser.parse_args() |
|
|
| def create_balanced_dataset(df, use_validation=True, n_bins=128, min_ratio=3/4): |
| bins = pd.qcut(df['OC'], q=n_bins, labels=False, duplicates='drop') |
| df['bin'] = bins |
| bin_counts = df['bin'].value_counts() |
| max_samples = bin_counts.max() |
| min_samples = max(int(max_samples * min_ratio), 5) |
| training_dfs = [] |
| if use_validation: |
| validation_indices = [] |
| for bin_idx in range(len(bin_counts)): |
| bin_data = df[df['bin'] == bin_idx] |
| if len(bin_data) >= 4: |
| val_samples = bin_data.sample(n=min(13, len(bin_data))) |
| validation_indices.extend(val_samples.index) |
| train_samples = bin_data.drop(val_samples.index) |
| if len(train_samples) > 0: |
| if len(train_samples) < min_samples: |
| resampled = train_samples.sample(n=min_samples, replace=True) |
| training_dfs.append(resampled) |
| else: |
| training_dfs.append(train_samples) |
| if not training_dfs or not validation_indices: |
| raise ValueError("No training or validation data available after binning") |
| training_df = pd.concat(training_dfs).drop('bin', axis=1) |
| validation_df = df.loc[validation_indices].drop('bin', axis=1) |
| print('Size of the training set: ', len(training_df)) |
| print('Size of the validation set: ', len(validation_df)) |
| return training_df, validation_df |
| else: |
| for bin_idx in range(len(bin_counts)): |
| bin_data = df[df['bin'] == bin_idx] |
| if len(bin_data) > 0: |
| if len(bin_data) < min_samples: |
| resampled = bin_data.sample(n=min_samples, replace=True) |
| training_dfs.append(resampled) |
| else: |
| training_dfs.append(bin_data) |
| if not training_dfs: |
| raise ValueError("No training data available after binning") |
| training_df = pd.concat(training_dfs).drop('bin', axis=1) |
| return training_df, None |
|
|
| def train_model(args, model, train_loader, val_loader, num_epochs, accelerator, loss_type='L1', target_transform='none'): |
| if loss_type == 'composite_l1': |
| criterion = lambda outputs, targets: composite_l1_chi2_loss(outputs, targets, sigma=3.0, alpha=args.loss_alpha) |
| elif loss_type == 'composite_l2': |
| criterion = lambda outputs, targets: composite_l2_chi2_loss(outputs, targets, sigma=3.0, alpha=args.loss_alpha) |
| elif loss_type == 'l1': |
| criterion = nn.L1Loss() |
| elif loss_type == 'mse': |
| criterion = nn.MSELoss() |
| else: |
| raise ValueError(f"Unknown loss type: {loss_type}") |
| |
| optimizer = optim.Adam(model.parameters(), lr=args.lr) |
| train_loader, model, optimizer = accelerator.prepare( |
| train_loader, model, optimizer |
| ) |
| if val_loader is not None: |
| val_loader = accelerator.prepare(val_loader) |
|
|
| if target_transform == 'normalize': |
| all_targets = [] |
| for _, _, _, targets in train_loader: |
| all_targets.append(targets) |
| all_targets = torch.cat(all_targets) |
| target_mean = all_targets.mean().item() |
| target_std = all_targets.std().item() |
| if accelerator.is_main_process: |
| print(f"Target mean: {target_mean}, Target std: {target_std}") |
| else: |
| target_mean, target_std = 0.0, 1.0 |
|
|
| best_r_squared = -float('inf') if args.use_validation else 1.0 |
| best_model_state = None |
| |
| for epoch in range(num_epochs): |
| model.train() |
| running_loss = 0.0 |
| |
| for batch_idx, (longitudes, latitudes, features, targets) in enumerate(tqdm(train_loader)): |
| features = features.to(accelerator.device) |
| targets = targets.to(accelerator.device).float() |
| if target_transform == 'log': |
| targets = torch.log(targets + 1e-10) |
| elif target_transform == 'normalize': |
| targets = (targets - target_mean) / (target_std + 1e-10) |
| optimizer.zero_grad() |
| outputs = model(features) |
| loss = criterion(outputs, targets) |
| accelerator.backward(loss) |
| optimizer.step() |
| running_loss += loss.item() |
| |
| if accelerator.is_main_process: |
| wandb.log({ |
| 'train_loss': loss.item(), |
| 'batch': batch_idx + 1 + epoch * len(train_loader), |
| 'epoch': epoch + 1 |
| }) |
|
|
| train_loss = running_loss / len(train_loader) |
| |
| if args.use_validation and val_loader is not None: |
| model.eval() |
| val_loss = 0.0 |
| val_outputs_list = [] |
| val_targets_list = [] |
| |
| with torch.no_grad(): |
| for longitudes, latitudes, features, targets in val_loader: |
| features = features.to(accelerator.device) |
| targets = targets.to(accelerator.device).float() |
| if target_transform == 'log': |
| targets = torch.log(targets + 1e-10) |
| elif target_transform == 'normalize': |
| targets = (targets - target_mean) / (target_std + 1e-10) |
| outputs = model(features) |
| loss = criterion(outputs, targets) |
| val_loss += loss.item() |
| val_outputs_list.append(outputs.cpu()) |
| val_targets_list.append(targets.cpu()) |
| |
| val_loss = val_loss / len(val_loader) |
| |
| |
| val_outputs = torch.cat(val_outputs_list, dim=0).numpy() |
| val_targets = torch.cat(val_targets_list, dim=0).numpy() |
| |
| |
| val_outputs_all = torch.from_numpy(val_outputs).to(accelerator.device) |
| val_targets_all = torch.from_numpy(val_targets).to(accelerator.device) |
| val_outputs_all = accelerator.gather(val_outputs_all).cpu().numpy() |
| val_targets_all = accelerator.gather(val_targets_all).cpu().numpy() |
| |
| if accelerator.is_main_process: |
| |
| if target_transform == 'log': |
| original_val_outputs = np.exp(val_outputs_all) |
| original_val_targets = np.exp(val_targets_all) |
| elif target_transform == 'normalize': |
| original_val_outputs = val_outputs_all * target_std + target_mean |
| original_val_targets = val_targets_all * target_std + target_mean |
| else: |
| original_val_outputs = val_outputs_all |
| original_val_targets = val_targets_all |
| |
| |
| if len(original_val_outputs) > 1 and np.std(original_val_outputs) > 1e-6 and np.std(original_val_targets) > 1e-6: |
| correlation = np.corrcoef(original_val_outputs, original_val_targets)[0, 1] |
| r_squared = correlation ** 2 |
| mse = np.mean((original_val_outputs - original_val_targets) ** 2) |
| rmse = np.sqrt(mse) |
| mae = np.mean(np.abs(original_val_outputs - original_val_targets)) |
| iqr = np.percentile(original_val_targets, 75) - np.percentile(original_val_targets, 25) |
| rpiq = iqr / rmse if rmse > 0 else float('inf') |
| else: |
| correlation = 0.0 |
| r_squared = 0.0 |
| mse = float('nan') |
| rmse = float('nan') |
| mae = float('nan') |
| rpiq = float('nan') |
| |
| |
| if r_squared > best_r_squared: |
| best_r_squared = r_squared |
| best_model_state = {k: v.cpu() for k, v in model.state_dict().items()} |
| wandb.run.summary['best_r_squared'] = best_r_squared |
|
|
| wandb.log({ |
| 'epoch': epoch + 1, |
| 'train_loss_avg': train_loss, |
| 'val_loss': val_loss, |
| 'correlation': correlation, |
| 'r_squared': r_squared, |
| 'mse': mse, |
| 'rmse': rmse, |
| 'mae': mae, |
| 'rpiq': rpiq |
| }) |
| |
| accelerator.print(f'Epoch {epoch+1}:') |
| accelerator.print(f'Training Loss: {train_loss:.4f}') |
| accelerator.print(f'Validation Loss: {val_loss:.4f}') |
| if accelerator.is_main_process: |
| accelerator.print(f'RPIQ: {rpiq:.4f}\n') |
| else: |
| |
| best_r_squared = 1.0 |
| best_model_state = {k: v.cpu() for k, v in model.state_dict().items()} |
| wandb.run.summary['best_r_squared'] = best_r_squared |
| |
| if accelerator.is_main_process: |
| wandb.log({ |
| 'epoch': epoch + 1, |
| 'train_loss_avg': train_loss, |
| }) |
| |
| accelerator.print(f'Epoch {epoch+1}:') |
| accelerator.print(f'Training Loss: {train_loss:.4f}\n') |
|
|
| return model, None, None, best_model_state, best_r_squared |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| accelerator = Accelerator() |
| |
| wandb.init( |
| project="socmapping-SimpleTimeCNN", |
| config={ |
| "max_oc": MAX_OC, |
| "time_beginning": TIME_BEGINNING, |
| "time_end": TIME_END, |
| "epochs": num_epochs, |
| "batch_size": 256, |
| "learning_rate": 0.001, |
| "input_channels": 6, |
| "loss_type": args.loss_type, |
| "target_transform": args.target_transform, |
| "use_validation": args.use_validation |
| } |
| ) |
|
|
| df = filter_dataframe(TIME_BEGINNING, TIME_END, MAX_OC) |
| samples_coordinates_array_path, data_array_path = separate_and_add_data() |
|
|
| def flatten_paths(path_list): |
| flattened = [] |
| for item in path_list: |
| if isinstance(item, list): |
| flattened.extend(flatten_paths(item)) |
| else: |
| flattened.append(item) |
| return flattened |
|
|
| samples_coordinates_array_path = list(dict.fromkeys(flatten_paths(samples_coordinates_array_path))) |
| data_array_path = list(dict.fromkeys(flatten_paths(data_array_path))) |
|
|
| if args.use_validation: |
| train_df, val_df = create_balanced_dataset(df, use_validation=args.use_validation) |
| else: |
| train_df, val_df = create_balanced_dataset(df, use_validation=args.use_validation) |
|
|
| train_dataset = MultiRasterDataset(samples_coordinates_array_path, data_array_path, train_df) |
| val_dataset = MultiRasterDataset(samples_coordinates_array_path, data_array_path, val_df) if val_df is not None else None |
|
|
| train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False) if val_dataset is not None else None |
|
|
| model = SmallCNN(input_channels=6) |
| |
| if accelerator.is_main_process: |
| wandb.run.summary["model_parameters"] = model.count_parameters() |
| wandb.run.summary["train_size"] = len(train_df) |
| wandb.run.summary["val_size"] = len(val_df) if val_df is not None else 0 |
| print(f"Model parameters: {model.count_parameters()}") |
| print(f"Training set size: {len(train_df)}") |
| print(f"Validation set size: {len(val_df) if val_df is not None else 0}") |
|
|
| model, val_outputs, val_targets, best_model_state, best_r_squared = train_model( |
| args, model, train_loader, val_loader, |
| num_epochs=num_epochs, |
| accelerator=accelerator, |
| loss_type=args.loss_type, |
| target_transform=args.target_transform |
| ) |
|
|
| if accelerator.is_main_process and best_model_state is not None: |
| final_model_path = (f'simpletimecnn_model_MAX_OC_{MAX_OC}_TIME_BEGINNING_{TIME_BEGINNING}_' |
| f'TIME_END_{TIME_END}_LOSS_{args.loss_type}_TRANSFORM_{args.target_transform}_' |
| f'BEST_R2_{best_r_squared:.4f}.pth') |
| torch.save(best_model_state, final_model_path) |
| wandb.save(final_model_path) |
| print(f"Best model saved with R²: {best_r_squared:.4f}") |
|
|
| wandb.finish() |
| |