#!/usr/bin/env python3 """ Complete Hair Counting Pipeline: - Steps 1-7: BSR, Preprocess, Binarize, Thinning, MSLD, PLB, Merge - Step 8: Relaxation Labeling for clustering line segments - Step 9: Count hairs - Visualization """ import os import cv2 import numpy as np from skimage.morphology import skeletonize import math from tqdm import tqdm import glob from collections import defaultdict # ----------------------------- Utilities ----------------------------------- def ensure_dir(p): os.makedirs(p, exist_ok=True) # ----------------------------- BSR module ---------------------------------- def bsr_lab_opening(rgb, se_radius=6): """ Bright Spot Removal via morphological opening on L channel in LAB color-space. """ lab = cv2.cvtColor(rgb, cv2.COLOR_BGR2LAB) L, A, B = cv2.split(lab) k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*se_radius+1, 2*se_radius+1)) L_open = cv2.morphologyEx(L, cv2.MORPH_OPEN, k) L2 = cv2.normalize(L - (L - L_open)//1, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) lab2 = cv2.merge([L2, A, B]) out = cv2.cvtColor(lab2, cv2.COLOR_LAB2BGR) return out # ----------------------------- Preprocessing -------------------------------- def preprocess(rgb, bilateral_d=9, bilateral_sigmaColor=75, bilateral_sigmaSpace=75): b = cv2.bilateralFilter(rgb, bilateral_d, bilateral_sigmaColor, bilateral_sigmaSpace) return b # ----------------------------- Binarization --------------------------------- def binarize(img_gray, morph_radius=3): blur = cv2.GaussianBlur(img_gray, (5,5), 0) _, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*morph_radius+1, 2*morph_radius+1)) th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, k, iterations=1) th = cv2.morphologyEx(th, cv2.MORPH_OPEN, k, iterations=1) return th # ----------------------------- Thinning ------------------------------------ def thin_mask(binary_u8): bw = (binary_u8 > 0) skel = skeletonize(bw) return (skel.astype(np.uint8) * 255) # ----------------------------- MSLD (multi-scale Hough) -------------------- def multi_scale_hough(edges, scale_factors=[1.0, 0.75, 0.5], hough_params=None): """ Run Probabilistic HoughLinesP at multiple scales. """ if hough_params is None: hough_params = {'rho':1, 'theta':np.pi/180, 'threshold':30, 'minLineLength':20, 'maxLineGap':20} lines_all = [] h, w = edges.shape for s in scale_factors: if s != 1.0: small = cv2.resize(edges, (int(w*s), int(h*s)), interpolation=cv2.INTER_LINEAR) else: small = edges lines = cv2.HoughLinesP(small, hough_params['rho'], hough_params['theta'], hough_params['threshold'], minLineLength=max(8, int(hough_params['minLineLength']*s)), maxLineGap=max(1, int(hough_params['maxLineGap']*s))) if lines is None: continue for l in lines: x1,y1,x2,y2 = l[0] if s != 1.0: x1 = int(round(x1 / s)); y1 = int(round(y1 / s)) x2 = int(round(x2 / s)); y2 = int(round(y2 / s)) lines_all.append((x1,y1,x2,y2)) # Deduplicate unique = [] def close(a,b, tol=6): return abs(a[0]-b[0])<=tol and abs(a[1]-b[1])<=tol and abs(a[2]-b[2])<=tol and abs(a[3]-b[3])<=tol for l in lines_all: if not any(close(l, u) or close(l, u[::-1]) for u in unique): unique.append(l) return unique # ----------------------------- PLB: Parallel Line Bundling ------------------- def line_to_abcline(line): x1,y1,x2,y2 = line dx = x2 - x1; dy = y2 - y1 if dx==0 and dy==0: return None a = dy; b = -dx norm = math.hypot(a,b) a /= norm; b /= norm c = -(a*x1 + b*y1) return (a,b,c) def line_angle(line): x1,y1,x2,y2 = line ang = math.atan2(y2-y1, x2-x1) return ang def distance_between_parallel_lines(l1_abc, l2_abc): a1,b1,c1 = l1_abc; a2,b2,c2 = l2_abc return abs(c1 - c2) def seg_projection_on_line(seg, line_dir): x1,y1,x2,y2 = seg vx = math.cos(line_dir); vy = math.sin(line_dir) p1 = x1*vx + y1*vy p2 = x2*vx + y2*vy return min(p1,p2), max(p1,p2) def overlap_segment_length(a1,b1,a2,b2): left = max(a1,a2); right = min(b1,b2) return max(0.0, right-left) def plb_restore(lines, avg_gap=None, gap_thresh_factor=1.15, angle_tol_deg=6, min_overlap_px=10): """ Parallel Line Bundling to restore concealed hairs """ out_lines = list(lines) if len(lines) < 2: return out_lines gaps = [] abc_list = [] for l in lines: abc = line_to_abcline(l) if abc is None: abc_list.append(None) else: abc_list.append(abc) for i in range(len(lines)): for j in range(i+1, len(lines)): if abc_list[i] is None or abc_list[j] is None: continue ang_i = line_angle(lines[i]); ang_j = line_angle(lines[j]) if abs((ang_i-ang_j)+math.pi) < 0.001: ang_j += math.pi angdiff = abs((ang_i - ang_j)) angdiff = min(angdiff, abs(2*math.pi - angdiff)) if angdiff > math.radians(angle_tol_deg): continue d = distance_between_parallel_lines(abc_list[i], abc_list[j]) if d <= 0.5: continue gaps.append(d) if len(gaps)>0: if avg_gap is None: avg_gap = np.median(gaps) else: avg_gap = avg_gap or 8.0 # Pairwise restore for i in range(len(lines)): for j in range(i+1, len(lines)): if abc_list[i] is None or abc_list[j] is None: continue ang_i = line_angle(lines[i]); ang_j = line_angle(lines[j]) angdiff = abs((ang_i - ang_j)) angdiff = min(angdiff, abs(2*math.pi - angdiff)) if angdiff > math.radians(angle_tol_deg): continue d = distance_between_parallel_lines(abc_list[i], abc_list[j]) if d < avg_gap * gap_thresh_factor * 0.7 or d > avg_gap * gap_thresh_factor * 2.5: continue dir_ang = 0.5*(ang_i + ang_j) a1,b1 = seg_projection_on_line(lines[i], dir_ang) a2,b2 = seg_projection_on_line(lines[j], dir_ang) ov = overlap_segment_length(a1,b1,a2,b2) if ov < min_overlap_px: continue mid_start = (max(a1,a2)) mid_end = (min(b1,b2)) def point_on_seg_by_proj(seg, proj_val, dir_ang): x1,y1,x2,y2 = seg vx = math.cos(dir_ang); vy = math.sin(dir_ang) p1 = x1*vx + y1*vy; p2 = x2*vx + y2*vy if p2==p1: alpha = 0.0 else: alpha = (proj_val - p1) / (p2 - p1) alpha = max(0.0, min(1.0, alpha)) return (int(round(x1 + alpha * (x2-x1))), int(round(y1 + alpha * (y2-y1)))) p_start_i = point_on_seg_by_proj(lines[i], mid_start, dir_ang) p_start_j = point_on_seg_by_proj(lines[j], mid_start, dir_ang) p_end_i = point_on_seg_by_proj(lines[i], mid_end, dir_ang) p_end_j = point_on_seg_by_proj(lines[j], mid_end, dir_ang) mid_start_pt = (int(round(0.5*(p_start_i[0]+p_start_j[0]))), int(round(0.5*(p_start_i[1]+p_start_j[1])))) mid_end_pt = (int(round(0.5*(p_end_i[0]+p_end_j[0]))), int(round(0.5*(p_end_i[1]+p_end_j[1])))) out_lines.append((mid_start_pt[0], mid_start_pt[1], mid_end_pt[0], mid_end_pt[1])) return out_lines # ----------------------------- lines -> mask -------------------------------- def rasterize_lines_to_mask(lines, shape, thickness=3): mask = np.zeros(shape[:2], dtype=np.uint8) for (x1,y1,x2,y2) in lines: cv2.line(mask, (x1,y1), (x2,y2), color=255, thickness=thickness) mask = cv2.dilate(mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)), iterations=1) return mask # ----------------------------- Relaxation Labeling (IMPROVED) --------------- def line_to_polar(line): """ Convert line segment to polar coordinates (rho, theta) """ x1, y1, x2, y2 = line # Compute angle (orientation) - normalize to [0, pi] theta = math.atan2(y2 - y1, x2 - x1) if theta < 0: theta += math.pi # Compute rho (perpendicular distance from origin to line) # Using the standard Hesse normal form if abs(x2 - x1) < 1e-6 and abs(y2 - y1) < 1e-6: # Degenerate case rho = math.hypot(x1, y1) else: # Distance from origin (0,0) to the infinite line passing through the segment rho = abs((y2-y1)*0 - (x2-x1)*0 + x2*y1 - y2*x1) / math.hypot(x2-x1, y2-y1) return rho, theta def get_line_midpoint(line): """Get midpoint of line segment""" x1, y1, x2, y2 = line return ((x1 + x2) / 2.0, (y1 + y2) / 2.0) def get_line_length(line): """Get length of line segment""" x1, y1, x2, y2 = line return math.hypot(x2 - x1, y2 - y1) def distance_between_points(p1, p2): """Euclidean distance between two points""" return math.hypot(p1[0] - p2[0], p1[1] - p2[1]) def point_to_line_distance(point, line): """ Minimum distance from a point to a line segment """ x1, y1, x2, y2 = line px, py = point # Vector from line start to end dx = x2 - x1 dy = y2 - y1 if dx == 0 and dy == 0: # Degenerate line segment return math.hypot(px - x1, py - y1) # Parameter t for projection of point onto line t = max(0, min(1, ((px - x1) * dx + (py - y1) * dy) / (dx * dx + dy * dy))) # Closest point on line segment closest_x = x1 + t * dx closest_y = y1 + t * dy return math.hypot(px - closest_x, py - closest_y) def line_to_line_distance(line1, line2): """ Minimum distance between two line segments """ # Check distance from endpoints of line1 to line2 x1, y1, x2, y2 = line1 d1 = point_to_line_distance((x1, y1), line2) d2 = point_to_line_distance((x2, y2), line2) # Check distance from endpoints of line2 to line1 x1, y1, x2, y2 = line2 d3 = point_to_line_distance((x1, y1), line1) d4 = point_to_line_distance((x2, y2), line1) return min(d1, d2, d3, d4) def find_neighbors_improved(lines, max_distance=100, max_angle_diff_deg=30): """ Find neighboring line segments with improved criteria: - Close in space (line-to-line distance) - Similar orientation """ n = len(lines) neighbors = [set() for _ in range(n)] # Precompute angles angles = [line_angle(line) for line in lines] for i in range(n): for j in range(i + 1, n): # Check angle similarity angle_i = angles[i] angle_j = angles[j] # Normalize angle difference to [0, pi] angle_diff = abs(angle_i - angle_j) if angle_diff > math.pi: angle_diff = 2 * math.pi - angle_diff # Also check if they're opposite directions (should still be grouped) angle_diff = min(angle_diff, math.pi - angle_diff) if angle_diff > math.radians(max_angle_diff_deg): continue # Check spatial proximity (line-to-line distance) dist = line_to_line_distance(lines[i], lines[j]) if dist <= max_distance: neighbors[i].add(j) neighbors[j].add(i) return neighbors def agglomerative_clustering(lines, max_cluster_distance=80, max_angle_diff_deg=25): """ Simple agglomerative clustering based on: - Lines that are close in space - Lines that have similar orientation This is more robust than Relaxation Labeling for this problem. """ n = len(lines) if n == 0: return [], 0 # Initialize: each line is its own cluster clusters = [[i] for i in range(n)] # Precompute angles angles = [line_angle(line) for line in lines] def cluster_angle(cluster_indices): """Average angle of lines in a cluster""" cluster_angles = [angles[i] for i in cluster_indices] # Use circular mean for angles x = sum(math.cos(a) for a in cluster_angles) y = sum(math.sin(a) for a in cluster_angles) return math.atan2(y, x) def cluster_center(cluster_indices): """Center point of all line midpoints in cluster""" midpoints = [get_line_midpoint(lines[i]) for i in cluster_indices] cx = sum(p[0] for p in midpoints) / len(midpoints) cy = sum(p[1] for p in midpoints) / len(midpoints) return (cx, cy) def cluster_distance(c1, c2): """ Distance between two clusters based on: - Spatial distance between centers - Angle difference """ center1 = cluster_center(c1) center2 = cluster_center(c2) spatial_dist = distance_between_points(center1, center2) angle1 = cluster_angle(c1) angle2 = cluster_angle(c2) angle_diff = abs(angle1 - angle2) angle_diff = min(angle_diff, 2 * math.pi - angle_diff, math.pi - angle_diff) # Combined metric: spatial distance + angle penalty if angle_diff > math.radians(max_angle_diff_deg): return float('inf') # Don't merge if angles too different return spatial_dist # Agglomerative merging changed = True while changed and len(clusters) > 1: changed = False best_merge = None best_dist = max_cluster_distance # Find best pair to merge for i in range(len(clusters)): for j in range(i + 1, len(clusters)): dist = cluster_distance(clusters[i], clusters[j]) if dist < best_dist: best_dist = dist best_merge = (i, j) changed = True # Merge best pair if best_merge: i, j = best_merge clusters[i].extend(clusters[j]) del clusters[j] # Assign labels labels = [-1] * n for cluster_id, cluster in enumerate(clusters): for line_idx in cluster: labels[line_idx] = cluster_id return labels, len(clusters) def relaxation_labeling_improved(lines, max_iterations=30, epsilon=0.7, max_neighbor_dist=100, max_angle_diff_deg=25, convergence_threshold=0.001): """ Improved Relaxation Labeling with better parameters and compatibility function """ n = len(lines) if n == 0: return [], 0 # Use improved neighbor finding neighbors = find_neighbors_improved(lines, max_distance=max_neighbor_dist, max_angle_diff_deg=max_angle_diff_deg) # Precompute angles angles = [line_angle(line) for line in lines] # Precompute midpoints midpoints = [get_line_midpoint(line) for line in lines] # Initialize: Start with connected components as initial labels # This gives a better initialization than one-label-per-line visited = [False] * n initial_labels = [-1] * n current_label = 0 for start in range(n): if visited[start]: continue # BFS to find connected component queue = [start] visited[start] = True while queue: i = queue.pop(0) initial_labels[i] = current_label for j in neighbors[i]: if not visited[j]: visited[j] = True queue.append(j) current_label += 1 num_labels = current_label if num_labels == 0: # No neighbors found, each line is separate return list(range(n)), n # Initialize probability matrix p = np.zeros((n, num_labels), dtype=np.float64) for i in range(n): if initial_labels[i] >= 0: p[i, initial_labels[i]] = 1.0 else: p[i, :] = 1.0 / num_labels # Iterative relaxation for iteration in range(max_iterations): q = np.zeros((n, num_labels), dtype=np.float64) for i in range(n): for label_i in range(num_labels): support = 0.0 for j in neighbors[i]: # Compute compatibility based on angle similarity angle_diff = abs(angles[i] - angles[j]) angle_diff = min(angle_diff, 2 * math.pi - angle_diff, math.pi - angle_diff) angle_sim = math.cos(angle_diff) # 1 if same, 0 if perpendicular # Distance similarity dist = distance_between_points(midpoints[i], midpoints[j]) dist_sim = math.exp(-dist / max_neighbor_dist) # Exponential decay # Combined compatibility compatibility = epsilon * angle_sim + (1 - epsilon) * dist_sim # Accumulate support for same label support += compatibility * p[j, label_i] q[i, label_i] = support # Update probabilities p_new = np.zeros_like(p) for i in range(n): for label in range(num_labels): p_new[i, label] = p[i, label] * (1 + q[i, label]) # Normalize row_sum = np.sum(p_new[i, :]) if row_sum > 1e-10: p_new[i, :] /= row_sum else: p_new[i, :] = p[i, :] # Check convergence diff = np.abs(p_new - p).max() p = p_new if diff < convergence_threshold: break # Assign final labels final_labels = np.argmax(p, axis=1) # Renumber to consecutive unique_labels = np.unique(final_labels) label_mapping = {old: new for new, old in enumerate(unique_labels)} final_labels = np.array([label_mapping[label] for label in final_labels]) return final_labels.tolist(), len(unique_labels) # ----------------------------- Update main pipeline to use improved version --- def visualize_labeled_lines(rgb, lines, labels, title="Labeled Hairs"): """ Visualize line segments colored by their labels (hair clusters) """ vis = rgb.copy() # Generate colors for each unique label unique_labels = sorted(set(labels)) num_labels = len(unique_labels) # Create colormap np.random.seed(42) colors = [] for i in range(num_labels): colors.append(( np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255) )) # Draw each line with its label color label_to_color = {label: colors[i] for i, label in enumerate(unique_labels)} for line, label in zip(lines, labels): x1, y1, x2, y2 = line color = label_to_color[label] cv2.line(vis, (x1, y1), (x2, y2), color, 2) # Add title cv2.putText(vis, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 3) cv2.putText(vis, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0), 1) return vis def run_complete_pipeline(image_path, out_dir, params, verbose=True): imname = os.path.basename(image_path) rgb = cv2.imread(image_path) if rgb is None: raise RuntimeError("Cannot open image: " + image_path) if verbose: print(f"\nProcessing: {imname}") print("="*60) print(" Step 1: BSR (Bright Spot Removal)...") bsr = bsr_lab_opening(rgb, se_radius=params.get('bsr_se', 6)) if verbose: print(" Step 2: Preprocessing (Bilateral Filter)...") prep = preprocess(bsr, bilateral_d=params.get('bilateral_d',9), bilateral_sigmaColor=params.get('bilateral_sigmaColor',75), bilateral_sigmaSpace=params.get('bilateral_sigmaSpace',75)) if verbose: print(" Step 3: Binarization (Otsu + Morphology)...") gray = cv2.cvtColor(prep, cv2.COLOR_BGR2GRAY) binary = binarize(gray, morph_radius=params.get('morph_radius',3)) if verbose: print(" Step 4: Thinning (Skeletonize)...") skel = thin_mask(binary) if verbose: print(" Step 5: MSLD (Multi-Scale Line Detection)...") edges = cv2.Canny(gray, 50, 150) lines = multi_scale_hough(edges, scale_factors=params.get('scales',[1.0,0.75,0.5]), hough_params=params.get('hough_params', None)) if verbose: print(f" - Detected {len(lines)} lines") if verbose: print(" Step 6: PLB (Parallel Line Bundling)...") restored_lines = plb_restore(lines, avg_gap=params.get('avg_gap', None), gap_thresh_factor=params.get('gap_factor', 1.25), angle_tol_deg=params.get('angle_tol_deg', 6), min_overlap_px=params.get('min_overlap_px', 12)) if verbose: print(f" - Restored to {len(restored_lines)} lines") print(f" - Concealed hairs recovered: {len(restored_lines)-len(lines)}") if verbose: print(" Step 7: Merge lines mask with binary...") lines_mask = rasterize_lines_to_mask(restored_lines, rgb.shape, thickness=params.get('line_thickness',3)) merged_foreground = cv2.bitwise_or(binary, lines_mask) if verbose: print(" Step 8: Clustering line segments into hairs...") # Choose clustering method clustering_method = params.get('clustering_method', 'agglomerative') # 'agglomerative' or 'relaxation' if clustering_method == 'agglomerative': if verbose: print(" - Using Agglomerative Clustering...") labels, num_hairs = agglomerative_clustering( restored_lines, max_cluster_distance=params.get('cluster_max_dist', 80), max_angle_diff_deg=params.get('cluster_angle_diff', 25) ) else: if verbose: print(" - Using Relaxation Labeling...") labels, num_hairs = relaxation_labeling_improved( restored_lines, max_iterations=params.get('rl_max_iter', 30), epsilon=params.get('rl_epsilon', 0.7), max_neighbor_dist=params.get('rl_neighbor_dist', 100), max_angle_diff_deg=params.get('rl_angle_diff', 25), convergence_threshold=params.get('rl_conv_threshold', 0.001) ) if verbose: print(f" - Clustered into {num_hairs} hairs") print("="*60) # ========== VISUALIZATION ========== (keep same as before) if verbose: print(" Creating visualizations...") lines_vis = rgb.copy() for (x1,y1,x2,y2) in lines: cv2.line(lines_vis, (x1,y1), (x2,y2), (0,255,0), 2) restored_vis = rgb.copy() for (x1,y1,x2,y2) in restored_lines: cv2.line(restored_vis, (x1,y1), (x2,y2), (0,255,255), 2) labeled_vis = visualize_labeled_lines(rgb, restored_lines, labels, f"Hairs: {num_hairs}") binary_bgr = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR) skel_bgr = cv2.cvtColor(skel, cv2.COLOR_GRAY2BGR) lines_mask_bgr = cv2.cvtColor(lines_mask, cv2.COLOR_GRAY2BGR) merged_bgr = cv2.cvtColor(merged_foreground, cv2.COLOR_GRAY2BGR) edges_bgr = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR) target_size = (512, 512) def resize_and_label(img, text): resized = cv2.resize(img, target_size) cv2.putText(resized, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) cv2.putText(resized, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1) return resized rgb_res = resize_and_label(rgb, "1. Original") bsr_res = resize_and_label(bsr, "2. BSR") prep_res = resize_and_label(prep, "3. Preprocessed") binary_res = resize_and_label(binary_bgr, "4. Binary") skel_res = resize_and_label(skel_bgr, "5. Skeleton") edges_res = resize_and_label(edges_bgr, "6. Edges") lines_vis_res = resize_and_label(lines_vis, f"7. Lines ({len(lines)})") restored_vis_res = resize_and_label(restored_vis, f"8. PLB ({len(restored_lines)})") lines_mask_res = resize_and_label(lines_mask_bgr, "9. Lines Mask") merged_res = resize_and_label(merged_bgr, "10. Merged") labeled_res = resize_and_label(labeled_vis, f"11. Labeled ({num_hairs})") count_img = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8) count_img[:] = (40, 40, 40) method_name = "Agglomerative" if clustering_method == 'agglomerative' else "Relaxation" info_text = [ f"Image: {imname}", "", f"Lines detected: {len(lines)}", f"After PLB: {len(restored_lines)}", f"Recovered: +{len(restored_lines)-len(lines)}", "", f"Hair Count: {num_hairs}", "", f"Method: {method_name}", f"Max distance: {params.get('cluster_max_dist', 80)}", f"Max angle diff: {params.get('cluster_angle_diff', 25)}°" ] y_offset = 50 for text in info_text: cv2.putText(count_img, text, (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) y_offset += 35 row1 = np.hstack([rgb_res, bsr_res, prep_res, binary_res]) row2 = np.hstack([skel_res, edges_res, lines_vis_res, restored_vis_res]) row3 = np.hstack([lines_mask_res, merged_res, labeled_res, count_img]) combined = np.vstack([row1, row2, row3]) out_vis_path = os.path.join(out_dir, "complete_pipeline_" + imname) cv2.imwrite(out_vis_path, combined) labeled_path = os.path.join(out_dir, "labeled_" + imname) cv2.imwrite(labeled_path, labeled_vis) cv2.imwrite(os.path.join(out_dir, "binary_" + imname), binary) cv2.imwrite(os.path.join(out_dir, "lines_mask_" + imname), lines_mask) cv2.imwrite(os.path.join(out_dir, "merged_foreground_" + imname), merged_foreground) if verbose: print(f" ✓ Visualization saved: {out_vis_path}") print(f" ✓ Labeled image saved: {labeled_path}") return { 'image': image_path, 'original': rgb, 'lines': lines, 'restored_lines': restored_lines, 'labels': labels, 'num_hairs': num_hairs, 'binary': binary, 'merged_foreground': merged_foreground, 'vis_path': out_vis_path, 'labeled_path': labeled_path } # ----------------------------- Batch processing -------------------------------- def process_folder(input_folder, output_folder, params): """ Process all images in a folder """ ensure_dir(output_folder) # Find all image files image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.tif'] image_files = [] for ext in image_extensions: image_files.extend(glob.glob(os.path.join(input_folder, ext))) image_files.extend(glob.glob(os.path.join(input_folder, ext.upper()))) image_files = sorted(list(set(image_files))) # Remove duplicates and sort if len(image_files) == 0: print(f"No images found in {input_folder}") return [], [] print(f"Found {len(image_files)} images in {input_folder}") print(f"Output will be saved to: {output_folder}\n") results = [] failed = [] # Process each image with progress bar for img_path in tqdm(image_files, desc="Processing images"): try: result = run_complete_pipeline(img_path, output_folder, params, verbose=False) results.append(result) print(f"✓ {os.path.basename(img_path)}: " f"{len(result['lines'])} → {len(result['restored_lines'])} lines → " f"{result['num_hairs']} hairs") except Exception as e: failed.append((img_path, str(e))) print(f"✗ {os.path.basename(img_path)}: ERROR - {str(e)}") # Summary statistics print("\n" + "="*80) print(f"SUMMARY:") print(f" Total images: {len(image_files)}") print(f" Successfully processed: {len(results)}") print(f" Failed: {len(failed)}") if len(results) > 0: total_original = sum(len(r['lines']) for r in results) total_restored = sum(len(r['restored_lines']) for r in results) total_hairs = sum(r['num_hairs'] for r in results) print(f"\n Total lines detected: {total_original}") print(f" Total after PLB restoration: {total_restored}") print(f" Total concealed hairs recovered: {total_restored - total_original}") print(f"\n TOTAL HAIR COUNT: {total_hairs}") print(f" Average hairs per image: {total_hairs / len(results):.1f}") # Hair count distribution hair_counts = [r['num_hairs'] for r in results] print(f"\n Hair count statistics:") print(f" Min: {min(hair_counts)}") print(f" Max: {max(hair_counts)}") print(f" Mean: {np.mean(hair_counts):.1f}") print(f" Median: {np.median(hair_counts):.1f}") print(f" Std: {np.std(hair_counts):.1f}") if len(failed) > 0: print(f"\n Failed images:") for path, error in failed: print(f" - {os.path.basename(path)}: {error}") print("="*80) # Save summary to CSV if len(results) > 0: import csv csv_path = os.path.join(output_folder, "hair_count_summary.csv") with open(csv_path, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(['Image', 'Lines_Detected', 'Lines_After_PLB', 'Lines_Recovered', 'Hair_Count']) for r in results: writer.writerow([ os.path.basename(r['image']), len(r['lines']), len(r['restored_lines']), len(r['restored_lines']) - len(r['lines']), r['num_hairs'] ]) print(f"\n✓ Summary saved to: {csv_path}") return results, failed if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Complete Hair Counting Pipeline') parser.add_argument('--image', type=str, required=False) parser.add_argument('--folder', type=str, default="/Users/Admin/ScalpVision/datasets/data") parser.add_argument('--out', type=str, default="./complete_pipeline_out") # Clustering method parser.add_argument('--method', type=str, default='agglomerative', choices=['agglomerative', 'relaxation'], help="Clustering method: agglomerative (recommended) or relaxation") # Clustering parameters parser.add_argument('--cluster-dist', type=float, default=200, help="Max distance for clustering (recommended: 60-100)") parser.add_argument('--cluster-angle', type=float, default=30, help="Max angle difference in degrees (recommended: 20-30)") args = parser.parse_args() ensure_dir(args.out) params = { 'bsr_se': 5, 'bilateral_d': 9, 'bilateral_sigmaColor': 75, 'bilateral_sigmaSpace': 75, 'morph_radius': 3, 'scales': [1.0, 0.75, 0.5], 'hough_params': { 'rho': 1, 'theta': np.pi/180, 'threshold': 33, 'minLineLength': 30, 'maxLineGap': 20 }, 'avg_gap': None, 'gap_factor': 1.25, 'angle_tol_deg': 6, 'min_overlap_px': 12, 'line_thickness': 3, # Clustering parameters 'clustering_method': args.method, 'cluster_max_dist': args.cluster_dist, 'cluster_angle_diff': args.cluster_angle, # Relaxation Labeling (if used) 'rl_max_iter': 30, 'rl_epsilon': 0.7, 'rl_neighbor_dist': 100, 'rl_angle_diff': 25, 'rl_conv_threshold': 0.001 } if args.image: result = run_complete_pipeline(args.image, args.out, params, verbose=True) print(f"\n{'='*60}") print(f"RESULTS:") print(f" Lines: {len(result['lines'])} → {len(result['restored_lines'])}") print(f"\n ★ HAIR COUNT: {result['num_hairs']} ★") print(f"{'='*60}") else: results, failed = process_folder(args.folder, args.out, params)