| """
|
| DeepLabV3+ Training Script for Coffee Leaf Rust
|
| ===============================================
|
|
|
| This script trains a DeepLabV3+ model (ResNet50 encoder) for binary semantic segmentation:
|
| - Class 0: Background / Healthy
|
| - Class 1: Lesion (Rust)
|
|
|
| Requirements:
|
| - torch
|
| - albumentations
|
| - segmentation-models-pytorch (smp)
|
| - opencv-python
|
| """
|
|
|
| import os
|
| import cv2
|
| import torch
|
| import numpy as np
|
| import albumentations as A
|
| import segmentation_models_pytorch as smp
|
| from torch.utils.data import Dataset, DataLoader
|
| from tqdm import tqdm
|
|
|
|
|
|
|
| TRAIN_IMG_DIR = "./data/dataset/images/train"
|
| TRAIN_MASK_DIR = "./data/dataset/masks/train"
|
|
|
| VAL_IMG_DIR = "./data/dataset/images/valid"
|
| VAL_MASK_DIR = "./data/dataset/masks/valid"
|
|
|
|
|
| SAVE_LAST = "./checkpoints/deeplab_binary_last.pth"
|
| SAVE_BEST = "./checkpoints/deeplab_binary_best.pth"
|
|
|
|
|
| BATCH_SIZE = 8
|
| EPOCHS = 25
|
| IMG_SIZE = 512
|
| LEARNING_RATE = 1e-4
|
|
|
|
|
| class LeafDataset(Dataset):
|
| def __init__(self, img_dir, mask_dir, transform=None):
|
| self.img_dir = img_dir
|
| self.mask_dir = mask_dir
|
| self.images = sorted(os.listdir(img_dir))
|
| self.transform = transform
|
|
|
| def __len__(self):
|
| return len(self.images)
|
|
|
| def __getitem__(self, idx):
|
| img_name = self.images[idx]
|
| img_path = os.path.join(self.img_dir, img_name)
|
|
|
|
|
|
|
| base = os.path.splitext(img_name)[0]
|
| mask_name = base + "_mask.png"
|
| mask_path = os.path.join(self.mask_dir, mask_name)
|
|
|
| img = cv2.imread(img_path)
|
| if img is None:
|
| raise RuntimeError(f"Cannot read image: {img_path}")
|
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
| mask = cv2.imread(mask_path, 0)
|
| if mask is None:
|
|
|
| mask_name = img_name
|
| mask_path = os.path.join(self.mask_dir, mask_name)
|
| mask = cv2.imread(mask_path, 0)
|
| if mask is None:
|
| raise RuntimeError(f"Cannot read mask for: {img_name}")
|
|
|
|
|
| mask = (mask > 0).astype(np.float32)
|
|
|
| if self.transform:
|
| augmented = self.transform(image=img, mask=mask)
|
| img = augmented["image"]
|
| mask = augmented["mask"]
|
|
|
|
|
| img = torch.tensor(img).permute(2, 0, 1).float()
|
| mask = torch.tensor(mask).unsqueeze(0).float()
|
|
|
| return img, mask
|
|
|
| def get_transforms():
|
| mean = [0.485, 0.456, 0.406]
|
| std = [0.229, 0.224, 0.225]
|
|
|
| train_tf = A.Compose([
|
| A.Resize(IMG_SIZE, IMG_SIZE),
|
| A.HorizontalFlip(p=0.5),
|
| A.VerticalFlip(p=0.5),
|
| A.Normalize(mean=mean, std=std),
|
| ], is_check_shapes=False)
|
|
|
| val_tf = A.Compose([
|
| A.Resize(IMG_SIZE, IMG_SIZE),
|
| A.Normalize(mean=mean, std=std),
|
| ], is_check_shapes=False)
|
|
|
| return train_tf, val_tf
|
|
|
| def train_epoch(model, loader, criterion, optimizer, device, epoch):
|
| model.train()
|
| running = 0.0
|
|
|
| pbar = tqdm(loader, desc=f"Train Epoch {epoch}")
|
| for imgs, masks in pbar:
|
| imgs, masks = imgs.to(device), masks.to(device)
|
|
|
| optimizer.zero_grad()
|
| logits = model(imgs)
|
|
|
| loss = criterion(logits, masks)
|
| loss.backward()
|
| optimizer.step()
|
|
|
| running += loss.item()
|
| pbar.set_postfix(loss=loss.item())
|
|
|
| return running / len(loader)
|
|
|
| def val_epoch(model, loader, criterion, device, epoch):
|
| model.eval()
|
| running = 0.0
|
|
|
| with torch.no_grad():
|
| pbar = tqdm(loader, desc=f"Val Epoch {epoch}")
|
| for imgs, masks in pbar:
|
| imgs, masks = imgs.to(device), masks.to(device)
|
| logits = model(imgs)
|
| loss = criterion(logits, masks)
|
| running += loss.item()
|
| pbar.set_postfix(loss=loss.item())
|
|
|
| return running / len(loader)
|
|
|
| def main():
|
| print("\n=== Binary DeepLabV3+ Training ===")
|
|
|
|
|
| if not os.path.exists(TRAIN_IMG_DIR):
|
| print(f"Error: Training directory not found: {TRAIN_IMG_DIR}")
|
| return
|
|
|
| os.makedirs(os.path.dirname(SAVE_BEST), exist_ok=True)
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| print("Using device:", device)
|
|
|
|
|
| train_tf, val_tf = get_transforms()
|
| train_ds = LeafDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, train_tf)
|
| val_ds = LeafDataset(VAL_IMG_DIR, VAL_MASK_DIR, val_tf)
|
| print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)}")
|
|
|
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
|
| val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
|
|
|
|
| model = smp.DeepLabV3Plus(
|
| encoder_name="resnet50",
|
| encoder_weights="imagenet",
|
| in_channels=3,
|
| classes=1
|
| ).to(device)
|
|
|
|
|
|
|
| dice_loss = smp.losses.DiceLoss(mode="binary")
|
| bce_loss = torch.nn.BCEWithLogitsLoss()
|
|
|
| def criterion(logits, targets):
|
| return dice_loss(logits, targets) + bce_loss(logits, targets)
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
|
|
|
|
| best_val_loss = float("inf")
|
|
|
| for epoch in range(1, EPOCHS + 1):
|
| train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
|
| val_loss = val_epoch(model, val_loader, criterion, device, epoch)
|
|
|
| print(f"[Epoch {epoch}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
|
|
|
| if val_loss < best_val_loss:
|
| best_val_loss = val_loss
|
| torch.save(model.state_dict(), SAVE_BEST)
|
| print(f"✔ Best model saved to: {SAVE_BEST}")
|
|
|
| torch.save(model.state_dict(), SAVE_LAST)
|
| print("Training finished.")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|