|
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from tqdm import tqdm |
|
|
import json |
|
|
import os |
|
|
import torch |
|
|
|
|
|
|
|
|
with open('datasets/train_seg_points.json', 'r') as f: |
|
|
points = json.load(f) |
|
|
|
|
|
sam_checkpoint = "sam_vit_h_4b8939.pth" |
|
|
model_type = "vit_h" |
|
|
|
|
|
device = torch.device("cpu") |
|
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
|
sam.to(device=device) |
|
|
predictor = SamPredictor(sam) |
|
|
|
|
|
output_dir = 'prediction/sam_result/sam_val' |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
abs_output_dir = os.path.abspath(output_dir) |
|
|
print(f"[INFO] Output directory: {abs_output_dir}") |
|
|
print(f"[INFO] Directory exists: {os.path.exists(abs_output_dir)}") |
|
|
print(f"[INFO] Directory writable: {os.access(abs_output_dir, os.W_OK)}") |
|
|
|
|
|
for full_name in tqdm(points.keys()): |
|
|
name, ext = os.path.splitext(full_name) |
|
|
sample_points = points.get(full_name) or points.get(f'{name}.png') or points.get(f'{name}.jpg') or points.get(f'{name}.jpeg') or [] |
|
|
|
|
|
|
|
|
possible_paths = [ |
|
|
os.path.join('datasets', 'data', f'{name}.jpeg'), |
|
|
os.path.join('datasets', 'data', f'{name}.jpg'), |
|
|
os.path.join('datasets', 'data', f'{name}.png'), |
|
|
] |
|
|
image = None |
|
|
for p in possible_paths: |
|
|
if os.path.isfile(p): |
|
|
image = cv2.imread(p) |
|
|
print(f"Using image: {p}") |
|
|
break |
|
|
if image is None or image.size == 0: |
|
|
print(f"[WARN] No valid image found for {name}, skipping.") |
|
|
continue |
|
|
|
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
predictor.set_image(np.ascontiguousarray(image)) |
|
|
|
|
|
|
|
|
if len(sample_points) == 0: |
|
|
print(f"[INFO] No points for {name}, saving original image.") |
|
|
cv2.imwrite(os.path.join(output_dir, f"{name}.jpg"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) |
|
|
continue |
|
|
|
|
|
|
|
|
tmp = np.array(sample_points) |
|
|
tmp = tmp[tmp.min(axis=1) > 0] |
|
|
|
|
|
|
|
|
rand_idx = np.random.choice(len(tmp), len(tmp)//2, replace=False) |
|
|
input_point = tmp[rand_idx] |
|
|
|
|
|
|
|
|
img_height, img_width = image.shape[:2] |
|
|
neg_list = [] |
|
|
border_width = 50 |
|
|
|
|
|
while len(neg_list) < 10: |
|
|
side = np.random.choice(['top', 'bottom', 'left', 'right']) |
|
|
if side == 'top': |
|
|
xy = [np.random.randint(img_width), np.random.randint(0, border_width)] |
|
|
elif side == 'bottom': |
|
|
xy = [np.random.randint(img_width), np.random.randint(img_height-border_width, img_height)] |
|
|
elif side == 'left': |
|
|
xy = [np.random.randint(0, border_width), np.random.randint(img_height)] |
|
|
else: |
|
|
xy = [np.random.randint(img_width-border_width, img_width), np.random.randint(img_height)] |
|
|
|
|
|
if xy not in tmp.tolist(): |
|
|
neg_list.append(xy) |
|
|
|
|
|
neg_arr = np.array(neg_list) |
|
|
|
|
|
|
|
|
final_point = np.append(input_point, neg_arr).reshape(-1, 2) |
|
|
|
|
|
|
|
|
input_label = np.array([0] * len(input_point) + [1] * len(neg_arr)) |
|
|
|
|
|
print(f"[INFO] Using {len(input_point)} hair points (label=0) and {len(neg_arr)} scalp points (label=1)") |
|
|
print(f"[INFO] Using ORIGINAL label logic (0=hair, 1=scalp)") |
|
|
print(f"[INFO] Negative points: smart border sampling") |
|
|
|
|
|
|
|
|
masks, scores, logits = predictor.predict( |
|
|
point_coords=final_point, |
|
|
point_labels=input_label, |
|
|
multimask_output=True, |
|
|
) |
|
|
|
|
|
|
|
|
sam_mask = masks[np.argmax(scores)] |
|
|
|
|
|
|
|
|
if sam_mask.ndim > 2: |
|
|
sam_mask = sam_mask.squeeze() |
|
|
|
|
|
|
|
|
if sam_mask.shape != (img_height, img_width): |
|
|
sam_mask = cv2.resize(sam_mask.astype(np.uint8), (img_width, img_height)) |
|
|
|
|
|
|
|
|
binary_map = np.where(sam_mask > 0, 0, 255).astype(np.uint8) |
|
|
print(f"[INFO] Using ORIGINAL binary map logic (hair=0/black, bg=255/white)") |
|
|
|
|
|
|
|
|
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats( |
|
|
binary_map, None, None, None, 8, cv2.CV_32S |
|
|
) |
|
|
|
|
|
|
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
result = np.zeros((labels.shape), np.uint8) |
|
|
kept_components = 0 |
|
|
|
|
|
for i in range(0, nlabels - 1): |
|
|
if areas[i] >= 400: |
|
|
result[labels == i + 1] = 255 |
|
|
kept_components += 1 |
|
|
|
|
|
print(f"[INFO] Found {nlabels-1} components, kept {kept_components} (>= 400 pixels)") |
|
|
|
|
|
|
|
|
save_path = os.path.join(output_dir, f"{name}.jpg") |
|
|
success = cv2.imwrite(save_path, result) |
|
|
|
|
|
if success: |
|
|
file_size = os.path.getsize(save_path) |
|
|
print(f"[INFO] Saved mask for {name} at {save_path} ({file_size} bytes)") |
|
|
else: |
|
|
print(f"[ERROR] cv2.imwrite failed for {name}") |
|
|
print("="*80) |