full / segmentation /sam_predict.py
caubetotbunggg's picture
Upload folder using huggingface_hub
7bf5a8e verified
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
import json
import os
import torch
# Load sample points
with open('datasets/train_seg_points.json', 'r') as f:
points = json.load(f)
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = torch.device("cpu")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
output_dir = 'prediction/sam_result/sam_val'
os.makedirs(output_dir, exist_ok=True)
# Verify output directory
abs_output_dir = os.path.abspath(output_dir)
print(f"[INFO] Output directory: {abs_output_dir}")
print(f"[INFO] Directory exists: {os.path.exists(abs_output_dir)}")
print(f"[INFO] Directory writable: {os.access(abs_output_dir, os.W_OK)}")
for full_name in tqdm(points.keys()):
name, ext = os.path.splitext(full_name)
sample_points = points.get(full_name) or points.get(f'{name}.png') or points.get(f'{name}.jpg') or points.get(f'{name}.jpeg') or []
# Find input image
possible_paths = [
os.path.join('datasets', 'data', f'{name}.jpeg'),
os.path.join('datasets', 'data', f'{name}.jpg'),
os.path.join('datasets', 'data', f'{name}.png'),
]
image = None
for p in possible_paths:
if os.path.isfile(p):
image = cv2.imread(p)
print(f"Using image: {p}")
break
if image is None or image.size == 0:
print(f"[WARN] No valid image found for {name}, skipping.")
continue
# Prepare image for SAM
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(np.ascontiguousarray(image))
# If no points, save original image
if len(sample_points) == 0:
print(f"[INFO] No points for {name}, saving original image.")
cv2.imwrite(os.path.join(output_dir, f"{name}.jpg"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
continue
# Filter valid points
tmp = np.array(sample_points)
tmp = tmp[tmp.min(axis=1) > 0]
# Use 50% of hair points randomly
rand_idx = np.random.choice(len(tmp), len(tmp)//2, replace=False)
input_point = tmp[rand_idx] # 50% hair points
# ✅ UPDATED: Smart negative points from border regions
img_height, img_width = image.shape[:2]
neg_list = []
border_width = 50
while len(neg_list) < 10:
side = np.random.choice(['top', 'bottom', 'left', 'right'])
if side == 'top':
xy = [np.random.randint(img_width), np.random.randint(0, border_width)]
elif side == 'bottom':
xy = [np.random.randint(img_width), np.random.randint(img_height-border_width, img_height)]
elif side == 'left':
xy = [np.random.randint(0, border_width), np.random.randint(img_height)]
else: # right
xy = [np.random.randint(img_width-border_width, img_width), np.random.randint(img_height)]
if xy not in tmp.tolist():
neg_list.append(xy)
neg_arr = np.array(neg_list) # scalp points (from borders)
# Combine points
final_point = np.append(input_point, neg_arr).reshape(-1, 2)
# LOGIC CODE 2: Label assignment (0=hair, 1=scalp) - KEPT AS ORIGINAL
input_label = np.array([0] * len(input_point) + [1] * len(neg_arr))
print(f"[INFO] Using {len(input_point)} hair points (label=0) and {len(neg_arr)} scalp points (label=1)")
print(f"[INFO] Using ORIGINAL label logic (0=hair, 1=scalp)")
print(f"[INFO] Negative points: smart border sampling")
# Predict mask
masks, scores, logits = predictor.predict(
point_coords=final_point,
point_labels=input_label,
multimask_output=True,
)
# Get best mask
sam_mask = masks[np.argmax(scores)]
# Ensure 2D
if sam_mask.ndim > 2:
sam_mask = sam_mask.squeeze()
# Resize if needed
if sam_mask.shape != (img_height, img_width):
sam_mask = cv2.resize(sam_mask.astype(np.uint8), (img_width, img_height))
# LOGIC CODE 2: Binary map INVERTED (hair=0/black, bg=255/white) - KEPT AS ORIGINAL
binary_map = np.where(sam_mask > 0, 0, 255).astype(np.uint8)
print(f"[INFO] Using ORIGINAL binary map logic (hair=0/black, bg=255/white)")
# Get rid of noises (small white spots that are not hair)
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
binary_map, None, None, None, 8, cv2.CV_32S
)
# Get CC_STAT_AREA component
areas = stats[1:, cv2.CC_STAT_AREA]
result = np.zeros((labels.shape), np.uint8)
kept_components = 0
for i in range(0, nlabels - 1):
if areas[i] >= 400: # Keep components >= 400 pixels
result[labels == i + 1] = 255
kept_components += 1
print(f"[INFO] Found {nlabels-1} components, kept {kept_components} (>= 400 pixels)")
# Save result
save_path = os.path.join(output_dir, f"{name}.jpg")
success = cv2.imwrite(save_path, result)
if success:
file_size = os.path.getsize(save_path)
print(f"[INFO] Saved mask for {name} at {save_path} ({file_size} bytes)")
else:
print(f"[ERROR] cv2.imwrite failed for {name}")
print("="*80)