Segmentation / train5.py
riha55's picture
Upload 7 files
6f774ac verified
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import pytorch_lightning as pl
# from losses import mIoULoss
# from torchvision import models
# class ASSP(nn.Module):
# def __init__(self, in_channels, out_channels=256, final_out_channels=4):
# super(ASSP, self).__init__()
# self.relu = nn.ReLU(inplace=True)
# # 1x1 convolution
# self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=False)
# self.bn1 = nn.BatchNorm2d(out_channels)
# # 3x3 convolutions with different dilation rates
# self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6, bias=False)
# self.bn2 = nn.BatchNorm2d(out_channels)
# self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=12, dilation=12, bias=False)
# self.bn3 = nn.BatchNorm2d(out_channels)
# self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=18, dilation=18, bias=False)
# self.bn4 = nn.BatchNorm2d(out_channels)
# # 1x1 convolution after global average pooling
# self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
# self.bn5 = nn.BatchNorm2d(out_channels)
# # Final 1x1 convolution to combine features
# self.convf = nn.Conv2d(out_channels * 5, final_out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
# self.bnf = nn.BatchNorm2d(final_out_channels)
# # Global average pooling
# self.adapool = nn.AdaptiveAvgPool2d(1)
# def forward(self, x):
# # 1x1 convolution
# x1 = self.conv1(x)
# x1 = self.bn1(x1)
# x1 = self.relu(x1)
# # 3x3 convolution with dilation 6
# x2 = self.conv2(x)
# x2 = self.bn2(x2)
# x2 = self.relu(x2)
# # 3x3 convolution with dilation 12
# x3 = self.conv3(x)
# x3 = self.bn3(x3)
# x3 = self.relu(x3)
# # 3x3 convolution with dilation 18
# x4 = self.conv4(x)
# x4 = self.bn4(x4)
# x4 = self.relu(x4)
# # Global average pooling, 1x1 convolution, and upsample
# x5 = self.adapool(x)
# x5 = self.conv5(x5)
# x5 = self.bn5(x5)
# x5 = self.relu(x5)
# x5 = F.interpolate(x5, size=x4.shape[-2:], mode='bilinear')
# # Concatenate all feature maps
# x = torch.cat((x1, x2, x3, x4, x5), dim=1)
# # Final 1x1 convolution
# x = self.convf(x)
# x = self.bnf(x)
# x = self.relu(x)
# return x
# class ResNet_50(nn.Module):
# def __init__(self, in_channels=3): # Change default to 3 channels for RGB images
# super(ResNet_50, self).__init__()
# # Load the pre-trained ResNet-50 model
# self.resnet_50 = models.resnet50(weights='DEFAULT')
# # Modify the first convolutional layer to accept 3-channel input
# self.resnet_50.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
# # Use the layers up to the final layer before the fully connected layer
# self.resnet_50 = nn.Sequential(*list(self.resnet_50.children())[:-2])
# self.relu = nn.ReLU(inplace=True)
# def forward(self, x):
# x = self.resnet_50(x)
# return x
# class deeplabv3_encoder_decoder(pl.LightningModule):
# def __init__(self, input_channels=3, output_channels=4): # Use 4 channels for output
# super(deeplabv3_encoder_decoder, self).__init__()
# self.resnet = ResNet_50(in_channels=input_channels)
# self.aspp = ASSP(in_channels=2048, final_out_channels=4)
# self.conv = nn.Conv2d(in_channels=4, out_channels=output_channels, kernel_size=1)
# self.criterion = mIoULoss(n_classes=4) # Set number of classes to 4
# def forward(self, x):
# _, _, h, w = x.shape
# x = self.resnet(x) # Output should be [batch_size, 2048, H/32, W/32]
# x = self.aspp(x)
# x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) # Upsample
# x = self.conv(x) # Apply final convolution
# return x
# def training_step(self, batch, batch_idx):
# images, masks = batch
# logits = self(images)
# loss = self.criterion(logits, masks)
# iou = calculate_iou(logits, masks)
# self.log('train_loss', loss)
# self.log('train_iou', iou)
# print(f'Training Loss: {loss}, IoU: {iou}')
# return loss
# def validation_step(self, batch, batch_idx):
# images, masks = batch
# logits = self(images)
# loss = self.criterion(logits, masks)
# iou = calculate_iou(logits, masks)
# self.log('val_loss', loss)
# self.log('val_iou', iou)
# print(f'Validation Loss: {loss}, IoU: {iou}')
# return loss
# def on_training_epoch_end(self, outputs):
# avg_iou = torch.stack([x['train_iou'] for x in outputs]).mean()
# self.log('avg_train_iou', avg_iou)
# def on_validation_epoch_end(self, outputs):
# avg_iou = torch.stack([x['val_iou'] for x in outputs]).mean()
# self.log('avg_val_iou', avg_iou)
# def configure_optimizers(self):
# optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
# return optimizer
# def calculate_iou(logits, masks):
# # Calculate predictions from logits
# preds = torch.argmax(logits, dim=1)
# # Calculate intersection and union
# intersection = torch.sum(preds * masks)
# union = torch.sum((preds.bool() | masks.bool()).int())
# # Avoid division by zero
# iou = intersection / union if union != 0 else torch.tensor(0.0)
# return iou
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from losses import DiceLoss
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
class ASSP(nn.Module):
def __init__(self, in_channels, out_channels=256, final_out_channels=4):
super(ASSP, self).__init__()
self.relu = nn.ReLU(inplace=True)
# 1x1 convolution
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
# 3x3 convolutions with different dilation rates
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=12, dilation=12, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=18, dilation=18, bias=False)
self.bn4 = nn.BatchNorm2d(out_channels)
# 1x1 convolution after global average pooling
self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
self.bn5 = nn.BatchNorm2d(out_channels)
# Final 1x1 convolution to combine features
self.convf = nn.Conv2d(out_channels * 5, final_out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
self.bnf = nn.BatchNorm2d(final_out_channels)
# Global average pooling
self.adapool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
# 1x1 convolution
x1 = self.conv1(x)
x1 = self.bn1(x1)
x1 = self.relu(x1)
# 3x3 convolution with dilation 6
x2 = self.conv2(x)
x2 = self.bn2(x2)
x2 = self.relu(x2)
# 3x3 convolution with dilation 12
x3 = self.conv3(x)
x3 = self.bn3(x3)
x3 = self.relu(x3)
# 3x3 convolution with dilation 18
x4 = self.conv4(x)
x4 = self.bn4(x4)
x4 = self.relu(x4)
# Global average pooling, 1x1 convolution, and upsample
x5 = self.adapool(x)
x5 = self.conv5(x5)
x5 = self.bn5(x5)
x5 = self.relu(x5)
x5 = F.interpolate(x5, size=x4.shape[-2:], mode='bilinear')
# Concatenate all feature maps
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
# Final 1x1 convolution
x = self.convf(x)
x = self.bnf(x)
x = self.relu(x)
return x
class ResNet_50(nn.Module):
def __init__(self, in_channels=3): # Change default to 3 channels for RGB images
super(ResNet_50, self).__init__()
# Load the pre-trained ResNet-50 model
self.resnet_50 = models.resnet50(pretrained=True)
# Modify the first convolutional layer to accept 3-channel input
self.resnet_50.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Use the layers up to the final layer before the fully connected layer
self.resnet_50 = nn.Sequential(*list(self.resnet_50.children())[:-2])
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.resnet_50(x)
return x
# class deeplabv3_encoder_decoder(pl.LightningModule):
# def __init__(self, input_channels=3, output_channels=4): # Use 4 channels for output
# super(deeplabv3_encoder_decoder, self).__init__()
# self.resnet = ResNet_50(in_channels=input_channels)
# self.aspp = ASSP(in_channels=2048, final_out_channels=4)
# self.conv = nn.Conv2d(in_channels=4, out_channels=output_channels, kernel_size=1)
# self.criterion = mIoULoss(n_classes=4) # Set number of classes to 4
# def forward(self, x):
# _, _, h, w = x.shape
# x = self.resnet(x) # Output should be [batch_size, 2048, H/32, W/32]
# x = self.aspp(x)
# x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) # Upsample
# x = self.conv(x) # Apply final convolution
# return x
# def training_step(self, batch, batch_idx):
# images, masks = batch
# logits = self(images)
# loss = self.criterion(logits, masks)
# iou = calculate_iou(logits, masks)
# self.log('train_loss', loss)
# self.log('train_iou', iou)
# print(f'Training Loss: {loss}, IoU: {iou}')
# return loss
# def validation_step(self, batch, batch_idx):
# images, masks = batch
# logits = self(images)
# loss = self.criterion(logits, masks)
# iou = calculate_iou(logits, masks)
# self.log('val_loss', loss)
# self.log('val_iou', iou)
# print(f'Validation Loss: {loss}, IoU: {iou}')
# return loss
# def on_training_epoch_end(self, outputs):
# avg_iou = torch.stack([x['train_iou'] for x in outputs]).mean()
# self.log('avg_train_iou', avg_iou)
# def on_validation_epoch_end(self, outputs):
# avg_iou = torch.stack([x['val_iou'] for x in outputs]).mean()
# self.log('avg_val_iou', avg_iou)
# def configure_optimizers(self):
# optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
# return optimizer
class deeplabv3_encoder_decoder(pl.LightningModule):
def __init__(self, input_channels=3, output_channels=4): # Use 4 channels for output
super(deeplabv3_encoder_decoder, self).__init__()
self.resnet = ResNet_50(in_channels=input_channels)
self.aspp = ASSP(in_channels=2048, final_out_channels=4)
self.conv = nn.Conv2d(in_channels=4, out_channels=output_channels, kernel_size=1)
self.criterion = DiceLoss() # Set number of classes to 4
def forward(self, x):
_, _, h, w = x.shape
x = self.resnet(x) # Output should be [batch_size, 2048, H/32, W/32]
x = self.aspp(x)
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) # Upsample
x = self.conv(x) # Apply final convolution
return x
def training_step(self, batch, batch_idx):
images, masks = batch
logits = self(images)
loss = self.criterion(logits, masks)
# print("\n\n\n\n\n\n\n\n",masks.shape, logits.shape,"\n\n\n\n\n\n\n\n\n\n")
iou = compute_iou(logits, masks)
self.log('train_loss', loss)
self.log('train_iou', iou)
# print(f'Training Loss: {loss}, IoU: {iou}')
return loss
def validation_step(self, batch, batch_idx):
images, masks = batch
logits = self(images)
loss = self.criterion(logits, masks)
iou = compute_iou(logits, masks)
self.log('val_loss', loss)
self.log('val_iou', iou)
# print(f'Validation Loss: {loss}, IoU: {iou}')
return loss
def on_train_epoch_end(self):
avg_iou = self.trainer.callback_metrics['train_iou'].mean()
train_loss = self.trainer.logged_metrics.get('train_loss')
self.log('avg_train_iou', avg_iou)
print("avg train iou",avg_iou)
print("loss",train_loss)
# iou = calculate_iou(logits, masks)
# self.log('train_loss', loss)
# self.log('train_iou', iou)
# print(f'Training Loss: {loss}, IoU: {iou}')
def on_validation_epoch_end(self):
avg_iou = self.trainer.callback_metrics['val_iou'].mean()
val_loss = self.trainer.logged_metrics.get('val_loss')
self.log('avg_val_iou', avg_iou)
print("avg val iou",avg_iou)
print("val loss", val_loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# def calculate_iou(logits, masks):
# # Calculate predictions from logits
# preds = torch.argmax(logits, dim=1)
# # Calculate intersection and union
# intersection = torch.sum(preds * masks)
# union = torch.sum((preds.bool() | masks.bool()).int())
# # Avoid division by zero
# iou = intersection / union if union != 0 else torch.tensor(0.0)
# return iou
def compute_iou(preds,labels,threshold = 0.5 , epsilon = torch.finfo(torch.float).eps):
preds = torch.sigmoid(preds)
# print("preds shape",preds.shape)
preds = (preds>threshold).float()
# print("preds shape123",preds.shape)
# print("masks shape123",labels.shape)
# print("masks shape123",np.unique(labels.cpu().numpy()))
# plt.imshow(labels[0,:,:,:].T.cpu().numpy())
# plt.show()
n_classes = preds.shape[1]
iou_per_class = []
for i in range(n_classes):
intersection = (preds[:,i,:,:] * labels[:,i,:,:]).sum((1,2))
union = (preds[:,i,:,:]+ labels[:,i,:,:]).sum((1,2)) - intersection
iou = (intersection + epsilon) / (union + epsilon)
iou_per_class.append(iou.mean())
iou_mean = sum(iou_per_class)/ n_classes
return iou_mean