import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import models import cv2 import os from typing import Tuple import matplotlib.pyplot as plt import numpy as np class DoubleConv(nn.Module): def __init__(self, in_channels: int, out_channels: int): super(DoubleConv, self).__init__() self.double_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class UpConv(nn.Module): def __init__(self, in_channels: int, out_channels: int): super(UpConv, self).__init__() self.up_conv = nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.up_conv(x) class TLUNet(nn.Module): def __init__(self, input_shape: Tuple[int, int, int], patch_size: int): super(TLUNet, self).__init__() self.patch_size = patch_size # Load pretrained VGG16 with explicit weights specification base_vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) # Freeze VGG16 parameters for param in base_vgg.parameters(): param.requires_grad = False # Extract and store intermediate dimensions self.dims = { 'input': patch_size, 'block1': patch_size // 2, # 128 'block2': patch_size // 4, # 64 'block3': patch_size // 8, # 32 'block4': patch_size // 16, # 16 'block5': patch_size // 32 # 8 } # Encoder blocks from VGG16 self.block1 = nn.Sequential( *list(base_vgg.features.children())[:5]) # 256->128 self.block2 = nn.Sequential( *list(base_vgg.features.children())[5:10]) # 128->64 self.block3 = nn.Sequential( *list(base_vgg.features.children())[10:17]) # 64->32 self.block4 = nn.Sequential( *list(base_vgg.features.children())[17:24]) # 32->16 self.block5 = nn.Sequential( *list(base_vgg.features.children())[24:31]) # 16->8 # Decoder path with reusable components self.upconv1 = UpConv(512, 512) # 8->16 self.conv6 = DoubleConv(1024, 512) # After concatenation self.upconv2 = UpConv(512, 256) # 16->32 self.conv7 = DoubleConv(512, 256) self.upconv3 = UpConv(256, 128) # 32->64 self.conv8 = DoubleConv(256, 128) self.upconv4 = UpConv(128, 64) # 64->128 self.conv9 = DoubleConv(128, 64) # Final upsampling and convolution self.upconv_final = UpConv(64, 32) # 128->256 self.final_conv = nn.Sequential( nn.Conv2d(32, 1, kernel_size=3, padding=1), nn.Sigmoid() ) def forward(self, x): # Encoder path with skip connections e1 = self.block1(x) e2 = self.block2(e1) e3 = self.block3(e2) e4 = self.block4(e3) e5 = self.block5(e4) # Decoder path with skip connections d1 = self.upconv1(e5) d1 = torch.cat([d1, e4], dim=1) d1 = self.conv6(d1) d2 = self.upconv2(d1) d2 = torch.cat([d2, e3], dim=1) d2 = self.conv7(d2) d3 = self.upconv3(d2) d3 = torch.cat([d3, e2], dim=1) d3 = self.conv8(d3) d4 = self.upconv4(d3) d4 = torch.cat([d4, e1], dim=1) d4 = self.conv9(d4) # Final upsampling and convolution out = self.upconv_final(d4) out = self.final_conv(out) return out class ImageMaskDataset(Dataset): def __init__(self, image_dir: str, mask_dir: str, target_size: Tuple[int, int]): self.image_dir = image_dir self.mask_dir = mask_dir self.target_size = target_size self.image_paths = sorted( [os.path.join(image_dir, fname) for fname in os.listdir(image_dir)]) self.mask_paths = sorted([os.path.join(mask_dir, fname) for fname in os.listdir(mask_dir)]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # Load image image = cv2.imread(self.image_paths[idx]) image = image / 255.0 image = torch.FloatTensor(image).permute( 2, 0, 1) # Convert to CHW format # Load mask mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE) mask = mask / 255.0 mask = torch.FloatTensor(mask).unsqueeze(0) # Add channel dimension return image, mask class DiceLoss(nn.Module): def __init__(self, smooth=1e-6, gamma=2): super(DiceLoss, self).__init__() self.smooth = smooth self.gamma = gamma def forward(self, y_pred, y_true): y_pred = y_pred.view(-1) y_true = y_true.view(-1) intersection = torch.sum(y_pred * y_true) denominator = torch.sum(y_pred.pow(self.gamma)) + \ torch.sum(y_true.pow(self.gamma)) dice_score = (2.0 * intersection + self.smooth) / \ (denominator + self.smooth) return 1 - dice_score def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, num_epochs: int, device: torch.device) -> dict: criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters()) history = { 'train_loss': [], 'val_loss': [], 'batch_losses': [] # Track individual batch losses } total_batches = len(train_loader) print(f"Training on {total_batches} batches per epoch") for epoch in range(num_epochs): model.train() train_loss = 0.0 # Progress tracking variables batch_losses = [] print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 60) for batch_idx, (images, masks) in enumerate(train_loader): images, masks = images.to(device), masks.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, masks) loss.backward() optimizer.step() # Record batch loss batch_loss = loss.item() batch_losses.append(batch_loss) train_loss += batch_loss # Print progress every 10% of batches if (batch_idx + 1) % max(1, total_batches // 10) == 0: current_loss = train_loss / (batch_idx + 1) progress = (batch_idx + 1) / total_batches * 100 print(f"Batch {batch_idx + 1}/{total_batches} [{progress:.1f}%] - " f"Current Loss: {current_loss:.4f}") # Validation phase model.eval() val_loss = 0.0 with torch.no_grad(): for val_batch_idx, (images, masks) in enumerate(val_loader): images, masks = images.to(device), masks.to(device) outputs = model(images) loss = criterion(outputs, masks) val_loss += loss.item() # Calculate and record average losses avg_train_loss = train_loss / len(train_loader) avg_val_loss = val_loss / len(val_loader) history['train_loss'].append(avg_train_loss) history['val_loss'].append(avg_val_loss) history['batch_losses'].extend(batch_losses) # Print epoch summary print("\nEpoch Summary:") print(f"Average Train Loss: {avg_train_loss:.4f}") print(f"Average Val Loss: {avg_val_loss:.4f}") print(f"Best batch loss: {min(batch_losses):.4f}") print(f"Worst batch loss: {max(batch_losses):.4f}") print("-" * 60) return history def test_model(model: nn.Module, test_loader: DataLoader, device: torch.device) -> Tuple[float, float]: model.eval() total_loss = 0.0 criterion = nn.BCELoss() with torch.no_grad(): for images, masks in test_loader: images, masks = images.to(device), masks.to(device) outputs = model(images) loss = criterion(outputs, masks) total_loss += loss.item() avg_loss = total_loss / len(test_loader) print(f"Test Loss: {avg_loss:.4f}") return avg_loss def get_hw(): device = "cpu" if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): device = "mps" return device def plot_training_history(history, save_path=None, show_plot=True): """ Plot training metrics from model history. Args: history (dict): Dictionary containing 'train_loss', 'val_loss', and 'batch_losses' save_path (str, optional): Path to save the plot. Defaults to None. show_plot (bool, optional): Whether to display the plot. Defaults to True. """ # Create figure with subplots fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10)) fig.suptitle('Training History', fontsize=16) # Plot epoch-wise losses epochs = range(1, len(history['train_loss']) + 1) # Top subplot: Training and validation loss per epoch ax1.plot(epochs, history['train_loss'], 'b-', label='Training Loss') ax1.plot(epochs, history['val_loss'], 'r-', label='Validation Loss') ax1.set_title('Epoch-wise Training and Validation Loss') ax1.set_xlabel('Epoch') ax1.set_ylabel('Loss') ax1.grid(True) ax1.legend() # Bottom subplot: Batch losses batches = range(1, len(history['batch_losses']) + 1) ax2.plot(batches, history['batch_losses'], 'g-', alpha=0.5, label='Batch Loss') # Add moving average line for batch losses window_size = min(100, len(history['batch_losses']) // 10) # Adaptive window size if window_size > 1: moving_avg = np.convolve(history['batch_losses'], np.ones(window_size)/window_size, mode='valid') ax2.plot(range(window_size, len(batches) + 1), moving_avg, 'r-', label=f'Moving Average (window={window_size})') ax2.set_title('Batch-wise Training Loss') ax2.set_xlabel('Batch') ax2.set_ylabel('Loss') ax2.grid(True) ax2.legend() # Add summary statistics as text stats_text = ( f"Final Training Loss: {history['train_loss'][-1]:.4f}\n" f"Final Validation Loss: {history['val_loss'][-1]:.4f}\n" f"Best Training Loss: {min(history['train_loss']):.4f}\n" f"Best Validation Loss: {min(history['val_loss']):.4f}\n" f"Best Batch Loss: {min(history['batch_losses']):.4f}" ) fig.text(0.95, 0.05, stats_text, fontsize=10, ha='right', bbox=dict(facecolor='white', alpha=0.8)) # Adjust layout to prevent overlap plt.tight_layout() # Save plot if path is provided if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Plot saved to {save_path}") # Show plot if requested if show_plot: plt.show() else: plt.close() # Usage example: if __name__ == "__main__": hw = get_hw() # Parameters patch_size = 256 version = 0 batch_size = 48 num_epochs = 8 device = torch.device(hw) # Paths path = f"Data/Data_{patch_size}_{version}" train_dir_images = os.path.join(path, "train", "images") train_dir_masks = os.path.join(path, "train", "masks") val_dir_images = os.path.join(path, "val", "images") val_dir_masks = os.path.join(path, "val", "masks") test_dir_images = os.path.join(path, "test", "images") test_dir_masks = os.path.join(path, "test", "masks") # Create datasets and dataloaders target_size = (patch_size, patch_size) input_shape = (patch_size, patch_size, 3) train_dataset = ImageMaskDataset( train_dir_images, train_dir_masks, target_size) val_dataset = ImageMaskDataset(val_dir_images, val_dir_masks, target_size) test_dataset = ImageMaskDataset( test_dir_images, test_dir_masks, target_size) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size) test_loader = DataLoader(test_dataset, batch_size=batch_size) # Create and train model model = TLUNet(input_shape, patch_size).to(device) history = train_model(model, train_loader, val_loader, num_epochs, device) # Test model test_loss = test_model(model, test_loader, device) # Save model model_name = f"model__{patch_size}_{batch_size}_{version}.pth" torch.save(model.state_dict(), os.path.join("Models", model_name)) plot_training_history( history, save_path=f'training_history_{patch_size}_{batch_size}_{version}.png' )