import os import torch import datetime import numpy as np import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt import scripts.config as config from torch.utils.data import DataLoader from sklearn.metrics import jaccard_score from torch.utils.data import random_split import torchvision.transforms as transforms import scripts.Segmentation.augment as augment from scripts.Segmentation.models import ResNetUNet from warmup_scheduler import GradualWarmupScheduler from scripts.Segmentation.focalLoss import FocalLoss from scripts.Segmentation.segDS import SegmentationDataset from scripts.Segmentation.diceLossCriterion import DiceLoss def save_report(row): print(str(row)) with open(config.report_file, 'a', encoding='utf-8') as f: f.write(str(row) + '\n') def compute_class_weights(dataset): class_counts = torch.zeros(2) for _, mask in dataset: pixels = mask.view(-1) for c in [0, 1]: class_counts[c] += (pixels == c).sum() weights = class_counts.sum() / (2.0 * class_counts + 1e-6) weights = weights / weights.sum() return weights os.environ['CUDA_LAUNCH_BLOCKING'] = '1' try: save_report(torch.__version__) save_report(torch.version.cuda) save_report(torch.cuda.get_arch_list()) #['sm_75', 'sm_86'] except Exception as e: save_report(e) pass config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Starting processing...') os.makedirs(config.checkpoints, exist_ok=True) transform = transforms.Compose([ transforms.Resize((config.height, config.width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) dataset = SegmentationDataset(transform = transform) train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_ds, val_ds = random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True) val_loader = DataLoader(val_ds, batch_size=4, shuffle=False) model = ResNetUNet(num_classes=2).to(config.device) class_weights = compute_class_weights(dataset).to(config.device) criterion = None if config.USE_FOCAL_LOSS: criterion = FocalLoss(gamma=2.0) else: criterion = nn.CrossEntropyLoss(weight = class_weights) #optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5) optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4) #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs) cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs) scheduler = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=5, after_scheduler=cosine_scheduler) dataHoraInicial = datetime.datetime.now() save_report('\n\n\nStarting training on: ' + str(dataHoraInicial)) accuracies = [] iou_history = [] loss_history = [] dice_history = [] val_accuracies = [] best_accuracy = 0.0 dice_loss = DiceLoss() epochs_no_improve = 0 model.train() for epoch in range(config.num_epochs): total_loss = 0 correct_pixels = 0 total_pixels = 0 for images, masks in train_loader: images = images.to(config.device).float() masks = masks.to(config.device).long() output = [] if config.USE_TTA: for img in images: preds = augment.predict_with_tta(model, img.unsqueeze(0)) # [1, C, H, W] output.append(preds) output = torch.cat(output, dim=0) # [B, C, H, W] else: output = model(images) output = output.float() masks = masks.long() loss = 1.5 * dice_loss(output, masks) + 0.5 * criterion(output, masks) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() preds = torch.argmax(output, dim=1) if config.USE_REFINEMENT: preds = augment.refine_mask(preds) correct_pixels += (preds == masks).sum().item() total_pixels += torch.numel(preds) total_loss += loss.item() loss_history.append(total_loss) epoch_accuracy = correct_pixels / total_pixels accuracies.append(epoch_accuracy) model.eval() val_correct = 0 val_total = 0 val_preds_all = [] val_targets_all = [] for val_images, val_masks in val_loader: val_images = val_images.to(config.device).float() val_masks = val_masks.to(config.device).long() val_outputs = model(val_images) val_preds = torch.argmax(val_outputs, dim=1) val_preds_all.append(val_preds.view(-1).cpu().numpy()) val_targets_all.append(val_masks.view(-1).cpu().numpy()) val_correct += (val_preds == val_masks).sum().item() val_total += torch.numel(val_preds) val_accuracy = val_correct / val_total val_accuracies.append(val_accuracy) val_preds_flat = np.concatenate(val_preds_all) val_targets_flat = np.concatenate(val_targets_all) #iou = jaccard_score(val_targets_flat, val_preds_flat, average='binary') iou = jaccard_score(val_targets_flat, val_preds_flat, average='macro') # ou 'weighted' intersection = np.logical_and(val_preds_flat, val_targets_flat).sum() union = np.logical_or(val_preds_flat, val_targets_flat).sum() dice = (2 * intersection) / (val_preds_flat.sum() + val_targets_flat.sum() + 1e-6) iou_history.append(iou) dice_history.append(dice) save_report(f"Epoch {epoch+1}/{config.num_epochs}, Loss: {total_loss:.4f}, " f"Train Acc: {epoch_accuracy:.4f}, Val Acc: {val_accuracy:.4f}, " f"IoU: {iou:.4f}, Dice: {dice:.4f}") if (epoch + 1) % config.checkpoint_interval == 0: checkpoint_path = os.path.join(config.checkpoints, f"checkpoint_epoch_{epoch+1}.pt") torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': total_loss, 'accuracy': epoch_accuracy, }, checkpoint_path) if epoch_accuracy > best_accuracy: save_report(f"šŸ”ø New best model at epoch {epoch+1} (acc: {epoch_accuracy:.4f}) — saving best_model.pt") best_accuracy = epoch_accuracy torch.save(model.state_dict(), os.path.join(config.checkpoints, "best_model.pt")) epochs_no_improve = 0 else: epochs_no_improve += 1 if epochs_no_improve >= config.early_stop_patience: save_report(f"\nā›” Early stopping triggered at epoch {epoch+1}") break dataHoraFinal = datetime.datetime.now() save_report('Completing training on: ' + str(dataHoraFinal)) save_report('Total training execution time = ' + str((dataHoraFinal - dataHoraInicial))) model.eval() torch.save(model, config.modelName) try: plt.figure(figsize=(8, 5)) plt.plot(range(1, len(loss_history)+1), loss_history, marker='o') plt.title("Loss Evolution") plt.xlabel("Epochs") plt.ylabel("Loss") plt.grid() plt.tight_layout() plt.savefig(config.source + 'training_loss.png') plt.figure(figsize=(8, 5)) plt.plot(range(1, len(val_accuracies)+1), val_accuracies, marker='o', color='green') plt.title("Validation Accuracy") plt.xlabel("Epochs") plt.ylabel("Pixel Accuracy") plt.grid() plt.tight_layout() plt.savefig(config.source + 'training_val_accuracy.png') plt.figure(figsize=(8, 5)) plt.plot(range(1, len(iou_history)+1), iou_history, marker='o', color='purple') plt.title("IoU Evolution") plt.xlabel("Epochs") plt.ylabel("IoU Score") plt.grid() plt.tight_layout() plt.savefig(config.source + 'iou_history.png') plt.figure(figsize=(8, 5)) plt.plot(range(1, len(dice_history)+1), dice_history, marker='o', color='orange') plt.title("Dice Score Evolution") plt.xlabel("Epochs") plt.ylabel("Dice Score") plt.grid() plt.tight_layout() plt.savefig(config.source + 'dice_history.png') except Exception as e: pass save_report("\nTraining Summary:") save_report(f" Min Loss: {min(loss_history):.4f}") save_report(f" Max Loss: {max(loss_history):.4f}") save_report(f" Loss final: {loss_history[-1]:.4f}") save_report(f" Best Val Acc: {max(val_accuracies):.4f}") print("\nCompleted āœ…")