Spaces:
Paused
Paused
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| import random | |
| from scipy.ndimage import gaussian_filter, map_coordinates # Add this line | |
| import PIL | |
| class ResidualConvBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(ResidualConvBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | |
| self.in1 = nn.InstanceNorm2d(out_channels) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) | |
| self.in2 = nn.InstanceNorm2d(out_channels) | |
| self.relu = nn.LeakyReLU(inplace=True) | |
| self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None | |
| def forward(self, x): | |
| residual = x | |
| out = self.relu(self.in1(self.conv1(x))) | |
| out = self.in2(self.conv2(out)) | |
| if self.downsample: | |
| residual = self.downsample(x) | |
| out += residual | |
| return self.relu(out) | |
| class AttentionGate(nn.Module): | |
| def __init__(self, F_g, F_l, F_int): | |
| super(AttentionGate, self).__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.InstanceNorm2d(F_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.InstanceNorm2d(F_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.InstanceNorm2d(1), | |
| nn.Sigmoid() | |
| ) | |
| self.relu = nn.LeakyReLU(inplace=True) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| psi = self.relu(g1 + x1) | |
| psi = self.psi(psi) | |
| return x * psi | |
| class EnhancedUNet(nn.Module): | |
| def __init__(self, n_channels, n_classes): | |
| super(EnhancedUNet, self).__init__() | |
| self.n_channels = n_channels | |
| self.n_classes = n_classes | |
| self.inc = ResidualConvBlock(n_channels, 64) | |
| self.down1 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(64, 128)) | |
| self.down2 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(128, 256)) | |
| self.down3 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(256, 512)) | |
| self.down4 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(512, 1024)) | |
| self.dilation = nn.Sequential( | |
| nn.Conv2d(1024, 1024, kernel_size=3, padding=2, dilation=2), | |
| nn.InstanceNorm2d(1024), | |
| nn.LeakyReLU(inplace=True), | |
| nn.Conv2d(1024, 1024, kernel_size=3, padding=4, dilation=4), | |
| nn.InstanceNorm2d(1024), | |
| nn.LeakyReLU(inplace=True) | |
| ) | |
| self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) | |
| self.att4 = AttentionGate(F_g=512, F_l=512, F_int=256) | |
| self.up_conv4 = ResidualConvBlock(1024, 512) | |
| self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) | |
| self.att3 = AttentionGate(F_g=256, F_l=256, F_int=128) | |
| self.up_conv3 = ResidualConvBlock(512, 256) | |
| self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
| self.att2 = AttentionGate(F_g=128, F_l=128, F_int=64) | |
| self.up_conv2 = ResidualConvBlock(256, 128) | |
| self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
| self.att1 = AttentionGate(F_g=64, F_l=64, F_int=32) | |
| self.up_conv1 = ResidualConvBlock(128, 64) | |
| self.outc = nn.Conv2d(64, n_classes, kernel_size=1) | |
| self.dropout = nn.Dropout(0.5) | |
| def forward(self, x): | |
| x1 = self.inc(x) | |
| x2 = self.down1(x1) | |
| x2 = self.dropout(x2) | |
| x3 = self.down2(x2) | |
| x3 = self.dropout(x3) | |
| x4 = self.down3(x3) | |
| x4 = self.dropout(x4) | |
| x5 = self.down4(x4) | |
| x5 = self.dilation(x5) | |
| x5 = self.dropout(x5) | |
| x = self.up4(x5) | |
| x4 = self.att4(g=x, x=x4) | |
| x = torch.cat([x4, x], dim=1) | |
| x = self.up_conv4(x) | |
| x = self.dropout(x) | |
| x = self.up3(x) | |
| x3 = self.att3(g=x, x=x3) | |
| x = torch.cat([x3, x], dim=1) | |
| x = self.up_conv3(x) | |
| x = self.dropout(x) | |
| x = self.up2(x) | |
| x2 = self.att2(g=x, x=x2) | |
| x = torch.cat([x2, x], dim=1) | |
| x = self.up_conv2(x) | |
| x = self.dropout(x) | |
| x = self.up1(x) | |
| x1 = self.att1(g=x, x=x1) | |
| x = torch.cat([x1, x], dim=1) | |
| x = self.up_conv1(x) | |
| logits = self.outc(x) | |
| return logits | |
| class MoS2Dataset(Dataset): | |
| def __init__(self, root_dir, transform=None): | |
| self.root_dir = root_dir | |
| self.transform = transform | |
| self.images_dir = os.path.join(root_dir, 'images') | |
| self.labels_dir = os.path.join(root_dir, 'labels') | |
| self.image_files = [] | |
| for f in sorted(os.listdir(self.images_dir)): | |
| if f.endswith('.png'): | |
| try: | |
| Image.open(os.path.join(self.images_dir, f)).verify() | |
| self.image_files.append(f) | |
| except: | |
| print(f"Skipping unreadable image: {f}") | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| img_name = self.image_files[idx] | |
| img_path = os.path.join(self.images_dir, img_name) | |
| if not os.path.exists(img_path): | |
| print(f"Image file does not exist: {img_path}") | |
| return None, None | |
| label_name = f"image_{img_name.split('_')[1].replace('.png', '.npy')}" | |
| label_path = os.path.join(self.labels_dir, label_name) | |
| try: | |
| image = np.array(Image.open(img_path).convert('L'), dtype=np.float32) / 255.0 | |
| label = np.load(label_path).astype(np.int64) | |
| except (PIL.UnidentifiedImageError, FileNotFoundError, IOError) as e: | |
| print(f"Error loading image {img_path}: {str(e)}") | |
| return None, None # Or handle this case appropriately | |
| if self.transform: | |
| image, label = self.transform(image, label) | |
| image = torch.from_numpy(image).float().unsqueeze(0) | |
| label = torch.from_numpy(label).long() | |
| return image, label | |
| class AugmentationTransform: | |
| def __init__(self): | |
| self.aug_functions = [ | |
| self.random_brightness_contrast, | |
| self.random_gamma, | |
| self.random_noise, | |
| self.random_elastic_deform | |
| ] | |
| def __call__(self, image, label): | |
| for aug_func in self.aug_functions: | |
| if random.random() < 0.5: # 50% chance to apply each augmentation | |
| image, label = aug_func(image, label) | |
| return image.astype(np.float32), label # Ensure float32 | |
| def random_brightness_contrast(self, image, label): | |
| brightness = random.uniform(0.7, 1.3) | |
| contrast = random.uniform(0.7, 1.3) | |
| image = np.clip(brightness * image + contrast * (image - 0.5) + 0.5, 0, 1) | |
| return image, label | |
| def random_gamma(self, image, label): | |
| gamma = random.uniform(0.7, 1.3) | |
| image = np.power(image, gamma) | |
| return image, label | |
| def random_noise(self, image, label): | |
| noise = np.random.normal(0, 0.05, image.shape) | |
| image = np.clip(image + noise, 0, 1) | |
| return image, label | |
| def random_elastic_deform(self, image, label): | |
| alpha = random.uniform(10, 20) | |
| sigma = random.uniform(3, 5) | |
| shape = image.shape | |
| dx = np.random.rand(*shape) * 2 - 1 | |
| dy = np.random.rand(*shape) * 2 - 1 | |
| dx = gaussian_filter(dx, sigma, mode="constant", cval=0) * alpha | |
| dy = gaussian_filter(dy, sigma, mode="constant", cval=0) * alpha | |
| x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) | |
| indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)) | |
| image = map_coordinates(image, indices, order=1).reshape(shape) | |
| label = map_coordinates(label, indices, order=0).reshape(shape) | |
| return image, label | |
| def focal_loss(output, target, alpha=0.25, gamma=2): | |
| ce_loss = nn.CrossEntropyLoss(reduction='none')(output, target) | |
| pt = torch.exp(-ce_loss) | |
| focal_loss = alpha * (1-pt)**gamma * ce_loss | |
| return focal_loss.mean() | |
| def dice_loss(output, target, smooth=1e-5): | |
| output = torch.softmax(output, dim=1) | |
| num_classes = output.shape[1] | |
| dice_sum = 0 | |
| for c in range(num_classes): | |
| pred_class = output[:, c, :, :] | |
| target_class = (target == c).float() | |
| intersection = (pred_class * target_class).sum() | |
| union = pred_class.sum() + target_class.sum() | |
| dice = (2. * intersection + smooth) / (union + smooth) | |
| dice_sum += dice | |
| return 1 - dice_sum / num_classes | |
| def combined_loss(output, target): | |
| fl = focal_loss(output, target) | |
| dl = dice_loss(output, target) | |
| return 0.5 * fl + 0.5 * dl | |
| def iou_score(output, target): | |
| smooth = 1e-5 | |
| output = torch.argmax(output, dim=1) | |
| intersection = (output & target).float().sum((1, 2)) | |
| union = (output | target).float().sum((1, 2)) | |
| iou = (intersection + smooth) / (union + smooth) | |
| return iou.mean() | |
| def pixel_accuracy(output, target): | |
| output = torch.argmax(output, dim=1) | |
| correct = torch.eq(output, target).int() | |
| accuracy = float(correct.sum()) / float(correct.numel()) | |
| return accuracy | |
| def train_one_epoch(model, dataloader, optimizer, criterion, device): | |
| model.train() | |
| total_loss = 0 | |
| total_iou = 0 | |
| total_accuracy = 0 | |
| pbar = tqdm(dataloader, desc='Training') | |
| for images, labels in pbar: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| total_iou += iou_score(outputs, labels) | |
| total_accuracy += pixel_accuracy(outputs, labels) | |
| pbar.set_postfix({'Loss': total_loss / (pbar.n + 1), | |
| 'IoU': total_iou / (pbar.n + 1), | |
| 'Accuracy': total_accuracy / (pbar.n + 1)}) | |
| return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader) | |
| def validate(model, dataloader, criterion, device): | |
| model.eval() | |
| total_loss = 0 | |
| total_iou = 0 | |
| total_accuracy = 0 | |
| with torch.no_grad(): | |
| pbar = tqdm(dataloader, desc='Validation') | |
| for images, labels in pbar: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() | |
| total_iou += iou_score(outputs, labels) | |
| total_accuracy += pixel_accuracy(outputs, labels) | |
| pbar.set_postfix({'Loss': total_loss / (pbar.n + 1), | |
| 'IoU': total_iou / (pbar.n + 1), | |
| 'Accuracy': total_accuracy / (pbar.n + 1)}) | |
| return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader) | |
| def main(): | |
| # Hyperparameters | |
| num_classes = 4 | |
| batch_size = 64 | |
| num_epochs = 100 | |
| learning_rate = 1e-4 | |
| weight_decay = 1e-5 | |
| # Device configuration | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Create datasets and data loaders | |
| transform = AugmentationTransform() | |
| # dataset = MoS2Dataset('MoS2_dataset_advanced_v2', transform=transform) | |
| dataset = MoS2Dataset('dataset_with_noise_npy') | |
| train_size = int(0.8 * len(dataset)) | |
| val_size = len(dataset) - train_size | |
| train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) | |
| # Create model | |
| model = EnhancedUNet(n_channels=1, n_classes=num_classes).to(device) | |
| # Loss and optimizer | |
| criterion = combined_loss | |
| optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10, verbose=True) | |
| # Create directory for saving models and visualizations | |
| save_dir = 'enhanced_training_results' | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Training loop | |
| best_val_iou = 0.0 | |
| for epoch in range(1, num_epochs + 1): | |
| print(f"Epoch {epoch}/{num_epochs}") | |
| train_loss, train_iou, train_accuracy = train_one_epoch(model, train_loader, optimizer, criterion, device) | |
| val_loss, val_iou, val_accuracy = validate(model, val_loader, criterion, device) | |
| print(f"Train - Loss: {train_loss:.4f}, IoU: {train_iou:.4f}, Accuracy: {train_accuracy:.4f}") | |
| print(f"Val - Loss: {val_loss:.4f}, IoU: {val_iou:.4f}, Accuracy: {val_accuracy:.4f}") | |
| scheduler.step(val_iou) | |
| if val_iou > best_val_iou: | |
| best_val_iou = val_iou | |
| torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth')) | |
| print(f"New best model saved with IoU: {best_val_iou:.4f}") | |
| # Save checkpoint | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), | |
| 'best_val_iou': best_val_iou, | |
| }, os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pth')) | |
| # Visualize predictions every 5 epochs | |
| visualize_prediction(model, val_loader, device, epoch, save_dir) | |
| print("Training completed!") | |
| def visualize_prediction(model, val_loader, device, epoch, save_dir): | |
| model.eval() | |
| images, labels = next(iter(val_loader)) | |
| images, labels = images.to(device), labels.to(device) | |
| with torch.no_grad(): | |
| outputs = model(images) | |
| images = images.cpu().numpy() | |
| labels = labels.cpu().numpy() | |
| predictions = torch.argmax(outputs, dim=1).cpu().numpy() | |
| fig, axs = plt.subplots(2, 3, figsize=(15, 10)) | |
| axs[0, 0].imshow(images[0, 0], cmap='gray') | |
| axs[0, 0].set_title('Input Image') | |
| axs[0, 1].imshow(labels[0], cmap='viridis') | |
| axs[0, 1].set_title('True Label') | |
| axs[0, 2].imshow(predictions[0], cmap='viridis') | |
| axs[0, 2].set_title('Prediction') | |
| axs[1, 0].imshow(images[1, 0], cmap='gray') | |
| axs[1, 1].imshow(labels[1], cmap='viridis') | |
| axs[1, 2].imshow(predictions[1], cmap='viridis') | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(save_dir, f'prediction_epoch_{epoch}.png')) | |
| plt.close() | |
| if __name__ == "__main__": | |
| main() | |