MaryPazRB's picture
Upload 23 files
d72a22b verified
"""
SAM2 Training Script for Coffee Severity Segmentation
=====================================================
This script fine-tunes the Segment Anything Model 2 (SAM2) for semantic segmentation
of coffee leaf rust severity. It loads a dataset of images and masks, and trains
the mask decoder of SAM2.
Methodology:
1. Load dataset (images + binary masks).
2. Resize and preprocess images.
3. Generate point prompts from the ground truth masks (simulating user clicks or random points).
4. Train the SAM2 mask decoder using a combination of IoU loss and Cross Entropy loss.
5. Evaluate semantic segmentation performance on a test set.
Requirements:
- torch
- opencv-python (cv2)
- pandas
- scikit-learn
- sam2 (https://github.com/facebookresearch/sam2)
"""
import os
import random
import logging
import argparse
import numpy as np
import torch
import cv2
import pandas as pd
from torch.nn.utils import clip_grad_norm_
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from sklearn.model_selection import train_test_split
# --- Import SAM2 ---
# Ensure 'sam2' is installed or in your PYTHONPATH
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
except ImportError:
print("Error: 'sam2' module not found. Please install SAM2 or check your python path.")
exit(1)
# ================= Configuration =================
# Default paths (Can be overridden by command line arguments)
DEFAULT_DATA_DIR = "./data/coffee_severity"
DEFAULT_CFG_PATH = "./configs/sam2.1_hiera_t.yaml"
DEFAULT_CKPT_PATH = "./checkpoints/sam2.1_hiera_tiny.pt"
DEFAULT_LOG_FILE = "training.log"
# Hyperparameters
SEED = 42
TARGET_SIZE = 1024
# =================================================
def setup_logging(log_path: str):
"""Sets up logging to console and file."""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Console Handler
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
# File Handler
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
def set_seeds(seed=42):
"""Sets random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def load_dataset(data_dir: str, test_ratio=0.2):
"""
Loads dataset from a directory containing 'images', 'masks', and a 'train.csv'.
"""
images_dir = os.path.join(data_dir, "images")
masks_dir = os.path.join(data_dir, "masks")
csv_path = os.path.join(data_dir, "train.csv")
if not os.path.exists(csv_path):
raise FileNotFoundError(f"CSV file not found: {csv_path}")
df = pd.read_csv(csv_path)
train_df, test_df = train_test_split(df, test_size=test_ratio, random_state=SEED)
def to_list(df_part):
lst = []
for _, row in df_part.iterrows():
img_path = os.path.join(images_dir, row["imageid"])
mask_path = os.path.join(masks_dir, row["maskid"])
if os.path.exists(img_path) and os.path.exists(mask_path):
lst.append({"image": img_path, "annotation": mask_path})
else:
logging.warning(f"Missing file: {img_path} or {mask_path}")
return lst
return to_list(train_df), to_list(test_df)
def read_batch(data_list, target_size=TARGET_SIZE):
"""
Reads a random image and mask from the data list and preprocesses it.
"""
if not data_list:
return None, None, None, 0
ent = random.choice(data_list)
img = cv2.imread(ent["image"])
ann = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)
if img is None or ann is None:
return None, None, None, 0
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR -> RGB
# Resize while maintaining aspect ratio (simplified)
r = min(target_size / img.shape[1], target_size / img.shape[0])
w_new, h_new = int(img.shape[1] * r), int(img.shape[0] * r)
img = cv2.resize(img, (w_new, h_new))
ann = cv2.resize(ann, (w_new, h_new), interpolation=cv2.INTER_NEAREST)
# 1. Binarize Mask (>0 covers all non-background classes)
binary_mask = (ann > 0).astype(np.uint8)
# 2. Generate Point Prompts (Sampling from inside the mask)
# Erode to ensure points are not on the edge
eroded = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1)
coords = np.argwhere(eroded > 0)
points = []
if coords.shape[0] > 0:
# Sample up to 5 random points
num_samples = min(5, coords.shape[0])
idxs = np.random.choice(coords.shape[0], size=num_samples, replace=False)
for idx in idxs:
y, x = coords[idx]
points.append([x, y])
else:
# Fallback if erosion removes everything (very small masks)
coords = np.argwhere(binary_mask > 0)
if coords.shape[0] > 0:
y, x = coords[0] # Pick first point
points.append([x, y])
points = np.array(points)
# Format for SAM2
binary_mask = np.expand_dims(binary_mask, axis=0) # shape (1, H, W)
points = points.reshape((-1, 1, 2)) if len(points) > 0 else np.zeros((0, 1, 2))
return img, binary_mask, points, 1
def build_model(cfg_path: str, checkpoint_path: str, device: str = "cuda"):
"""Loads SAM2 model and freezes the image encoder."""
print(f"Loading SAM2 from {checkpoint_path}...")
model = build_sam2(cfg_path, checkpoint_path, device=device)
predictor = SAM2ImagePredictor(model)
# Freeze image encoder to save memory and time
for param in predictor.model.image_encoder.parameters():
param.requires_grad = False
# Train only the mask decoder and prompt encoder
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
return predictor
def train_one_step(predictor, optimizer, scheduler, scaler,
train_data, step, accumulation_steps, logger, mean_iou):
"""Performs one training step."""
predictor.model.train()
with autocast():
img, mask_np, points, num_masks = read_batch(train_data)
# Skip if bad data or no points generated
if img is None or num_masks == 0 or points.shape[0] == 0:
return mean_iou
predictor.set_image(img)
# Prepare inputs
input_label = np.ones((points.shape[0], 1)) # All points are foreground (1)
mask_input, unnorm_coords, labels, _ = predictor._prep_prompts(
points, input_label, box=None, mask_logits=None, normalize_coords=True
)
# Forward Pass
sparse_emb, dense_emb = predictor.model.sam_prompt_encoder(
points=(unnorm_coords, labels), boxes=None, masks=None
)
batched = unnorm_coords.shape[0] > 1
high_feats = [feat[-1].unsqueeze(0) for feat in predictor._features["high_res_feats"]]
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_emb,
dense_prompt_embeddings=dense_emb,
multimask_output=True,
repeat_image=batched,
high_res_features=high_feats
)
# Post-process and Calculate Loss
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
gt = torch.tensor(mask_np.astype(np.float32)).cuda()
prd = torch.sigmoid(prd_masks[:, 0]) # Take first mask
# Binary Cross Entropy + IoU Loss
seg_loss = (-gt * torch.log(prd + 1e-6) - (1 - gt) * torch.log(1 - prd + 1e-6)).mean()
inter = (gt * (prd > 0.5)).sum()
union = gt.sum() + (prd > 0.5).sum() - inter
iou = (inter / (union + 1e-6)).item()
score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
loss = seg_loss + 0.05 * score_loss
loss = loss / accumulation_steps
# Backward
scaler.scale(loss).backward()
clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)
if step % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scheduler.step()
# Moving average IoU
mean_iou = mean_iou * 0.99 + 0.01 * iou
if step % 100 == 0:
lr = optimizer.param_groups[0]["lr"]
logger.info(f"Step {step}: LR={lr:.6e}, IoU={mean_iou:.4f}, Loss={seg_loss.item():.4f}")
return mean_iou
def run_training(args):
logger = setup_logging(args.log_file)
logger.info(f"Initializing training on device: {args.device}")
set_seeds(args.seed)
# Load Data
try:
train_data, test_data = load_dataset(args.data_dir, test_ratio=args.test_ratio)
logger.info(f"Loaded {len(train_data)} training samples and {len(test_data)} test samples.")
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
return
# Build Model
predictor = build_model(args.cfg_path, args.ckpt_path, device=args.device)
# Optimizer
optimizer = AdamW(predictor.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
scaler = GradScaler()
mean_iou = 0.0
logger.info("Starting training loop...")
for step in range(1, args.max_steps + 1):
mean_iou = train_one_step(predictor, optimizer, scheduler, scaler,
train_data, step, args.accum_steps, logger, mean_iou)
# Basic validation (could be improved to run on full test set)
if step % 500 == 0:
# Save checkpoint
fname = f"checkpoint_step_{step}.pt"
torch.save(predictor.model.state_dict(), fname)
logger.info(f"Saved checkpoint: {fname}")
logger.info("Training finished.")
def parse_args():
p = argparse.ArgumentParser(description="Train SAM2 for Coffee Leaf Rust Severity")
p.add_argument("--data-dir", type=str, default=DEFAULT_DATA_DIR, help="Path to dataset folder")
p.add_argument("--cfg-path", type=str, default=DEFAULT_CFG_PATH, help="Path to SAM2 config YAML")
p.add_argument("--ckpt-path", type=str, default=DEFAULT_CKPT_PATH, help="Path to pretrained SAM2 checkpoint")
p.add_argument("--device", type=str, default="cuda", help="Device to use (cuda/cpu)")
p.add_argument("--log-file", type=str, default=DEFAULT_LOG_FILE, help="Path to log file")
p.add_argument("--max-steps", type=int, default=6000, help="Total training steps")
p.add_argument("--accum-steps", type=int, default=8, help="Gradient accumulation steps")
p.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
p.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay")
p.add_argument("--step-size", type=int, default=2000, help="Scheduler step size")
p.add_argument("--gamma", type=float, default=0.6, help="Scheduler gamma")
p.add_argument("--test-ratio", type=float, default=0.2, help="Ratio of test set")
p.add_argument("--seed", type=int, default=SEED, help="Random seed")
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
run_training(args)