| """
|
| SAM2 Training Script for Coffee Severity Segmentation
|
| =====================================================
|
|
|
| This script fine-tunes the Segment Anything Model 2 (SAM2) for semantic segmentation
|
| of coffee leaf rust severity. It loads a dataset of images and masks, and trains
|
| the mask decoder of SAM2.
|
|
|
| Methodology:
|
| 1. Load dataset (images + binary masks).
|
| 2. Resize and preprocess images.
|
| 3. Generate point prompts from the ground truth masks (simulating user clicks or random points).
|
| 4. Train the SAM2 mask decoder using a combination of IoU loss and Cross Entropy loss.
|
| 5. Evaluate semantic segmentation performance on a test set.
|
|
|
| Requirements:
|
| - torch
|
| - opencv-python (cv2)
|
| - pandas
|
| - scikit-learn
|
| - sam2 (https://github.com/facebookresearch/sam2)
|
| """
|
|
|
| import os
|
| import random
|
| import logging
|
| import argparse
|
| import numpy as np
|
| import torch
|
| import cv2
|
| import pandas as pd
|
| from torch.nn.utils import clip_grad_norm_
|
| from torch.optim import AdamW
|
| from torch.cuda.amp import GradScaler, autocast
|
| from sklearn.model_selection import train_test_split
|
|
|
|
|
|
|
| try:
|
| from sam2.build_sam import build_sam2
|
| from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| except ImportError:
|
| print("Error: 'sam2' module not found. Please install SAM2 or check your python path.")
|
| exit(1)
|
|
|
|
|
|
|
| DEFAULT_DATA_DIR = "./data/coffee_severity"
|
| DEFAULT_CFG_PATH = "./configs/sam2.1_hiera_t.yaml"
|
| DEFAULT_CKPT_PATH = "./checkpoints/sam2.1_hiera_tiny.pt"
|
| DEFAULT_LOG_FILE = "training.log"
|
|
|
|
|
| SEED = 42
|
| TARGET_SIZE = 1024
|
|
|
|
|
| def setup_logging(log_path: str):
|
| """Sets up logging to console and file."""
|
| logger = logging.getLogger()
|
| logger.setLevel(logging.INFO)
|
|
|
|
|
| ch = logging.StreamHandler()
|
| ch.setLevel(logging.INFO)
|
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| ch.setFormatter(formatter)
|
| logger.addHandler(ch)
|
|
|
|
|
| fh = logging.FileHandler(log_path)
|
| fh.setLevel(logging.DEBUG)
|
| fh.setFormatter(formatter)
|
| logger.addHandler(fh)
|
|
|
| return logger
|
|
|
| def set_seeds(seed=42):
|
| """Sets random seeds for reproducibility."""
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(seed)
|
| torch.backends.cudnn.deterministic = True
|
| torch.backends.cudnn.benchmark = True
|
|
|
| def load_dataset(data_dir: str, test_ratio=0.2):
|
| """
|
| Loads dataset from a directory containing 'images', 'masks', and a 'train.csv'.
|
| """
|
| images_dir = os.path.join(data_dir, "images")
|
| masks_dir = os.path.join(data_dir, "masks")
|
| csv_path = os.path.join(data_dir, "train.csv")
|
|
|
| if not os.path.exists(csv_path):
|
| raise FileNotFoundError(f"CSV file not found: {csv_path}")
|
|
|
| df = pd.read_csv(csv_path)
|
| train_df, test_df = train_test_split(df, test_size=test_ratio, random_state=SEED)
|
|
|
| def to_list(df_part):
|
| lst = []
|
| for _, row in df_part.iterrows():
|
| img_path = os.path.join(images_dir, row["imageid"])
|
| mask_path = os.path.join(masks_dir, row["maskid"])
|
| if os.path.exists(img_path) and os.path.exists(mask_path):
|
| lst.append({"image": img_path, "annotation": mask_path})
|
| else:
|
| logging.warning(f"Missing file: {img_path} or {mask_path}")
|
| return lst
|
|
|
| return to_list(train_df), to_list(test_df)
|
|
|
| def read_batch(data_list, target_size=TARGET_SIZE):
|
| """
|
| Reads a random image and mask from the data list and preprocesses it.
|
| """
|
| if not data_list:
|
| return None, None, None, 0
|
|
|
| ent = random.choice(data_list)
|
| img = cv2.imread(ent["image"])
|
| ann = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)
|
|
|
| if img is None or ann is None:
|
| return None, None, None, 0
|
|
|
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
|
| r = min(target_size / img.shape[1], target_size / img.shape[0])
|
| w_new, h_new = int(img.shape[1] * r), int(img.shape[0] * r)
|
| img = cv2.resize(img, (w_new, h_new))
|
| ann = cv2.resize(ann, (w_new, h_new), interpolation=cv2.INTER_NEAREST)
|
|
|
|
|
| binary_mask = (ann > 0).astype(np.uint8)
|
|
|
|
|
|
|
| eroded = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1)
|
| coords = np.argwhere(eroded > 0)
|
|
|
| points = []
|
| if coords.shape[0] > 0:
|
|
|
| num_samples = min(5, coords.shape[0])
|
| idxs = np.random.choice(coords.shape[0], size=num_samples, replace=False)
|
| for idx in idxs:
|
| y, x = coords[idx]
|
| points.append([x, y])
|
| else:
|
|
|
| coords = np.argwhere(binary_mask > 0)
|
| if coords.shape[0] > 0:
|
| y, x = coords[0]
|
| points.append([x, y])
|
|
|
| points = np.array(points)
|
|
|
|
|
| binary_mask = np.expand_dims(binary_mask, axis=0)
|
| points = points.reshape((-1, 1, 2)) if len(points) > 0 else np.zeros((0, 1, 2))
|
|
|
| return img, binary_mask, points, 1
|
|
|
| def build_model(cfg_path: str, checkpoint_path: str, device: str = "cuda"):
|
| """Loads SAM2 model and freezes the image encoder."""
|
| print(f"Loading SAM2 from {checkpoint_path}...")
|
| model = build_sam2(cfg_path, checkpoint_path, device=device)
|
| predictor = SAM2ImagePredictor(model)
|
|
|
|
|
| for param in predictor.model.image_encoder.parameters():
|
| param.requires_grad = False
|
|
|
|
|
| predictor.model.sam_mask_decoder.train(True)
|
| predictor.model.sam_prompt_encoder.train(True)
|
|
|
| return predictor
|
|
|
| def train_one_step(predictor, optimizer, scheduler, scaler,
|
| train_data, step, accumulation_steps, logger, mean_iou):
|
| """Performs one training step."""
|
| predictor.model.train()
|
|
|
| with autocast():
|
| img, mask_np, points, num_masks = read_batch(train_data)
|
|
|
|
|
| if img is None or num_masks == 0 or points.shape[0] == 0:
|
| return mean_iou
|
|
|
| predictor.set_image(img)
|
|
|
|
|
| input_label = np.ones((points.shape[0], 1))
|
| mask_input, unnorm_coords, labels, _ = predictor._prep_prompts(
|
| points, input_label, box=None, mask_logits=None, normalize_coords=True
|
| )
|
|
|
|
|
| sparse_emb, dense_emb = predictor.model.sam_prompt_encoder(
|
| points=(unnorm_coords, labels), boxes=None, masks=None
|
| )
|
|
|
| batched = unnorm_coords.shape[0] > 1
|
| high_feats = [feat[-1].unsqueeze(0) for feat in predictor._features["high_res_feats"]]
|
|
|
| low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
|
| image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
|
| image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
|
| sparse_prompt_embeddings=sparse_emb,
|
| dense_prompt_embeddings=dense_emb,
|
| multimask_output=True,
|
| repeat_image=batched,
|
| high_res_features=high_feats
|
| )
|
|
|
|
|
| prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
|
| gt = torch.tensor(mask_np.astype(np.float32)).cuda()
|
| prd = torch.sigmoid(prd_masks[:, 0])
|
|
|
|
|
| seg_loss = (-gt * torch.log(prd + 1e-6) - (1 - gt) * torch.log(1 - prd + 1e-6)).mean()
|
|
|
| inter = (gt * (prd > 0.5)).sum()
|
| union = gt.sum() + (prd > 0.5).sum() - inter
|
| iou = (inter / (union + 1e-6)).item()
|
|
|
| score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
|
| loss = seg_loss + 0.05 * score_loss
|
| loss = loss / accumulation_steps
|
|
|
|
|
| scaler.scale(loss).backward()
|
| clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)
|
|
|
| if step % accumulation_steps == 0:
|
| scaler.step(optimizer)
|
| scaler.update()
|
| optimizer.zero_grad()
|
| scheduler.step()
|
|
|
|
|
| mean_iou = mean_iou * 0.99 + 0.01 * iou
|
|
|
| if step % 100 == 0:
|
| lr = optimizer.param_groups[0]["lr"]
|
| logger.info(f"Step {step}: LR={lr:.6e}, IoU={mean_iou:.4f}, Loss={seg_loss.item():.4f}")
|
|
|
| return mean_iou
|
|
|
| def run_training(args):
|
| logger = setup_logging(args.log_file)
|
| logger.info(f"Initializing training on device: {args.device}")
|
| set_seeds(args.seed)
|
|
|
|
|
| try:
|
| train_data, test_data = load_dataset(args.data_dir, test_ratio=args.test_ratio)
|
| logger.info(f"Loaded {len(train_data)} training samples and {len(test_data)} test samples.")
|
| except Exception as e:
|
| logger.error(f"Failed to load dataset: {e}")
|
| return
|
|
|
|
|
| predictor = build_model(args.cfg_path, args.ckpt_path, device=args.device)
|
|
|
|
|
| optimizer = AdamW(predictor.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
|
| scaler = GradScaler()
|
|
|
| mean_iou = 0.0
|
| logger.info("Starting training loop...")
|
|
|
| for step in range(1, args.max_steps + 1):
|
| mean_iou = train_one_step(predictor, optimizer, scheduler, scaler,
|
| train_data, step, args.accum_steps, logger, mean_iou)
|
|
|
|
|
| if step % 500 == 0:
|
|
|
| fname = f"checkpoint_step_{step}.pt"
|
| torch.save(predictor.model.state_dict(), fname)
|
| logger.info(f"Saved checkpoint: {fname}")
|
|
|
| logger.info("Training finished.")
|
|
|
| def parse_args():
|
| p = argparse.ArgumentParser(description="Train SAM2 for Coffee Leaf Rust Severity")
|
| p.add_argument("--data-dir", type=str, default=DEFAULT_DATA_DIR, help="Path to dataset folder")
|
| p.add_argument("--cfg-path", type=str, default=DEFAULT_CFG_PATH, help="Path to SAM2 config YAML")
|
| p.add_argument("--ckpt-path", type=str, default=DEFAULT_CKPT_PATH, help="Path to pretrained SAM2 checkpoint")
|
| p.add_argument("--device", type=str, default="cuda", help="Device to use (cuda/cpu)")
|
| p.add_argument("--log-file", type=str, default=DEFAULT_LOG_FILE, help="Path to log file")
|
| p.add_argument("--max-steps", type=int, default=6000, help="Total training steps")
|
| p.add_argument("--accum-steps", type=int, default=8, help="Gradient accumulation steps")
|
| p.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
|
| p.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay")
|
| p.add_argument("--step-size", type=int, default=2000, help="Scheduler step size")
|
| p.add_argument("--gamma", type=float, default=0.6, help="Scheduler gamma")
|
| p.add_argument("--test-ratio", type=float, default=0.2, help="Ratio of test set")
|
| p.add_argument("--seed", type=int, default=SEED, help="Random seed")
|
| return p.parse_args()
|
|
|
| if __name__ == "__main__":
|
| args = parse_args()
|
| run_training(args)
|
|
|