import cv2 import glob from tqdm import tqdm import numpy as np import os def make_final_mask(seg_path, sam_path, result_path): # Try multiple extensions for seg_path seg_patterns = [ os.path.join(seg_path, '*.png'), os.path.join(seg_path, '*.jpg'), os.path.join(seg_path, '*.jpeg'), ] seg_full_path = [] for pattern in seg_patterns: files = sorted(glob.glob(pattern)) seg_full_path.extend(files) seg_full_path = sorted(list(set(seg_full_path))) # Remove duplicates # Try multiple extensions for sam_path sam_patterns = [ os.path.join(sam_path, '*.jpg'), os.path.join(sam_path, '*.png'), os.path.join(sam_path, '*.jpeg'), ] sam_full_path = [] for pattern in sam_patterns: files = sorted(glob.glob(pattern)) sam_full_path.extend(files) sam_full_path = sorted(list(set(sam_full_path))) # DEBUG print(f"[DEBUG] seg_path: {seg_path}") print(f"[DEBUG] Found {len(seg_full_path)} seg images") if len(seg_full_path) > 0: print(f"[DEBUG] First 3 seg files: {seg_full_path[:3]}") print(f"[DEBUG] sam_path: {sam_path}") print(f"[DEBUG] Found {len(sam_full_path)} sam images") if len(sam_full_path) > 0: print(f"[DEBUG] First 3 sam files: {sam_full_path[:3]}") if len(seg_full_path) == 0: print(f"[ERROR] No seg images found in {seg_path}") return if len(sam_full_path) == 0: print(f"[ERROR] No sam images found in {sam_path}") return # Match by filename (without extension) seg_dict = {} for path in seg_full_path: basename = os.path.splitext(os.path.basename(path))[0] seg_dict[basename] = path sam_dict = {} for path in sam_full_path: basename = os.path.splitext(os.path.basename(path))[0] sam_dict[basename] = path # Find matching pairs matched_pairs = [] for name in seg_dict.keys(): if name in sam_dict: matched_pairs.append((seg_dict[name], sam_dict[name])) print(f"[INFO] Found {len(matched_pairs)} matching pairs") if len(matched_pairs) == 0: print("[ERROR] No matching pairs found!") print(f"[DEBUG] Seg basenames: {list(seg_dict.keys())[:5]}") print(f"[DEBUG] Sam basenames: {list(sam_dict.keys())[:5]}") return for seg, sam in tqdm(matched_pairs): seg_img = cv2.imread(seg) sam_img = cv2.imread(sam) if seg_img is None: print(f"[WARN] Failed to read seg image: {seg}") continue if sam_img is None: print(f"[WARN] Failed to read sam image: {sam}") continue # Resize if shapes don't match if seg_img.shape != sam_img.shape: print(f"[INFO] Resizing sam {sam_img.shape} to match seg {seg_img.shape}") sam_img = cv2.resize(sam_img, (seg_img.shape[1], seg_img.shape[0])) img_name = os.path.basename(sam) added_img = cv2.bitwise_and(seg_img, sam_img) binary_map = cv2.cvtColor(added_img, cv2.COLOR_BGR2GRAY) nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats( binary_map, None, None, None, 8, cv2.CV_32S ) # Get CC_STAT_AREA component as stats[label, COLUMN] areas = stats[1:, cv2.CC_STAT_AREA] result = np.zeros((labels.shape), np.uint8) for i in range(0, nlabels - 1): if areas[i] >= 400: # Keep result[labels == i + 1] = 255 output_path = os.path.join(result_path, img_name) cv2.imwrite(output_path, result) print(f"[INFO] Saved: {output_path}") if __name__ == '__main__': seg_path = '/Users/Admin/ScalpVision/datasets/seg_train' # mask gốc sam_path = '/Users/Admin/ScalpVision/prediction/sam_result/sam_val' # mask SAM result_path = 'prediction/ensemble_result/ensemble_val' # output mask hợp nhất os.makedirs(result_path, exist_ok=True) make_final_mask(seg_path, sam_path, result_path)