File size: 5,340 Bytes
7bf5a8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
import json
import os
import torch
# Load sample points
with open('datasets/train_seg_points.json', 'r') as f:
points = json.load(f)
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = torch.device("cpu")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
output_dir = 'prediction/sam_result/sam_val'
os.makedirs(output_dir, exist_ok=True)
# Verify output directory
abs_output_dir = os.path.abspath(output_dir)
print(f"[INFO] Output directory: {abs_output_dir}")
print(f"[INFO] Directory exists: {os.path.exists(abs_output_dir)}")
print(f"[INFO] Directory writable: {os.access(abs_output_dir, os.W_OK)}")
for full_name in tqdm(points.keys()):
name, ext = os.path.splitext(full_name)
sample_points = points.get(full_name) or points.get(f'{name}.png') or points.get(f'{name}.jpg') or points.get(f'{name}.jpeg') or []
# Find input image
possible_paths = [
os.path.join('datasets', 'data', f'{name}.jpeg'),
os.path.join('datasets', 'data', f'{name}.jpg'),
os.path.join('datasets', 'data', f'{name}.png'),
]
image = None
for p in possible_paths:
if os.path.isfile(p):
image = cv2.imread(p)
print(f"Using image: {p}")
break
if image is None or image.size == 0:
print(f"[WARN] No valid image found for {name}, skipping.")
continue
# Prepare image for SAM
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(np.ascontiguousarray(image))
# If no points, save original image
if len(sample_points) == 0:
print(f"[INFO] No points for {name}, saving original image.")
cv2.imwrite(os.path.join(output_dir, f"{name}.jpg"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
continue
# Filter valid points
tmp = np.array(sample_points)
tmp = tmp[tmp.min(axis=1) > 0]
# Use 50% of hair points randomly
rand_idx = np.random.choice(len(tmp), len(tmp)//2, replace=False)
input_point = tmp[rand_idx] # 50% hair points
# ✅ UPDATED: Smart negative points from border regions
img_height, img_width = image.shape[:2]
neg_list = []
border_width = 50
while len(neg_list) < 10:
side = np.random.choice(['top', 'bottom', 'left', 'right'])
if side == 'top':
xy = [np.random.randint(img_width), np.random.randint(0, border_width)]
elif side == 'bottom':
xy = [np.random.randint(img_width), np.random.randint(img_height-border_width, img_height)]
elif side == 'left':
xy = [np.random.randint(0, border_width), np.random.randint(img_height)]
else: # right
xy = [np.random.randint(img_width-border_width, img_width), np.random.randint(img_height)]
if xy not in tmp.tolist():
neg_list.append(xy)
neg_arr = np.array(neg_list) # scalp points (from borders)
# Combine points
final_point = np.append(input_point, neg_arr).reshape(-1, 2)
# LOGIC CODE 2: Label assignment (0=hair, 1=scalp) - KEPT AS ORIGINAL
input_label = np.array([0] * len(input_point) + [1] * len(neg_arr))
print(f"[INFO] Using {len(input_point)} hair points (label=0) and {len(neg_arr)} scalp points (label=1)")
print(f"[INFO] Using ORIGINAL label logic (0=hair, 1=scalp)")
print(f"[INFO] Negative points: smart border sampling")
# Predict mask
masks, scores, logits = predictor.predict(
point_coords=final_point,
point_labels=input_label,
multimask_output=True,
)
# Get best mask
sam_mask = masks[np.argmax(scores)]
# Ensure 2D
if sam_mask.ndim > 2:
sam_mask = sam_mask.squeeze()
# Resize if needed
if sam_mask.shape != (img_height, img_width):
sam_mask = cv2.resize(sam_mask.astype(np.uint8), (img_width, img_height))
# LOGIC CODE 2: Binary map INVERTED (hair=0/black, bg=255/white) - KEPT AS ORIGINAL
binary_map = np.where(sam_mask > 0, 0, 255).astype(np.uint8)
print(f"[INFO] Using ORIGINAL binary map logic (hair=0/black, bg=255/white)")
# Get rid of noises (small white spots that are not hair)
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
binary_map, None, None, None, 8, cv2.CV_32S
)
# Get CC_STAT_AREA component
areas = stats[1:, cv2.CC_STAT_AREA]
result = np.zeros((labels.shape), np.uint8)
kept_components = 0
for i in range(0, nlabels - 1):
if areas[i] >= 400: # Keep components >= 400 pixels
result[labels == i + 1] = 255
kept_components += 1
print(f"[INFO] Found {nlabels-1} components, kept {kept_components} (>= 400 pixels)")
# Save result
save_path = os.path.join(output_dir, f"{name}.jpg")
success = cv2.imwrite(save_path, result)
if success:
file_size = os.path.getsize(save_path)
print(f"[INFO] Saved mask for {name} at {save_path} ({file_size} bytes)")
else:
print(f"[ERROR] cv2.imwrite failed for {name}")
print("="*80) |