import os import random import pandas as pd import cv2 import torch import torch.nn.utils import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt import matplotlib.colors as mcolors from sklearn.model_selection import train_test_split from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor def set_seeds(): SEED_VALUE = 42 random.seed(SEED_VALUE) np.random.seed(SEED_VALUE) torch.manual_seed(SEED_VALUE) if torch.cuda.is_available(): torch.cuda.manual_seed(SEED_VALUE) torch.cuda.manual_seed_all(SEED_VALUE) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True set_seeds() data_dir = "./sam2-data" images_dir = os.path.join(data_dir, "images") masks_dir = os.path.join(data_dir, "masks") train_df = pd.read_csv(os.path.join(data_dir, "train.csv")) train_df, test_df = train_test_split(train_df, test_size=0.1, random_state=42) train_data = [] for index, row in train_df.iterrows(): image_name = row['imageid'] mask_name = row['maskid'] train_data.append({ "image": os.path.join(images_dir, image_name), "annotation": os.path.join(masks_dir, mask_name) }) test_data = [] for index, row in test_df.iterrows(): image_name = row['imageid'] mask_name = row['maskid'] test_data.append({ "image": os.path.join(images_dir, image_name), "annotation": os.path.join(masks_dir, mask_name) }) def read_batch(data, visualize_data=True): ent = data[np.random.randint(len(data))] Img = cv2.imread(ent["image"])[..., ::-1] ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE) if Img is None or ann_map is None: print(f"Error: Could not read image or mask from path {ent['image']} or {ent['annotation']}") return None, None, None, 0 r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r))) ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST) binary_mask = np.zeros_like(ann_map, dtype=np.uint8) points = [] inds = np.unique(ann_map)[1:] for ind in inds: mask = (ann_map == ind).astype(np.uint8) binary_mask = np.maximum(binary_mask, mask) eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1) coords = np.argwhere(eroded_mask > 0) if len(coords) > 0: for _ in inds: yx = np.array(coords[np.random.randint(len(coords))]) points.append([yx[1], yx[0]]) points = np.array(points) if visualize_data: plt.figure(figsize=(15, 5)) plt.subplot(1, 3, 1) plt.title('Original Image') plt.imshow(Img) plt.axis('off') plt.subplot(1, 3, 2) plt.title('Binarized Mask') plt.imshow(binary_mask, cmap='gray') plt.axis('off') plt.subplot(1, 3, 3) plt.title('Binarized Mask with Points') plt.imshow(binary_mask, cmap='gray') colors = list(mcolors.TABLEAU_COLORS.values()) for i, point in enumerate(points): plt.scatter(point[0], point[1], c=colors[i % len(colors)], s=100) plt.axis('off') plt.tight_layout() plt.show() binary_mask = np.expand_dims(binary_mask, axis=-1) binary_mask = binary_mask.transpose((2, 0, 1)) points = np.expand_dims(points, axis=1) return Img, binary_mask, points, len(inds) # Img1, masks1, points1, num_masks = read_batch(train_data, visualize_data=True) def _to_hydra_name(x): if not x: return None s = str(x).replace("\\", "/") if s.endswith(".yaml"): s = s[:-5] # Normalize absolute/relative repo paths to hydra names: # /.../sam2/sam2/configs/sam2.1/sam2.1_hiera_s -> configs/sam2.1/sam2.1_hiera_s # ./sam2/configs/sam2.1/sam2.1_hiera_s -> configs/sam2.1/sam2.1_hiera_s if "/sam2/configs/" in s: return s.split("/sam2/")[1] # keep from 'configs/...' if s.startswith("sam2/configs/"): return s[len("sam2/"):] # strip leading 'sam2/' return s sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt" model_cfg = "./sam2/configs/sam2.1/sam2.1_hiera_l.yaml" model_cfg = _to_hydra_name(model_cfg) sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") predictor = SAM2ImagePredictor(sam2_model) predictor.model.sam_mask_decoder.train(True) predictor.model.sam_prompt_encoder.train(True) scaler = torch.amp.GradScaler() NO_OF_STEPS = 1200 FINE_TUNED_MODEL_NAME = "fine_tuned_sam2" optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=0.00005, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.6) accumulation_steps = 8 def train(predictor, train_data, step, mean_iou): # Ensure rolling mean is numeric if mean_iou is None or (isinstance(mean_iou, float) and (mean_iou != mean_iou)): # NaN mean_iou = 0.0 eps = 1e-6 predictor.model.train() with torch.amp.autocast(device_type='cuda'): image, mask, input_point, num_masks = read_batch(train_data, visualize_data=False) # If this batch is unusable, keep the rolling mean unchanged if image is None or mask is None or num_masks == 0: return mean_iou input_label = np.ones((num_masks, 1), dtype=np.int64) if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray): return mean_iou if input_point.size == 0 or input_label.size == 0: return mean_iou predictor.set_image(image) mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts( input_point, input_label, box=None, mask_logits=None, normalize_coords=True ) if ( unnorm_coords is None or labels is None or unnorm_coords.shape[0] == 0 or labels.shape[0] == 0 ): return mean_iou sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( points=(unnorm_coords, labels), boxes=None, masks=None ) batched_mode = unnorm_coords.shape[0] > 1 high_res_features = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]] low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0), image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=True, repeat_image=batched_mode, high_res_features=high_res_features, ) prd_masks = predictor._transforms.postprocess_masks( low_res_masks, predictor._orig_hw[-1] ) gt_mask = torch.tensor(mask.astype(np.float32), device='cuda') prd_mask = torch.sigmoid(prd_masks[:, 0]) # BCE-style seg loss (numerically stable enough with eps) seg_loss = (-gt_mask * torch.log(prd_mask + eps) - (1 - gt_mask) * torch.log((1 - prd_mask) + eps)).mean() # IoU with safeties pred_bin = (prd_mask > 0.5).float() inter = (gt_mask * pred_bin).sum(dim=(1, 2)) denom = gt_mask.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter iou = inter / (denom + eps) score_loss = torch.abs(prd_scores[:, 0] - iou).mean() loss = seg_loss + 0.05 * score_loss # grad accumulation loss = loss / accumulation_steps scaler.scale(loss).backward() torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0) did_optimizer_step = False if step % accumulation_steps == 0: # Optimizer step first, then scheduler.step() (fixes the warning) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) did_optimizer_step = True # Step the LR scheduler only when we actually step the optimizer if did_optimizer_step: scheduler.step() # Update rolling mean IoU (robust to NaN/inf) iou_np = iou.detach().float().cpu().numpy() iou_np = np.nan_to_num(iou_np, nan=0.0, posinf=1.0, neginf=0.0) mean_iou = float(mean_iou * 0.99 + 0.01 * float(np.mean(iou_np))) if step % 100 == 0: current_lr = optimizer.param_groups[0]["lr"] print(f"Step {step}: LR={current_lr:.6f} IoU={mean_iou:.6f} SegLoss={seg_loss.item():.6f}") return mean_iou def validate(predictor, test_data, step, mean_iou): # Always have a numeric baseline if mean_iou is None or (isinstance(mean_iou, float) and (mean_iou != mean_iou)): # NaN check mean_iou = 0.0 predictor.model.eval() with torch.amp.autocast(device_type='cuda'): with torch.no_grad(): image, mask, input_point, num_masks = read_batch(test_data, visualize_data=False) # If this batch is unusable, keep the rolling mean unchanged if image is None or mask is None or num_masks == 0: return mean_iou input_label = np.ones((num_masks, 1), dtype=np.int64) if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray): return mean_iou if input_point.size == 0 or input_label.size == 0: return mean_iou predictor.set_image(image) mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts( input_point, input_label, box=None, mask_logits=None, normalize_coords=True ) if ( unnorm_coords is None or labels is None or unnorm_coords.shape[0] == 0 or labels.shape[0] == 0 ): return mean_iou sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( points=(unnorm_coords, labels), boxes=None, masks=None ) batched_mode = unnorm_coords.shape[0] > 1 high_res_features = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]] low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0), image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=True, repeat_image=batched_mode, high_res_features=high_res_features, ) prd_masks = predictor._transforms.postprocess_masks( low_res_masks, predictor._orig_hw[-1] ) gt_mask = torch.tensor(mask.astype(np.float32), device='cuda') prd_mask = torch.sigmoid(prd_masks[:, 0]) # BCE-style seg loss eps = 1e-6 seg_loss = (-gt_mask * torch.log(prd_mask + eps) - (1 - gt_mask) * torch.log((1 - prd_mask) + eps)).mean() # IoU with numerical safety pred_bin = (prd_mask > 0.5).float() inter = (gt_mask * pred_bin).sum(dim=(1, 2)) denom = gt_mask.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter iou = inter / (denom + eps) # avoid 0/0 # Score loss score_loss = torch.abs(prd_scores[:, 0] - iou).mean() loss = seg_loss + 0.05 * score_loss loss = loss / accumulation_steps # assumes defined elsewhere if step % 100 == 0: torch.save(predictor.model.state_dict(), f"./checkpoints-ft/{FINE_TUNED_MODEL_NAME}_{step}.pt") iou_np = iou.detach().float().cpu().numpy() iou_np = np.nan_to_num(iou_np, nan=0.0, posinf=1.0, neginf=0.0) mean_iou = float(mean_iou * 0.99 + 0.01 * float(np.mean(iou_np))) if step % 100 == 0: current_lr = optimizer.param_groups[0]["lr"] print(f"Step {step}: LR={current_lr:.6f} Valid_IoU={mean_iou:.6f} SegLoss={seg_loss.item():.6f}") return mean_iou train_mean_iou = 0 valid_mean_iou = 0 # for step in range(1, NO_OF_STEPS + 1): # train_mean_iou = train(predictor, train_data, step, train_mean_iou) # valid_mean_iou = validate(predictor, test_data, step, valid_mean_iou) def read_image(image_path, mask_path): # read and resize image and mask img = cv2.imread(image_path)[..., ::-1] # Convert BGR to RGB mask = cv2.imread(mask_path, 0) r = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r))) mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), interpolation=cv2.INTER_NEAREST) return img, mask def get_points(mask, num_points): # Sample points inside the input mask points = [] coords = np.argwhere(mask > 0) for i in range(num_points): yx = np.array(coords[np.random.randint(len(coords))]) points.append([[yx[1], yx[0]]]) return np.array(points) for n in range(3): selected_entry = random.choice(test_data) print(selected_entry) image_path = selected_entry['image'] mask_path = selected_entry['annotation'] print(mask_path,'mask path') # Load the selected image and mask image, target_mask = read_image(image_path, mask_path) # Generate random points for the input num_samples = 30 # Number of points per segment to sample input_points = get_points(target_mask, num_samples) # Load the fine-tuned model FINE_TUNED_MODEL_WEIGHTS = "./checkpoints-ft/fine_tuned_sam2_1200.pt" sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # Build net and load weights predictor = SAM2ImagePredictor(sam2_model) predictor.model.load_state_dict(torch.load(FINE_TUNED_MODEL_WEIGHTS)) # Perform inference and predict masks with torch.no_grad(): predictor.set_image(image) masks, scores, logits = predictor.predict( point_coords=input_points, point_labels=np.ones([input_points.shape[0], 1]) ) # Process the predicted masks and sort by scores np_masks = np.array(masks[:, 0]) np_scores = scores[:, 0] sorted_masks = np_masks[np.argsort(np_scores)][::-1] # Initialize segmentation map and occupancy mask seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8) occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool) # Combine masks to create the final segmentation map for i in range(sorted_masks.shape[0]): mask = sorted_masks[i] if (mask * occupancy_mask).sum() / mask.sum() > 0.15: continue mask_bool = mask.astype(bool) mask_bool[occupancy_mask] = False # Set overlapping areas to False in the mask seg_map[mask_bool] = i + 1 # Use boolean mask to index seg_map occupancy_mask[mask_bool] = True # Update occupancy_mask # Visualization: Show the original image, mask, and final segmentation side by side plt.figure(figsize=(18, 6)) plt.subplot(1, 3, 1) plt.title('Test Image') plt.imshow(image) plt.axis('off') plt.subplot(1, 3, 2) plt.title('Original Mask') plt.imshow(target_mask, cmap='gray') plt.axis('off') plt.subplot(1, 3, 3) plt.title('Final Segmentation') plt.imshow(seg_map, cmap='jet') plt.axis('off') plt.tight_layout() plt.show()