File size: 5,340 Bytes
7bf5a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
import json
import os
import torch

# Load sample points
with open('datasets/train_seg_points.json', 'r') as f:
    points = json.load(f)

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = torch.device("cpu")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

output_dir = 'prediction/sam_result/sam_val'
os.makedirs(output_dir, exist_ok=True)

# Verify output directory
abs_output_dir = os.path.abspath(output_dir)
print(f"[INFO] Output directory: {abs_output_dir}")
print(f"[INFO] Directory exists: {os.path.exists(abs_output_dir)}")
print(f"[INFO] Directory writable: {os.access(abs_output_dir, os.W_OK)}")

for full_name in tqdm(points.keys()):
    name, ext = os.path.splitext(full_name)
    sample_points = points.get(full_name) or points.get(f'{name}.png') or points.get(f'{name}.jpg') or points.get(f'{name}.jpeg') or []

    # Find input image
    possible_paths = [
        os.path.join('datasets', 'data', f'{name}.jpeg'),
        os.path.join('datasets', 'data', f'{name}.jpg'),
        os.path.join('datasets', 'data', f'{name}.png'),
    ]
    image = None
    for p in possible_paths:
        if os.path.isfile(p):
            image = cv2.imread(p)
            print(f"Using image: {p}")
            break
    if image is None or image.size == 0:
        print(f"[WARN] No valid image found for {name}, skipping.")
        continue

    # Prepare image for SAM
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(np.ascontiguousarray(image))
    
    # If no points, save original image
    if len(sample_points) == 0:
        print(f"[INFO] No points for {name}, saving original image.")
        cv2.imwrite(os.path.join(output_dir, f"{name}.jpg"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
        continue
    
    # Filter valid points
    tmp = np.array(sample_points)
    tmp = tmp[tmp.min(axis=1) > 0]
    
    # Use 50% of hair points randomly
    rand_idx = np.random.choice(len(tmp), len(tmp)//2, replace=False)
    input_point = tmp[rand_idx]  # 50% hair points
    
    # ✅ UPDATED: Smart negative points from border regions
    img_height, img_width = image.shape[:2]
    neg_list = []
    border_width = 50
    
    while len(neg_list) < 10:
        side = np.random.choice(['top', 'bottom', 'left', 'right'])
        if side == 'top':
            xy = [np.random.randint(img_width), np.random.randint(0, border_width)]
        elif side == 'bottom':
            xy = [np.random.randint(img_width), np.random.randint(img_height-border_width, img_height)]
        elif side == 'left':
            xy = [np.random.randint(0, border_width), np.random.randint(img_height)]
        else:  # right
            xy = [np.random.randint(img_width-border_width, img_width), np.random.randint(img_height)]
        
        if xy not in tmp.tolist():
            neg_list.append(xy)
    
    neg_arr = np.array(neg_list)  # scalp points (from borders)
    
    # Combine points
    final_point = np.append(input_point, neg_arr).reshape(-1, 2)
    
    # LOGIC CODE 2: Label assignment (0=hair, 1=scalp) - KEPT AS ORIGINAL
    input_label = np.array([0] * len(input_point) + [1] * len(neg_arr))
    
    print(f"[INFO] Using {len(input_point)} hair points (label=0) and {len(neg_arr)} scalp points (label=1)")
    print(f"[INFO] Using ORIGINAL label logic (0=hair, 1=scalp)")
    print(f"[INFO] Negative points: smart border sampling")
    
    # Predict mask
    masks, scores, logits = predictor.predict(
        point_coords=final_point,
        point_labels=input_label,
        multimask_output=True,
    )
    
    # Get best mask
    sam_mask = masks[np.argmax(scores)]
    
    # Ensure 2D
    if sam_mask.ndim > 2:
        sam_mask = sam_mask.squeeze()
    
    # Resize if needed
    if sam_mask.shape != (img_height, img_width):
        sam_mask = cv2.resize(sam_mask.astype(np.uint8), (img_width, img_height))
    
    # LOGIC CODE 2: Binary map INVERTED (hair=0/black, bg=255/white) - KEPT AS ORIGINAL
    binary_map = np.where(sam_mask > 0, 0, 255).astype(np.uint8)
    print(f"[INFO] Using ORIGINAL binary map logic (hair=0/black, bg=255/white)")
    
    # Get rid of noises (small white spots that are not hair)
    nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
        binary_map, None, None, None, 8, cv2.CV_32S
    )
    
    # Get CC_STAT_AREA component
    areas = stats[1:, cv2.CC_STAT_AREA]
    result = np.zeros((labels.shape), np.uint8)
    kept_components = 0
    
    for i in range(0, nlabels - 1):
        if areas[i] >= 400:   # Keep components >= 400 pixels
            result[labels == i + 1] = 255
            kept_components += 1
    
    print(f"[INFO] Found {nlabels-1} components, kept {kept_components} (>= 400 pixels)")
    
    # Save result
    save_path = os.path.join(output_dir, f"{name}.jpg")
    success = cv2.imwrite(save_path, result)
    
    if success:
        file_size = os.path.getsize(save_path)
        print(f"[INFO] Saved mask for {name} at {save_path} ({file_size} bytes)")
    else:
        print(f"[ERROR] cv2.imwrite failed for {name}")
    print("="*80)