from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor import matplotlib.pyplot as plt import numpy as np import cv2 from tqdm import tqdm import json import os import torch # Load sample points with open('datasets/train_seg_points.json', 'r') as f: points = json.load(f) sam_checkpoint = "sam_vit_h_4b8939.pth" model_type = "vit_h" device = torch.device("cpu") sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) output_dir = 'prediction/sam_result/sam_val' os.makedirs(output_dir, exist_ok=True) # Verify output directory abs_output_dir = os.path.abspath(output_dir) print(f"[INFO] Output directory: {abs_output_dir}") print(f"[INFO] Directory exists: {os.path.exists(abs_output_dir)}") print(f"[INFO] Directory writable: {os.access(abs_output_dir, os.W_OK)}") for full_name in tqdm(points.keys()): name, ext = os.path.splitext(full_name) sample_points = points.get(full_name) or points.get(f'{name}.png') or points.get(f'{name}.jpg') or points.get(f'{name}.jpeg') or [] # Find input image possible_paths = [ os.path.join('datasets', 'data', f'{name}.jpeg'), os.path.join('datasets', 'data', f'{name}.jpg'), os.path.join('datasets', 'data', f'{name}.png'), ] image = None for p in possible_paths: if os.path.isfile(p): image = cv2.imread(p) print(f"Using image: {p}") break if image is None or image.size == 0: print(f"[WARN] No valid image found for {name}, skipping.") continue # Prepare image for SAM image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) predictor.set_image(np.ascontiguousarray(image)) # If no points, save original image if len(sample_points) == 0: print(f"[INFO] No points for {name}, saving original image.") cv2.imwrite(os.path.join(output_dir, f"{name}.jpg"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) continue # Filter valid points tmp = np.array(sample_points) tmp = tmp[tmp.min(axis=1) > 0] # Use 50% of hair points randomly rand_idx = np.random.choice(len(tmp), len(tmp)//2, replace=False) input_point = tmp[rand_idx] # 50% hair points # ✅ UPDATED: Smart negative points from border regions img_height, img_width = image.shape[:2] neg_list = [] border_width = 50 while len(neg_list) < 10: side = np.random.choice(['top', 'bottom', 'left', 'right']) if side == 'top': xy = [np.random.randint(img_width), np.random.randint(0, border_width)] elif side == 'bottom': xy = [np.random.randint(img_width), np.random.randint(img_height-border_width, img_height)] elif side == 'left': xy = [np.random.randint(0, border_width), np.random.randint(img_height)] else: # right xy = [np.random.randint(img_width-border_width, img_width), np.random.randint(img_height)] if xy not in tmp.tolist(): neg_list.append(xy) neg_arr = np.array(neg_list) # scalp points (from borders) # Combine points final_point = np.append(input_point, neg_arr).reshape(-1, 2) # LOGIC CODE 2: Label assignment (0=hair, 1=scalp) - KEPT AS ORIGINAL input_label = np.array([0] * len(input_point) + [1] * len(neg_arr)) print(f"[INFO] Using {len(input_point)} hair points (label=0) and {len(neg_arr)} scalp points (label=1)") print(f"[INFO] Using ORIGINAL label logic (0=hair, 1=scalp)") print(f"[INFO] Negative points: smart border sampling") # Predict mask masks, scores, logits = predictor.predict( point_coords=final_point, point_labels=input_label, multimask_output=True, ) # Get best mask sam_mask = masks[np.argmax(scores)] # Ensure 2D if sam_mask.ndim > 2: sam_mask = sam_mask.squeeze() # Resize if needed if sam_mask.shape != (img_height, img_width): sam_mask = cv2.resize(sam_mask.astype(np.uint8), (img_width, img_height)) # LOGIC CODE 2: Binary map INVERTED (hair=0/black, bg=255/white) - KEPT AS ORIGINAL binary_map = np.where(sam_mask > 0, 0, 255).astype(np.uint8) print(f"[INFO] Using ORIGINAL binary map logic (hair=0/black, bg=255/white)") # Get rid of noises (small white spots that are not hair) nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats( binary_map, None, None, None, 8, cv2.CV_32S ) # Get CC_STAT_AREA component areas = stats[1:, cv2.CC_STAT_AREA] result = np.zeros((labels.shape), np.uint8) kept_components = 0 for i in range(0, nlabels - 1): if areas[i] >= 400: # Keep components >= 400 pixels result[labels == i + 1] = 255 kept_components += 1 print(f"[INFO] Found {nlabels-1} components, kept {kept_components} (>= 400 pixels)") # Save result save_path = os.path.join(output_dir, f"{name}.jpg") success = cv2.imwrite(save_path, result) if success: file_size = os.path.getsize(save_path) print(f"[INFO] Saved mask for {name} at {save_path} ({file_size} bytes)") else: print(f"[ERROR] cv2.imwrite failed for {name}") print("="*80)