Tan Gezerman
Upload torchmodel.py
8166961 verified
raw
history blame
13.7 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import cv2
import os
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
class DoubleConv(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UpConv(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super(UpConv, self).__init__()
self.up_conv = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels,
kernel_size=2, stride=2),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.up_conv(x)
class TLUNet(nn.Module):
def __init__(self, input_shape: Tuple[int, int, int], patch_size: int):
super(TLUNet, self).__init__()
self.patch_size = patch_size
# Load pretrained VGG16 with explicit weights specification
base_vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
# Freeze VGG16 parameters
for param in base_vgg.parameters():
param.requires_grad = False
# Extract and store intermediate dimensions
self.dims = {
'input': patch_size,
'block1': patch_size // 2, # 128
'block2': patch_size // 4, # 64
'block3': patch_size // 8, # 32
'block4': patch_size // 16, # 16
'block5': patch_size // 32 # 8
}
# Encoder blocks from VGG16
self.block1 = nn.Sequential(
*list(base_vgg.features.children())[:5]) # 256->128
self.block2 = nn.Sequential(
*list(base_vgg.features.children())[5:10]) # 128->64
self.block3 = nn.Sequential(
*list(base_vgg.features.children())[10:17]) # 64->32
self.block4 = nn.Sequential(
*list(base_vgg.features.children())[17:24]) # 32->16
self.block5 = nn.Sequential(
*list(base_vgg.features.children())[24:31]) # 16->8
# Decoder path with reusable components
self.upconv1 = UpConv(512, 512) # 8->16
self.conv6 = DoubleConv(1024, 512) # After concatenation
self.upconv2 = UpConv(512, 256) # 16->32
self.conv7 = DoubleConv(512, 256)
self.upconv3 = UpConv(256, 128) # 32->64
self.conv8 = DoubleConv(256, 128)
self.upconv4 = UpConv(128, 64) # 64->128
self.conv9 = DoubleConv(128, 64)
# Final upsampling and convolution
self.upconv_final = UpConv(64, 32) # 128->256
self.final_conv = nn.Sequential(
nn.Conv2d(32, 1, kernel_size=3, padding=1),
nn.Sigmoid()
)
def forward(self, x):
# Encoder path with skip connections
e1 = self.block1(x)
e2 = self.block2(e1)
e3 = self.block3(e2)
e4 = self.block4(e3)
e5 = self.block5(e4)
# Decoder path with skip connections
d1 = self.upconv1(e5)
d1 = torch.cat([d1, e4], dim=1)
d1 = self.conv6(d1)
d2 = self.upconv2(d1)
d2 = torch.cat([d2, e3], dim=1)
d2 = self.conv7(d2)
d3 = self.upconv3(d2)
d3 = torch.cat([d3, e2], dim=1)
d3 = self.conv8(d3)
d4 = self.upconv4(d3)
d4 = torch.cat([d4, e1], dim=1)
d4 = self.conv9(d4)
# Final upsampling and convolution
out = self.upconv_final(d4)
out = self.final_conv(out)
return out
class ImageMaskDataset(Dataset):
def __init__(self, image_dir: str, mask_dir: str, target_size: Tuple[int, int]):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.target_size = target_size
self.image_paths = sorted(
[os.path.join(image_dir, fname) for fname in os.listdir(image_dir)])
self.mask_paths = sorted([os.path.join(mask_dir, fname)
for fname in os.listdir(mask_dir)])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image
image = cv2.imread(self.image_paths[idx])
image = image / 255.0
image = torch.FloatTensor(image).permute(
2, 0, 1) # Convert to CHW format
# Load mask
mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
mask = mask / 255.0
mask = torch.FloatTensor(mask).unsqueeze(0) # Add channel dimension
return image, mask
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-6, gamma=2):
super(DiceLoss, self).__init__()
self.smooth = smooth
self.gamma = gamma
def forward(self, y_pred, y_true):
y_pred = y_pred.view(-1)
y_true = y_true.view(-1)
intersection = torch.sum(y_pred * y_true)
denominator = torch.sum(y_pred.pow(self.gamma)) + \
torch.sum(y_true.pow(self.gamma))
dice_score = (2.0 * intersection + self.smooth) / \
(denominator + self.smooth)
return 1 - dice_score
def train_model(model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
num_epochs: int,
device: torch.device) -> dict:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters())
history = {
'train_loss': [],
'val_loss': [],
'batch_losses': [] # Track individual batch losses
}
total_batches = len(train_loader)
print(f"Training on {total_batches} batches per epoch")
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
# Progress tracking variables
batch_losses = []
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 60)
for batch_idx, (images, masks) in enumerate(train_loader):
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
# Record batch loss
batch_loss = loss.item()
batch_losses.append(batch_loss)
train_loss += batch_loss
# Print progress every 10% of batches
if (batch_idx + 1) % max(1, total_batches // 10) == 0:
current_loss = train_loss / (batch_idx + 1)
progress = (batch_idx + 1) / total_batches * 100
print(f"Batch {batch_idx + 1}/{total_batches} [{progress:.1f}%] - "
f"Current Loss: {current_loss:.4f}")
# Validation phase
model.eval()
val_loss = 0.0
with torch.no_grad():
for val_batch_idx, (images, masks) in enumerate(val_loader):
images, masks = images.to(device), masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item()
# Calculate and record average losses
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['batch_losses'].extend(batch_losses)
# Print epoch summary
print("\nEpoch Summary:")
print(f"Average Train Loss: {avg_train_loss:.4f}")
print(f"Average Val Loss: {avg_val_loss:.4f}")
print(f"Best batch loss: {min(batch_losses):.4f}")
print(f"Worst batch loss: {max(batch_losses):.4f}")
print("-" * 60)
return history
def test_model(model: nn.Module, test_loader: DataLoader, device: torch.device) -> Tuple[float, float]:
model.eval()
total_loss = 0.0
criterion = nn.BCELoss()
with torch.no_grad():
for images, masks in test_loader:
images, masks = images.to(device), masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
total_loss += loss.item()
avg_loss = total_loss / len(test_loader)
print(f"Test Loss: {avg_loss:.4f}")
return avg_loss
def get_hw():
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = "mps"
return device
def plot_training_history(history, save_path=None, show_plot=True):
"""
Plot training metrics from model history.
Args:
history (dict): Dictionary containing 'train_loss', 'val_loss', and 'batch_losses'
save_path (str, optional): Path to save the plot. Defaults to None.
show_plot (bool, optional): Whether to display the plot. Defaults to True.
"""
# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
fig.suptitle('Training History', fontsize=16)
# Plot epoch-wise losses
epochs = range(1, len(history['train_loss']) + 1)
# Top subplot: Training and validation loss per epoch
ax1.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
ax1.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
ax1.set_title('Epoch-wise Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True)
ax1.legend()
# Bottom subplot: Batch losses
batches = range(1, len(history['batch_losses']) + 1)
ax2.plot(batches, history['batch_losses'], 'g-', alpha=0.5, label='Batch Loss')
# Add moving average line for batch losses
window_size = min(100, len(history['batch_losses']) // 10) # Adaptive window size
if window_size > 1:
moving_avg = np.convolve(history['batch_losses'],
np.ones(window_size)/window_size,
mode='valid')
ax2.plot(range(window_size, len(batches) + 1),
moving_avg,
'r-',
label=f'Moving Average (window={window_size})')
ax2.set_title('Batch-wise Training Loss')
ax2.set_xlabel('Batch')
ax2.set_ylabel('Loss')
ax2.grid(True)
ax2.legend()
# Add summary statistics as text
stats_text = (
f"Final Training Loss: {history['train_loss'][-1]:.4f}\n"
f"Final Validation Loss: {history['val_loss'][-1]:.4f}\n"
f"Best Training Loss: {min(history['train_loss']):.4f}\n"
f"Best Validation Loss: {min(history['val_loss']):.4f}\n"
f"Best Batch Loss: {min(history['batch_losses']):.4f}"
)
fig.text(0.95, 0.05, stats_text, fontsize=10, ha='right',
bbox=dict(facecolor='white', alpha=0.8))
# Adjust layout to prevent overlap
plt.tight_layout()
# Save plot if path is provided
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Plot saved to {save_path}")
# Show plot if requested
if show_plot:
plt.show()
else:
plt.close()
# Usage example:
if __name__ == "__main__":
hw = get_hw()
# Parameters
patch_size = 256
version = 0
batch_size = 48
num_epochs = 8
device = torch.device(hw)
# Paths
path = f"Data/Data_{patch_size}_{version}"
train_dir_images = os.path.join(path, "train", "images")
train_dir_masks = os.path.join(path, "train", "masks")
val_dir_images = os.path.join(path, "val", "images")
val_dir_masks = os.path.join(path, "val", "masks")
test_dir_images = os.path.join(path, "test", "images")
test_dir_masks = os.path.join(path, "test", "masks")
# Create datasets and dataloaders
target_size = (patch_size, patch_size)
input_shape = (patch_size, patch_size, 3)
train_dataset = ImageMaskDataset(
train_dir_images, train_dir_masks, target_size)
val_dataset = ImageMaskDataset(val_dir_images, val_dir_masks, target_size)
test_dataset = ImageMaskDataset(
test_dir_images, test_dir_masks, target_size)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
# Create and train model
model = TLUNet(input_shape, patch_size).to(device)
history = train_model(model, train_loader, val_loader, num_epochs, device)
# Test model
test_loss = test_model(model, test_loader, device)
# Save model
model_name = f"model__{patch_size}_{batch_size}_{version}.pth"
torch.save(model.state_dict(), os.path.join("Models", model_name))
plot_training_history(
history,
save_path=f'training_history_{patch_size}_{batch_size}_{version}.png'
)