MaryPazRB's picture
Upload 23 files
d72a22b verified
"""
DeepLabV3+ Training Script for Coffee Leaf Rust
===============================================
This script trains a DeepLabV3+ model (ResNet50 encoder) for binary semantic segmentation:
- Class 0: Background / Healthy
- Class 1: Lesion (Rust)
Requirements:
- torch
- albumentations
- segmentation-models-pytorch (smp)
- opencv-python
"""
import os
import cv2
import torch
import numpy as np
import albumentations as A
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
# ================= Configuration =================
# Dataset Paths
TRAIN_IMG_DIR = "./data/dataset/images/train"
TRAIN_MASK_DIR = "./data/dataset/masks/train"
VAL_IMG_DIR = "./data/dataset/images/valid"
VAL_MASK_DIR = "./data/dataset/masks/valid"
# Output Paths
SAVE_LAST = "./checkpoints/deeplab_binary_last.pth"
SAVE_BEST = "./checkpoints/deeplab_binary_best.pth"
# Hyperparameters
BATCH_SIZE = 8
EPOCHS = 25
IMG_SIZE = 512
LEARNING_RATE = 1e-4
# =================================================
class LeafDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.images = sorted(os.listdir(img_dir))
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name = self.images[idx]
img_path = os.path.join(self.img_dir, img_name)
# Assumption: mask has same basename + _mask.png
# Adjust this logic if your naming convention differs
base = os.path.splitext(img_name)[0]
mask_name = base + "_mask.png"
mask_path = os.path.join(self.mask_dir, mask_name)
img = cv2.imread(img_path)
if img is None:
raise RuntimeError(f"Cannot read image: {img_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, 0) # Read as grayscale
if mask is None:
# Fallback: try reading with original name if _mask suffix not used
mask_name = img_name
mask_path = os.path.join(self.mask_dir, mask_name)
mask = cv2.imread(mask_path, 0)
if mask is None:
raise RuntimeError(f"Cannot read mask for: {img_name}")
# Binary mask: lesion = 1, background = 0
mask = (mask > 0).astype(np.float32)
if self.transform:
augmented = self.transform(image=img, mask=mask)
img = augmented["image"]
mask = augmented["mask"]
# To Tensor: (H, W, C) -> (C, H, W)
img = torch.tensor(img).permute(2, 0, 1).float()
mask = torch.tensor(mask).unsqueeze(0).float()
return img, mask
def get_transforms():
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_tf = A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Normalize(mean=mean, std=std),
], is_check_shapes=False)
val_tf = A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.Normalize(mean=mean, std=std),
], is_check_shapes=False)
return train_tf, val_tf
def train_epoch(model, loader, criterion, optimizer, device, epoch):
model.train()
running = 0.0
pbar = tqdm(loader, desc=f"Train Epoch {epoch}")
for imgs, masks in pbar:
imgs, masks = imgs.to(device), masks.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, masks)
loss.backward()
optimizer.step()
running += loss.item()
pbar.set_postfix(loss=loss.item())
return running / len(loader)
def val_epoch(model, loader, criterion, device, epoch):
model.eval()
running = 0.0
with torch.no_grad():
pbar = tqdm(loader, desc=f"Val Epoch {epoch}")
for imgs, masks in pbar:
imgs, masks = imgs.to(device), masks.to(device)
logits = model(imgs)
loss = criterion(logits, masks)
running += loss.item()
pbar.set_postfix(loss=loss.item())
return running / len(loader)
def main():
print("\n=== Binary DeepLabV3+ Training ===")
# Check Directories
if not os.path.exists(TRAIN_IMG_DIR):
print(f"Error: Training directory not found: {TRAIN_IMG_DIR}")
return
os.makedirs(os.path.dirname(SAVE_BEST), exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
# Transforms & Datasets
train_tf, val_tf = get_transforms()
train_ds = LeafDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, train_tf)
val_ds = LeafDataset(VAL_IMG_DIR, VAL_MASK_DIR, val_tf)
print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)}")
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# Model
model = smp.DeepLabV3Plus(
encoder_name="resnet50",
encoder_weights="imagenet",
in_channels=3,
classes=1
).to(device)
# Loss & Optimizer
# Combine Dice Loss and BCE for robust segmentation training
dice_loss = smp.losses.DiceLoss(mode="binary")
bce_loss = torch.nn.BCEWithLogitsLoss()
def criterion(logits, targets):
return dice_loss(logits, targets) + bce_loss(logits, targets)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# Training Loop
best_val_loss = float("inf")
for epoch in range(1, EPOCHS + 1):
train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
val_loss = val_epoch(model, val_loader, criterion, device, epoch)
print(f"[Epoch {epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), SAVE_BEST)
print(f"✔ Best model saved to: {SAVE_BEST}")
torch.save(model.state_dict(), SAVE_LAST)
print("Training finished.")
if __name__ == "__main__":
main()