File size: 4,151 Bytes
7bf5a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import cv2
import glob
from tqdm import tqdm
import numpy as np
import os

def make_final_mask(seg_path, sam_path, result_path):
    # Try multiple extensions for seg_path
    seg_patterns = [
        os.path.join(seg_path, '*.png'),
        os.path.join(seg_path, '*.jpg'),
        os.path.join(seg_path, '*.jpeg'),
    ]
    
    seg_full_path = []
    for pattern in seg_patterns:
        files = sorted(glob.glob(pattern))
        seg_full_path.extend(files)
    seg_full_path = sorted(list(set(seg_full_path)))  # Remove duplicates
    
    # Try multiple extensions for sam_path
    sam_patterns = [
        os.path.join(sam_path, '*.jpg'),
        os.path.join(sam_path, '*.png'),
        os.path.join(sam_path, '*.jpeg'),
    ]
    
    sam_full_path = []
    for pattern in sam_patterns:
        files = sorted(glob.glob(pattern))
        sam_full_path.extend(files)
    sam_full_path = sorted(list(set(sam_full_path)))
    
    # DEBUG
    print(f"[DEBUG] seg_path: {seg_path}")
    print(f"[DEBUG] Found {len(seg_full_path)} seg images")
    if len(seg_full_path) > 0:
        print(f"[DEBUG] First 3 seg files: {seg_full_path[:3]}")
    
    print(f"[DEBUG] sam_path: {sam_path}")
    print(f"[DEBUG] Found {len(sam_full_path)} sam images")
    if len(sam_full_path) > 0:
        print(f"[DEBUG] First 3 sam files: {sam_full_path[:3]}")
    
    if len(seg_full_path) == 0:
        print(f"[ERROR] No seg images found in {seg_path}")
        return
    
    if len(sam_full_path) == 0:
        print(f"[ERROR] No sam images found in {sam_path}")
        return
    
    # Match by filename (without extension)
    seg_dict = {}
    for path in seg_full_path:
        basename = os.path.splitext(os.path.basename(path))[0]
        seg_dict[basename] = path
    
    sam_dict = {}
    for path in sam_full_path:
        basename = os.path.splitext(os.path.basename(path))[0]
        sam_dict[basename] = path
    
    # Find matching pairs
    matched_pairs = []
    for name in seg_dict.keys():
        if name in sam_dict:
            matched_pairs.append((seg_dict[name], sam_dict[name]))
    
    print(f"[INFO] Found {len(matched_pairs)} matching pairs")
    
    if len(matched_pairs) == 0:
        print("[ERROR] No matching pairs found!")
        print(f"[DEBUG] Seg basenames: {list(seg_dict.keys())[:5]}")
        print(f"[DEBUG] Sam basenames: {list(sam_dict.keys())[:5]}")
        return
    
    for seg, sam in tqdm(matched_pairs):
        seg_img = cv2.imread(seg)
        sam_img = cv2.imread(sam)
        
        if seg_img is None:
            print(f"[WARN] Failed to read seg image: {seg}")
            continue
        if sam_img is None:
            print(f"[WARN] Failed to read sam image: {sam}")
            continue
        
        # Resize if shapes don't match
        if seg_img.shape != sam_img.shape:
            print(f"[INFO] Resizing sam {sam_img.shape} to match seg {seg_img.shape}")
            sam_img = cv2.resize(sam_img, (seg_img.shape[1], seg_img.shape[0]))
        
        img_name = os.path.basename(sam)
        added_img = cv2.bitwise_and(seg_img, sam_img)
        binary_map = cv2.cvtColor(added_img, cv2.COLOR_BGR2GRAY)
        
        nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
            binary_map, None, None, None, 8, cv2.CV_32S
        )
        
        # Get CC_STAT_AREA component as stats[label, COLUMN]
        areas = stats[1:, cv2.CC_STAT_AREA]
        result = np.zeros((labels.shape), np.uint8)
        for i in range(0, nlabels - 1):
            if areas[i] >= 400:   # Keep
                result[labels == i + 1] = 255
        
        output_path = os.path.join(result_path, img_name)
        cv2.imwrite(output_path, result)
        print(f"[INFO] Saved: {output_path}")

if __name__ == '__main__':
    seg_path = '/Users/Admin/ScalpVision/datasets/seg_train'  # mask gốc
    sam_path = '/Users/Admin/ScalpVision/prediction/sam_result/sam_val'  # mask SAM
    result_path = 'prediction/ensemble_result/ensemble_val'  # output mask hợp nhất
    os.makedirs(result_path, exist_ok=True)
    make_final_mask(seg_path, sam_path, result_path)