| | import os |
| | import random |
| | import pandas as pd |
| | import cv2 |
| | import torch |
| | import torch.nn.utils |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import matplotlib.colors as mcolors |
| | from sklearn.model_selection import train_test_split |
| | |
| | from sam2.build_sam import build_sam2 |
| | from sam2.sam2_image_predictor import SAM2ImagePredictor |
| |
|
| | def set_seeds(): |
| | SEED_VALUE = 42 |
| | random.seed(SEED_VALUE) |
| | np.random.seed(SEED_VALUE) |
| | torch.manual_seed(SEED_VALUE) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(SEED_VALUE) |
| | torch.cuda.manual_seed_all(SEED_VALUE) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = True |
| | |
| | set_seeds() |
| |
|
| | data_dir = "./sam2-data" |
| | images_dir = os.path.join(data_dir, "images") |
| | masks_dir = os.path.join(data_dir, "masks") |
| | |
| | train_df = pd.read_csv(os.path.join(data_dir, "train.csv")) |
| | |
| | train_df, test_df = train_test_split(train_df, test_size=0.1, random_state=42) |
| | |
| | train_data = [] |
| | for index, row in train_df.iterrows(): |
| | image_name = row['imageid'] |
| | mask_name = row['maskid'] |
| | train_data.append({ |
| | "image": os.path.join(images_dir, image_name), |
| | "annotation": os.path.join(masks_dir, mask_name) |
| | }) |
| | |
| | test_data = [] |
| |
|
| | for index, row in test_df.iterrows(): |
| | image_name = row['imageid'] |
| | mask_name = row['maskid'] |
| | test_data.append({ |
| | "image": os.path.join(images_dir, image_name), |
| | "annotation": os.path.join(masks_dir, mask_name) |
| | }) |
| |
|
| | def read_batch(data, visualize_data=True): |
| | ent = data[np.random.randint(len(data))] |
| | Img = cv2.imread(ent["image"])[..., ::-1] |
| | ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE) |
| | |
| | if Img is None or ann_map is None: |
| | print(f"Error: Could not read image or mask from path {ent['image']} or {ent['annotation']}") |
| | return None, None, None, 0 |
| | |
| | r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) |
| | Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r))) |
| | ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), |
| | interpolation=cv2.INTER_NEAREST) |
| | |
| | binary_mask = np.zeros_like(ann_map, dtype=np.uint8) |
| | points = [] |
| | inds = np.unique(ann_map)[1:] |
| | for ind in inds: |
| | mask = (ann_map == ind).astype(np.uint8) |
| | binary_mask = np.maximum(binary_mask, mask) |
| | |
| | eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1) |
| | coords = np.argwhere(eroded_mask > 0) |
| | if len(coords) > 0: |
| | for _ in inds: |
| | yx = np.array(coords[np.random.randint(len(coords))]) |
| | points.append([yx[1], yx[0]]) |
| | points = np.array(points) |
| | |
| | if visualize_data: |
| | plt.figure(figsize=(15, 5)) |
| | plt.subplot(1, 3, 1) |
| | plt.title('Original Image') |
| | plt.imshow(Img) |
| | plt.axis('off') |
| | |
| | plt.subplot(1, 3, 2) |
| | plt.title('Binarized Mask') |
| | plt.imshow(binary_mask, cmap='gray') |
| | plt.axis('off') |
| | |
| | plt.subplot(1, 3, 3) |
| | plt.title('Binarized Mask with Points') |
| | plt.imshow(binary_mask, cmap='gray') |
| | colors = list(mcolors.TABLEAU_COLORS.values()) |
| | for i, point in enumerate(points): |
| | plt.scatter(point[0], point[1], c=colors[i % len(colors)], s=100) |
| | plt.axis('off') |
| | |
| | plt.tight_layout() |
| | plt.show() |
| | |
| | binary_mask = np.expand_dims(binary_mask, axis=-1) |
| | binary_mask = binary_mask.transpose((2, 0, 1)) |
| | points = np.expand_dims(points, axis=1) |
| | return Img, binary_mask, points, len(inds) |
| | |
| | |
| | def _to_hydra_name(x): |
| | if not x: |
| | return None |
| | s = str(x).replace("\\", "/") |
| | if s.endswith(".yaml"): |
| | s = s[:-5] |
| | |
| | |
| | |
| | if "/sam2/configs/" in s: |
| | return s.split("/sam2/")[1] |
| | if s.startswith("sam2/configs/"): |
| | return s[len("sam2/"):] |
| | return s |
| |
|
| | sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt" |
| | model_cfg = "./sam2/configs/sam2.1/sam2.1_hiera_l.yaml" |
| |
|
| | model_cfg = _to_hydra_name(model_cfg) |
| | sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") |
| | predictor = SAM2ImagePredictor(sam2_model) |
| | |
| | predictor.model.sam_mask_decoder.train(True) |
| | predictor.model.sam_prompt_encoder.train(True) |
| |
|
| | scaler = torch.amp.GradScaler() |
| | NO_OF_STEPS = 1200 |
| | FINE_TUNED_MODEL_NAME = "fine_tuned_sam2" |
| | |
| | optimizer = torch.optim.AdamW(params=predictor.model.parameters(), |
| | lr=0.00005, |
| | weight_decay=1e-4) |
| | |
| | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.6) |
| | accumulation_steps = 8 |
| |
|
| | def train(predictor, train_data, step, mean_iou): |
| | |
| | if mean_iou is None or (isinstance(mean_iou, float) and (mean_iou != mean_iou)): |
| | mean_iou = 0.0 |
| |
|
| | eps = 1e-6 |
| |
|
| | predictor.model.train() |
| | with torch.amp.autocast(device_type='cuda'): |
| | image, mask, input_point, num_masks = read_batch(train_data, visualize_data=False) |
| |
|
| | |
| | if image is None or mask is None or num_masks == 0: |
| | return mean_iou |
| |
|
| | input_label = np.ones((num_masks, 1), dtype=np.int64) |
| |
|
| | if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray): |
| | return mean_iou |
| | if input_point.size == 0 or input_label.size == 0: |
| | return mean_iou |
| |
|
| | predictor.set_image(image) |
| | mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts( |
| | input_point, input_label, box=None, mask_logits=None, normalize_coords=True |
| | ) |
| | if ( |
| | unnorm_coords is None or labels is None or |
| | unnorm_coords.shape[0] == 0 or labels.shape[0] == 0 |
| | ): |
| | return mean_iou |
| |
|
| | sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( |
| | points=(unnorm_coords, labels), boxes=None, masks=None |
| | ) |
| |
|
| | batched_mode = unnorm_coords.shape[0] > 1 |
| | high_res_features = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]] |
| |
|
| | low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( |
| | image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0), |
| | image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | multimask_output=True, |
| | repeat_image=batched_mode, |
| | high_res_features=high_res_features, |
| | ) |
| |
|
| | prd_masks = predictor._transforms.postprocess_masks( |
| | low_res_masks, predictor._orig_hw[-1] |
| | ) |
| |
|
| | gt_mask = torch.tensor(mask.astype(np.float32), device='cuda') |
| | prd_mask = torch.sigmoid(prd_masks[:, 0]) |
| |
|
| | |
| | seg_loss = (-gt_mask * torch.log(prd_mask + eps) |
| | - (1 - gt_mask) * torch.log((1 - prd_mask) + eps)).mean() |
| |
|
| | |
| | pred_bin = (prd_mask > 0.5).float() |
| | inter = (gt_mask * pred_bin).sum(dim=(1, 2)) |
| | denom = gt_mask.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter |
| | iou = inter / (denom + eps) |
| |
|
| | score_loss = torch.abs(prd_scores[:, 0] - iou).mean() |
| | loss = seg_loss + 0.05 * score_loss |
| |
|
| | |
| | loss = loss / accumulation_steps |
| | scaler.scale(loss).backward() |
| |
|
| | torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0) |
| |
|
| | did_optimizer_step = False |
| | if step % accumulation_steps == 0: |
| | |
| | scaler.step(optimizer) |
| | scaler.update() |
| | optimizer.zero_grad(set_to_none=True) |
| | did_optimizer_step = True |
| |
|
| | |
| | if did_optimizer_step: |
| | scheduler.step() |
| |
|
| | |
| | iou_np = iou.detach().float().cpu().numpy() |
| | iou_np = np.nan_to_num(iou_np, nan=0.0, posinf=1.0, neginf=0.0) |
| | mean_iou = float(mean_iou * 0.99 + 0.01 * float(np.mean(iou_np))) |
| |
|
| | if step % 100 == 0: |
| | current_lr = optimizer.param_groups[0]["lr"] |
| | print(f"Step {step}: LR={current_lr:.6f} IoU={mean_iou:.6f} SegLoss={seg_loss.item():.6f}") |
| |
|
| | return mean_iou |
| |
|
| | def validate(predictor, test_data, step, mean_iou): |
| | |
| | if mean_iou is None or (isinstance(mean_iou, float) and (mean_iou != mean_iou)): |
| | mean_iou = 0.0 |
| |
|
| | predictor.model.eval() |
| | with torch.amp.autocast(device_type='cuda'): |
| | with torch.no_grad(): |
| | image, mask, input_point, num_masks = read_batch(test_data, visualize_data=False) |
| |
|
| | |
| | if image is None or mask is None or num_masks == 0: |
| | return mean_iou |
| |
|
| | input_label = np.ones((num_masks, 1), dtype=np.int64) |
| |
|
| | if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray): |
| | return mean_iou |
| | if input_point.size == 0 or input_label.size == 0: |
| | return mean_iou |
| |
|
| | predictor.set_image(image) |
| | mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts( |
| | input_point, input_label, box=None, mask_logits=None, normalize_coords=True |
| | ) |
| |
|
| | if ( |
| | unnorm_coords is None or labels is None or |
| | unnorm_coords.shape[0] == 0 or labels.shape[0] == 0 |
| | ): |
| | return mean_iou |
| |
|
| | sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( |
| | points=(unnorm_coords, labels), boxes=None, masks=None |
| | ) |
| |
|
| | batched_mode = unnorm_coords.shape[0] > 1 |
| | high_res_features = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]] |
| | low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( |
| | image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0), |
| | image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | multimask_output=True, |
| | repeat_image=batched_mode, |
| | high_res_features=high_res_features, |
| | ) |
| |
|
| | prd_masks = predictor._transforms.postprocess_masks( |
| | low_res_masks, predictor._orig_hw[-1] |
| | ) |
| |
|
| | gt_mask = torch.tensor(mask.astype(np.float32), device='cuda') |
| | prd_mask = torch.sigmoid(prd_masks[:, 0]) |
| |
|
| | |
| | eps = 1e-6 |
| | seg_loss = (-gt_mask * torch.log(prd_mask + eps) |
| | - (1 - gt_mask) * torch.log((1 - prd_mask) + eps)).mean() |
| |
|
| | |
| | pred_bin = (prd_mask > 0.5).float() |
| | inter = (gt_mask * pred_bin).sum(dim=(1, 2)) |
| | denom = gt_mask.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter |
| | iou = inter / (denom + eps) |
| |
|
| | |
| | score_loss = torch.abs(prd_scores[:, 0] - iou).mean() |
| | loss = seg_loss + 0.05 * score_loss |
| | loss = loss / accumulation_steps |
| |
|
| | if step % 100 == 0: |
| | torch.save(predictor.model.state_dict(), f"./checkpoints-ft/{FINE_TUNED_MODEL_NAME}_{step}.pt") |
| |
|
| | iou_np = iou.detach().float().cpu().numpy() |
| | iou_np = np.nan_to_num(iou_np, nan=0.0, posinf=1.0, neginf=0.0) |
| | mean_iou = float(mean_iou * 0.99 + 0.01 * float(np.mean(iou_np))) |
| |
|
| | if step % 100 == 0: |
| | current_lr = optimizer.param_groups[0]["lr"] |
| | print(f"Step {step}: LR={current_lr:.6f} Valid_IoU={mean_iou:.6f} SegLoss={seg_loss.item():.6f}") |
| |
|
| | return mean_iou |
| |
|
| | train_mean_iou = 0 |
| | valid_mean_iou = 0 |
| | |
| | |
| | |
| | |
| |
|
| | def read_image(image_path, mask_path): |
| | img = cv2.imread(image_path)[..., ::-1] |
| | mask = cv2.imread(mask_path, 0) |
| | r = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) |
| | img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r))) |
| | mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), interpolation=cv2.INTER_NEAREST) |
| | return img, mask |
| | |
| | def get_points(mask, num_points): |
| | points = [] |
| | coords = np.argwhere(mask > 0) |
| | for i in range(num_points): |
| | yx = np.array(coords[np.random.randint(len(coords))]) |
| | points.append([[yx[1], yx[0]]]) |
| | return np.array(points) |
| |
|
| | for n in range(3): |
| | selected_entry = random.choice(test_data) |
| | print(selected_entry) |
| | image_path = selected_entry['image'] |
| | mask_path = selected_entry['annotation'] |
| | print(mask_path,'mask path') |
| | |
| | |
| | image, target_mask = read_image(image_path, mask_path) |
| | |
| | |
| | num_samples = 30 |
| | input_points = get_points(target_mask, num_samples) |
| | |
| | |
| | FINE_TUNED_MODEL_WEIGHTS = "./checkpoints-ft/fine_tuned_sam2_1200.pt" |
| | sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") |
| | |
| | |
| | predictor = SAM2ImagePredictor(sam2_model) |
| | predictor.model.load_state_dict(torch.load(FINE_TUNED_MODEL_WEIGHTS)) |
| | |
| | |
| | |
| | |
| | with torch.no_grad(): |
| | predictor.set_image(image) |
| | masks, scores, logits = predictor.predict( |
| | point_coords=input_points, |
| | point_labels=np.ones([input_points.shape[0], 1]) |
| | ) |
| | |
| | |
| | np_masks = np.array(masks[:, 0]) |
| | np_scores = scores[:, 0] |
| | sorted_masks = np_masks[np.argsort(np_scores)][::-1] |
| | |
| | |
| | seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8) |
| | occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool) |
| | |
| | |
| | for i in range(sorted_masks.shape[0]): |
| | mask = sorted_masks[i] |
| | if (mask * occupancy_mask).sum() / mask.sum() > 0.15: |
| | continue |
| | |
| | mask_bool = mask.astype(bool) |
| | mask_bool[occupancy_mask] = False |
| | seg_map[mask_bool] = i + 1 |
| | occupancy_mask[mask_bool] = True |
| | |
| | |
| | plt.figure(figsize=(18, 6)) |
| | |
| | plt.subplot(1, 3, 1) |
| | plt.title('Test Image') |
| | plt.imshow(image) |
| | plt.axis('off') |
| | |
| | plt.subplot(1, 3, 2) |
| | plt.title('Original Mask') |
| | plt.imshow(target_mask, cmap='gray') |
| | plt.axis('off') |
| | |
| | plt.subplot(1, 3, 3) |
| | plt.title('Final Segmentation') |
| | plt.imshow(seg_map, cmap='jet') |
| | plt.axis('off') |
| | |
| | plt.tight_layout() |
| | plt.show() |