AdamHines commited on
Commit
57cfe02
·
verified ·
1 Parent(s): 5efac04

Upload finetune-sam2.py

Browse files
Files changed (1) hide show
  1. finetune-sam2.py +426 -0
finetune-sam2.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import pandas as pd
4
+ import cv2
5
+ import torch
6
+ import torch.nn.utils
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.colors as mcolors
11
+ from sklearn.model_selection import train_test_split
12
+
13
+ from sam2.build_sam import build_sam2
14
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
15
+
16
+ def set_seeds():
17
+ SEED_VALUE = 42
18
+ random.seed(SEED_VALUE)
19
+ np.random.seed(SEED_VALUE)
20
+ torch.manual_seed(SEED_VALUE)
21
+ if torch.cuda.is_available():
22
+ torch.cuda.manual_seed(SEED_VALUE)
23
+ torch.cuda.manual_seed_all(SEED_VALUE)
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ set_seeds()
28
+
29
+ data_dir = "./sam2-data"
30
+ images_dir = os.path.join(data_dir, "images")
31
+ masks_dir = os.path.join(data_dir, "masks")
32
+
33
+ train_df = pd.read_csv(os.path.join(data_dir, "train.csv"))
34
+
35
+ train_df, test_df = train_test_split(train_df, test_size=0.1, random_state=42)
36
+
37
+ train_data = []
38
+ for index, row in train_df.iterrows():
39
+ image_name = row['imageid']
40
+ mask_name = row['maskid']
41
+ train_data.append({
42
+ "image": os.path.join(images_dir, image_name),
43
+ "annotation": os.path.join(masks_dir, mask_name)
44
+ })
45
+
46
+ test_data = []
47
+
48
+ for index, row in test_df.iterrows():
49
+ image_name = row['imageid']
50
+ mask_name = row['maskid']
51
+ test_data.append({
52
+ "image": os.path.join(images_dir, image_name),
53
+ "annotation": os.path.join(masks_dir, mask_name)
54
+ })
55
+
56
+ def read_batch(data, visualize_data=True):
57
+ ent = data[np.random.randint(len(data))]
58
+ Img = cv2.imread(ent["image"])[..., ::-1]
59
+ ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)
60
+
61
+ if Img is None or ann_map is None:
62
+ print(f"Error: Could not read image or mask from path {ent['image']} or {ent['annotation']}")
63
+ return None, None, None, 0
64
+
65
+ r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]])
66
+ Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
67
+ ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),
68
+ interpolation=cv2.INTER_NEAREST)
69
+
70
+ binary_mask = np.zeros_like(ann_map, dtype=np.uint8)
71
+ points = []
72
+ inds = np.unique(ann_map)[1:]
73
+ for ind in inds:
74
+ mask = (ann_map == ind).astype(np.uint8)
75
+ binary_mask = np.maximum(binary_mask, mask)
76
+
77
+ eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1)
78
+ coords = np.argwhere(eroded_mask > 0)
79
+ if len(coords) > 0:
80
+ for _ in inds:
81
+ yx = np.array(coords[np.random.randint(len(coords))])
82
+ points.append([yx[1], yx[0]])
83
+ points = np.array(points)
84
+
85
+ if visualize_data:
86
+ plt.figure(figsize=(15, 5))
87
+ plt.subplot(1, 3, 1)
88
+ plt.title('Original Image')
89
+ plt.imshow(Img)
90
+ plt.axis('off')
91
+
92
+ plt.subplot(1, 3, 2)
93
+ plt.title('Binarized Mask')
94
+ plt.imshow(binary_mask, cmap='gray')
95
+ plt.axis('off')
96
+
97
+ plt.subplot(1, 3, 3)
98
+ plt.title('Binarized Mask with Points')
99
+ plt.imshow(binary_mask, cmap='gray')
100
+ colors = list(mcolors.TABLEAU_COLORS.values())
101
+ for i, point in enumerate(points):
102
+ plt.scatter(point[0], point[1], c=colors[i % len(colors)], s=100)
103
+ plt.axis('off')
104
+
105
+ plt.tight_layout()
106
+ plt.show()
107
+
108
+ binary_mask = np.expand_dims(binary_mask, axis=-1)
109
+ binary_mask = binary_mask.transpose((2, 0, 1))
110
+ points = np.expand_dims(points, axis=1)
111
+ return Img, binary_mask, points, len(inds)
112
+
113
+ # Img1, masks1, points1, num_masks = read_batch(train_data, visualize_data=True)
114
+ def _to_hydra_name(x):
115
+ if not x:
116
+ return None
117
+ s = str(x).replace("\\", "/")
118
+ if s.endswith(".yaml"):
119
+ s = s[:-5]
120
+ # Normalize absolute/relative repo paths to hydra names:
121
+ # /.../sam2/sam2/configs/sam2.1/sam2.1_hiera_s -> configs/sam2.1/sam2.1_hiera_s
122
+ # ./sam2/configs/sam2.1/sam2.1_hiera_s -> configs/sam2.1/sam2.1_hiera_s
123
+ if "/sam2/configs/" in s:
124
+ return s.split("/sam2/")[1] # keep from 'configs/...'
125
+ if s.startswith("sam2/configs/"):
126
+ return s[len("sam2/"):] # strip leading 'sam2/'
127
+ return s
128
+
129
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
130
+ model_cfg = "./sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
131
+
132
+ model_cfg = _to_hydra_name(model_cfg)
133
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
134
+ predictor = SAM2ImagePredictor(sam2_model)
135
+
136
+ predictor.model.sam_mask_decoder.train(True)
137
+ predictor.model.sam_prompt_encoder.train(True)
138
+
139
+ scaler = torch.amp.GradScaler()
140
+ NO_OF_STEPS = 1200
141
+ FINE_TUNED_MODEL_NAME = "fine_tuned_sam2"
142
+
143
+ optimizer = torch.optim.AdamW(params=predictor.model.parameters(),
144
+ lr=0.00005,
145
+ weight_decay=1e-4)
146
+
147
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.6)
148
+ accumulation_steps = 8
149
+
150
+ def train(predictor, train_data, step, mean_iou):
151
+ # Ensure rolling mean is numeric
152
+ if mean_iou is None or (isinstance(mean_iou, float) and (mean_iou != mean_iou)): # NaN
153
+ mean_iou = 0.0
154
+
155
+ eps = 1e-6
156
+
157
+ predictor.model.train()
158
+ with torch.amp.autocast(device_type='cuda'):
159
+ image, mask, input_point, num_masks = read_batch(train_data, visualize_data=False)
160
+
161
+ # If this batch is unusable, keep the rolling mean unchanged
162
+ if image is None or mask is None or num_masks == 0:
163
+ return mean_iou
164
+
165
+ input_label = np.ones((num_masks, 1), dtype=np.int64)
166
+
167
+ if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray):
168
+ return mean_iou
169
+ if input_point.size == 0 or input_label.size == 0:
170
+ return mean_iou
171
+
172
+ predictor.set_image(image)
173
+ mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
174
+ input_point, input_label, box=None, mask_logits=None, normalize_coords=True
175
+ )
176
+ if (
177
+ unnorm_coords is None or labels is None or
178
+ unnorm_coords.shape[0] == 0 or labels.shape[0] == 0
179
+ ):
180
+ return mean_iou
181
+
182
+ sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
183
+ points=(unnorm_coords, labels), boxes=None, masks=None
184
+ )
185
+
186
+ batched_mode = unnorm_coords.shape[0] > 1
187
+ high_res_features = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
188
+
189
+ low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
190
+ image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
191
+ image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
192
+ sparse_prompt_embeddings=sparse_embeddings,
193
+ dense_prompt_embeddings=dense_embeddings,
194
+ multimask_output=True,
195
+ repeat_image=batched_mode,
196
+ high_res_features=high_res_features,
197
+ )
198
+
199
+ prd_masks = predictor._transforms.postprocess_masks(
200
+ low_res_masks, predictor._orig_hw[-1]
201
+ )
202
+
203
+ gt_mask = torch.tensor(mask.astype(np.float32), device='cuda')
204
+ prd_mask = torch.sigmoid(prd_masks[:, 0])
205
+
206
+ # BCE-style seg loss (numerically stable enough with eps)
207
+ seg_loss = (-gt_mask * torch.log(prd_mask + eps)
208
+ - (1 - gt_mask) * torch.log((1 - prd_mask) + eps)).mean()
209
+
210
+ # IoU with safeties
211
+ pred_bin = (prd_mask > 0.5).float()
212
+ inter = (gt_mask * pred_bin).sum(dim=(1, 2))
213
+ denom = gt_mask.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter
214
+ iou = inter / (denom + eps)
215
+
216
+ score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
217
+ loss = seg_loss + 0.05 * score_loss
218
+
219
+ # grad accumulation
220
+ loss = loss / accumulation_steps
221
+ scaler.scale(loss).backward()
222
+
223
+ torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)
224
+
225
+ did_optimizer_step = False
226
+ if step % accumulation_steps == 0:
227
+ # Optimizer step first, then scheduler.step() (fixes the warning)
228
+ scaler.step(optimizer)
229
+ scaler.update()
230
+ optimizer.zero_grad(set_to_none=True)
231
+ did_optimizer_step = True
232
+
233
+ # Step the LR scheduler only when we actually step the optimizer
234
+ if did_optimizer_step:
235
+ scheduler.step()
236
+
237
+ # Update rolling mean IoU (robust to NaN/inf)
238
+ iou_np = iou.detach().float().cpu().numpy()
239
+ iou_np = np.nan_to_num(iou_np, nan=0.0, posinf=1.0, neginf=0.0)
240
+ mean_iou = float(mean_iou * 0.99 + 0.01 * float(np.mean(iou_np)))
241
+
242
+ if step % 100 == 0:
243
+ current_lr = optimizer.param_groups[0]["lr"]
244
+ print(f"Step {step}: LR={current_lr:.6f} IoU={mean_iou:.6f} SegLoss={seg_loss.item():.6f}")
245
+
246
+ return mean_iou
247
+
248
+ def validate(predictor, test_data, step, mean_iou):
249
+ # Always have a numeric baseline
250
+ if mean_iou is None or (isinstance(mean_iou, float) and (mean_iou != mean_iou)): # NaN check
251
+ mean_iou = 0.0
252
+
253
+ predictor.model.eval()
254
+ with torch.amp.autocast(device_type='cuda'):
255
+ with torch.no_grad():
256
+ image, mask, input_point, num_masks = read_batch(test_data, visualize_data=False)
257
+
258
+ # If this batch is unusable, keep the rolling mean unchanged
259
+ if image is None or mask is None or num_masks == 0:
260
+ return mean_iou
261
+
262
+ input_label = np.ones((num_masks, 1), dtype=np.int64)
263
+
264
+ if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray):
265
+ return mean_iou
266
+ if input_point.size == 0 or input_label.size == 0:
267
+ return mean_iou
268
+
269
+ predictor.set_image(image)
270
+ mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
271
+ input_point, input_label, box=None, mask_logits=None, normalize_coords=True
272
+ )
273
+
274
+ if (
275
+ unnorm_coords is None or labels is None or
276
+ unnorm_coords.shape[0] == 0 or labels.shape[0] == 0
277
+ ):
278
+ return mean_iou
279
+
280
+ sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
281
+ points=(unnorm_coords, labels), boxes=None, masks=None
282
+ )
283
+
284
+ batched_mode = unnorm_coords.shape[0] > 1
285
+ high_res_features = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
286
+ low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
287
+ image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
288
+ image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
289
+ sparse_prompt_embeddings=sparse_embeddings,
290
+ dense_prompt_embeddings=dense_embeddings,
291
+ multimask_output=True,
292
+ repeat_image=batched_mode,
293
+ high_res_features=high_res_features,
294
+ )
295
+
296
+ prd_masks = predictor._transforms.postprocess_masks(
297
+ low_res_masks, predictor._orig_hw[-1]
298
+ )
299
+
300
+ gt_mask = torch.tensor(mask.astype(np.float32), device='cuda')
301
+ prd_mask = torch.sigmoid(prd_masks[:, 0])
302
+
303
+ # BCE-style seg loss
304
+ eps = 1e-6
305
+ seg_loss = (-gt_mask * torch.log(prd_mask + eps)
306
+ - (1 - gt_mask) * torch.log((1 - prd_mask) + eps)).mean()
307
+
308
+ # IoU with numerical safety
309
+ pred_bin = (prd_mask > 0.5).float()
310
+ inter = (gt_mask * pred_bin).sum(dim=(1, 2))
311
+ denom = gt_mask.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter
312
+ iou = inter / (denom + eps) # avoid 0/0
313
+
314
+ # Score loss
315
+ score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
316
+ loss = seg_loss + 0.05 * score_loss
317
+ loss = loss / accumulation_steps # assumes defined elsewhere
318
+
319
+ if step % 100 == 0:
320
+ torch.save(predictor.model.state_dict(), f"./checkpoints-ft/{FINE_TUNED_MODEL_NAME}_{step}.pt")
321
+
322
+ iou_np = iou.detach().float().cpu().numpy()
323
+ iou_np = np.nan_to_num(iou_np, nan=0.0, posinf=1.0, neginf=0.0)
324
+ mean_iou = float(mean_iou * 0.99 + 0.01 * float(np.mean(iou_np)))
325
+
326
+ if step % 100 == 0:
327
+ current_lr = optimizer.param_groups[0]["lr"]
328
+ print(f"Step {step}: LR={current_lr:.6f} Valid_IoU={mean_iou:.6f} SegLoss={seg_loss.item():.6f}")
329
+
330
+ return mean_iou
331
+
332
+ train_mean_iou = 0
333
+ valid_mean_iou = 0
334
+
335
+ # for step in range(1, NO_OF_STEPS + 1):
336
+ # train_mean_iou = train(predictor, train_data, step, train_mean_iou)
337
+ # valid_mean_iou = validate(predictor, test_data, step, valid_mean_iou)
338
+
339
+ def read_image(image_path, mask_path): # read and resize image and mask
340
+ img = cv2.imread(image_path)[..., ::-1] # Convert BGR to RGB
341
+ mask = cv2.imread(mask_path, 0)
342
+ r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
343
+ img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
344
+ mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), interpolation=cv2.INTER_NEAREST)
345
+ return img, mask
346
+
347
+ def get_points(mask, num_points): # Sample points inside the input mask
348
+ points = []
349
+ coords = np.argwhere(mask > 0)
350
+ for i in range(num_points):
351
+ yx = np.array(coords[np.random.randint(len(coords))])
352
+ points.append([[yx[1], yx[0]]])
353
+ return np.array(points)
354
+
355
+ for n in range(3):
356
+ selected_entry = random.choice(test_data)
357
+ print(selected_entry)
358
+ image_path = selected_entry['image']
359
+ mask_path = selected_entry['annotation']
360
+ print(mask_path,'mask path')
361
+
362
+ # Load the selected image and mask
363
+ image, target_mask = read_image(image_path, mask_path)
364
+
365
+ # Generate random points for the input
366
+ num_samples = 30 # Number of points per segment to sample
367
+ input_points = get_points(target_mask, num_samples)
368
+
369
+ # Load the fine-tuned model
370
+ FINE_TUNED_MODEL_WEIGHTS = "./checkpoints-ft/fine_tuned_sam2_1200.pt"
371
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
372
+
373
+ # Build net and load weights
374
+ predictor = SAM2ImagePredictor(sam2_model)
375
+ predictor.model.load_state_dict(torch.load(FINE_TUNED_MODEL_WEIGHTS))
376
+
377
+
378
+
379
+ # Perform inference and predict masks
380
+ with torch.no_grad():
381
+ predictor.set_image(image)
382
+ masks, scores, logits = predictor.predict(
383
+ point_coords=input_points,
384
+ point_labels=np.ones([input_points.shape[0], 1])
385
+ )
386
+
387
+ # Process the predicted masks and sort by scores
388
+ np_masks = np.array(masks[:, 0])
389
+ np_scores = scores[:, 0]
390
+ sorted_masks = np_masks[np.argsort(np_scores)][::-1]
391
+
392
+ # Initialize segmentation map and occupancy mask
393
+ seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
394
+ occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)
395
+
396
+ # Combine masks to create the final segmentation map
397
+ for i in range(sorted_masks.shape[0]):
398
+ mask = sorted_masks[i]
399
+ if (mask * occupancy_mask).sum() / mask.sum() > 0.15:
400
+ continue
401
+
402
+ mask_bool = mask.astype(bool)
403
+ mask_bool[occupancy_mask] = False # Set overlapping areas to False in the mask
404
+ seg_map[mask_bool] = i + 1 # Use boolean mask to index seg_map
405
+ occupancy_mask[mask_bool] = True # Update occupancy_mask
406
+
407
+ # Visualization: Show the original image, mask, and final segmentation side by side
408
+ plt.figure(figsize=(18, 6))
409
+
410
+ plt.subplot(1, 3, 1)
411
+ plt.title('Test Image')
412
+ plt.imshow(image)
413
+ plt.axis('off')
414
+
415
+ plt.subplot(1, 3, 2)
416
+ plt.title('Original Mask')
417
+ plt.imshow(target_mask, cmap='gray')
418
+ plt.axis('off')
419
+
420
+ plt.subplot(1, 3, 3)
421
+ plt.title('Final Segmentation')
422
+ plt.imshow(seg_map, cmap='jet')
423
+ plt.axis('off')
424
+
425
+ plt.tight_layout()
426
+ plt.show()