File size: 6,460 Bytes
d72a22b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | """
DeepLabV3+ Training Script for Coffee Leaf Rust
===============================================
This script trains a DeepLabV3+ model (ResNet50 encoder) for binary semantic segmentation:
- Class 0: Background / Healthy
- Class 1: Lesion (Rust)
Requirements:
- torch
- albumentations
- segmentation-models-pytorch (smp)
- opencv-python
"""
import os
import cv2
import torch
import numpy as np
import albumentations as A
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
# ================= Configuration =================
# Dataset Paths
TRAIN_IMG_DIR = "./data/dataset/images/train"
TRAIN_MASK_DIR = "./data/dataset/masks/train"
VAL_IMG_DIR = "./data/dataset/images/valid"
VAL_MASK_DIR = "./data/dataset/masks/valid"
# Output Paths
SAVE_LAST = "./checkpoints/deeplab_binary_last.pth"
SAVE_BEST = "./checkpoints/deeplab_binary_best.pth"
# Hyperparameters
BATCH_SIZE = 8
EPOCHS = 25
IMG_SIZE = 512
LEARNING_RATE = 1e-4
# =================================================
class LeafDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.images = sorted(os.listdir(img_dir))
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name = self.images[idx]
img_path = os.path.join(self.img_dir, img_name)
# Assumption: mask has same basename + _mask.png
# Adjust this logic if your naming convention differs
base = os.path.splitext(img_name)[0]
mask_name = base + "_mask.png"
mask_path = os.path.join(self.mask_dir, mask_name)
img = cv2.imread(img_path)
if img is None:
raise RuntimeError(f"Cannot read image: {img_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, 0) # Read as grayscale
if mask is None:
# Fallback: try reading with original name if _mask suffix not used
mask_name = img_name
mask_path = os.path.join(self.mask_dir, mask_name)
mask = cv2.imread(mask_path, 0)
if mask is None:
raise RuntimeError(f"Cannot read mask for: {img_name}")
# Binary mask: lesion = 1, background = 0
mask = (mask > 0).astype(np.float32)
if self.transform:
augmented = self.transform(image=img, mask=mask)
img = augmented["image"]
mask = augmented["mask"]
# To Tensor: (H, W, C) -> (C, H, W)
img = torch.tensor(img).permute(2, 0, 1).float()
mask = torch.tensor(mask).unsqueeze(0).float()
return img, mask
def get_transforms():
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_tf = A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Normalize(mean=mean, std=std),
], is_check_shapes=False)
val_tf = A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.Normalize(mean=mean, std=std),
], is_check_shapes=False)
return train_tf, val_tf
def train_epoch(model, loader, criterion, optimizer, device, epoch):
model.train()
running = 0.0
pbar = tqdm(loader, desc=f"Train Epoch {epoch}")
for imgs, masks in pbar:
imgs, masks = imgs.to(device), masks.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, masks)
loss.backward()
optimizer.step()
running += loss.item()
pbar.set_postfix(loss=loss.item())
return running / len(loader)
def val_epoch(model, loader, criterion, device, epoch):
model.eval()
running = 0.0
with torch.no_grad():
pbar = tqdm(loader, desc=f"Val Epoch {epoch}")
for imgs, masks in pbar:
imgs, masks = imgs.to(device), masks.to(device)
logits = model(imgs)
loss = criterion(logits, masks)
running += loss.item()
pbar.set_postfix(loss=loss.item())
return running / len(loader)
def main():
print("\n=== Binary DeepLabV3+ Training ===")
# Check Directories
if not os.path.exists(TRAIN_IMG_DIR):
print(f"Error: Training directory not found: {TRAIN_IMG_DIR}")
return
os.makedirs(os.path.dirname(SAVE_BEST), exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
# Transforms & Datasets
train_tf, val_tf = get_transforms()
train_ds = LeafDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, train_tf)
val_ds = LeafDataset(VAL_IMG_DIR, VAL_MASK_DIR, val_tf)
print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)}")
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# Model
model = smp.DeepLabV3Plus(
encoder_name="resnet50",
encoder_weights="imagenet",
in_channels=3,
classes=1
).to(device)
# Loss & Optimizer
# Combine Dice Loss and BCE for robust segmentation training
dice_loss = smp.losses.DiceLoss(mode="binary")
bce_loss = torch.nn.BCEWithLogitsLoss()
def criterion(logits, targets):
return dice_loss(logits, targets) + bce_loss(logits, targets)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# Training Loop
best_val_loss = float("inf")
for epoch in range(1, EPOCHS + 1):
train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
val_loss = val_epoch(model, val_loader, criterion, device, epoch)
print(f"[Epoch {epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), SAVE_BEST)
print(f"✔ Best model saved to: {SAVE_BEST}")
torch.save(model.state_dict(), SAVE_LAST)
print("Training finished.")
if __name__ == "__main__":
main()
|