full / p_algorithm_precision.py
caubetotbunggg's picture
Upload folder using huggingface_hub
7bf5a8e verified
#!/usr/bin/env python3
"""
Complete Hair Counting Pipeline:
- Steps 1-7: BSR, Preprocess, Binarize, Thinning, MSLD, PLB, Merge
- Step 8: Relaxation Labeling for clustering line segments
- Step 9: Count hairs
- Visualization
"""
import os
import cv2
import numpy as np
from skimage.morphology import skeletonize
import math
from tqdm import tqdm
import glob
from collections import defaultdict
# ----------------------------- Utilities -----------------------------------
def ensure_dir(p):
os.makedirs(p, exist_ok=True)
# ----------------------------- BSR module ----------------------------------
def bsr_lab_opening(rgb, se_radius=6):
"""
Bright Spot Removal via morphological opening on L channel in LAB color-space.
"""
lab = cv2.cvtColor(rgb, cv2.COLOR_BGR2LAB)
L, A, B = cv2.split(lab)
k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*se_radius+1, 2*se_radius+1))
L_open = cv2.morphologyEx(L, cv2.MORPH_OPEN, k)
L2 = cv2.normalize(L - (L - L_open)//1, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
lab2 = cv2.merge([L2, A, B])
out = cv2.cvtColor(lab2, cv2.COLOR_LAB2BGR)
return out
# ----------------------------- Preprocessing --------------------------------
def preprocess(rgb, bilateral_d=9, bilateral_sigmaColor=75, bilateral_sigmaSpace=75):
b = cv2.bilateralFilter(rgb, bilateral_d, bilateral_sigmaColor, bilateral_sigmaSpace)
return b
# ----------------------------- Binarization ---------------------------------
def binarize(img_gray, morph_radius=3):
blur = cv2.GaussianBlur(img_gray, (5,5), 0)
_, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*morph_radius+1, 2*morph_radius+1))
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, k, iterations=1)
th = cv2.morphologyEx(th, cv2.MORPH_OPEN, k, iterations=1)
return th
# ----------------------------- Thinning ------------------------------------
def thin_mask(binary_u8):
bw = (binary_u8 > 0)
skel = skeletonize(bw)
return (skel.astype(np.uint8) * 255)
# ----------------------------- MSLD (multi-scale Hough) --------------------
def multi_scale_hough(edges, scale_factors=[1.0, 0.75, 0.5], hough_params=None):
"""
Run Probabilistic HoughLinesP at multiple scales.
"""
if hough_params is None:
hough_params = {'rho':1, 'theta':np.pi/180, 'threshold':30, 'minLineLength':20, 'maxLineGap':20}
lines_all = []
h, w = edges.shape
for s in scale_factors:
if s != 1.0:
small = cv2.resize(edges, (int(w*s), int(h*s)), interpolation=cv2.INTER_LINEAR)
else:
small = edges
lines = cv2.HoughLinesP(small, hough_params['rho'], hough_params['theta'],
hough_params['threshold'],
minLineLength=max(8, int(hough_params['minLineLength']*s)),
maxLineGap=max(1, int(hough_params['maxLineGap']*s)))
if lines is None:
continue
for l in lines:
x1,y1,x2,y2 = l[0]
if s != 1.0:
x1 = int(round(x1 / s)); y1 = int(round(y1 / s))
x2 = int(round(x2 / s)); y2 = int(round(y2 / s))
lines_all.append((x1,y1,x2,y2))
# Deduplicate
unique = []
def close(a,b, tol=6):
return abs(a[0]-b[0])<=tol and abs(a[1]-b[1])<=tol and abs(a[2]-b[2])<=tol and abs(a[3]-b[3])<=tol
for l in lines_all:
if not any(close(l, u) or close(l, u[::-1]) for u in unique):
unique.append(l)
return unique
# ----------------------------- PLB: Parallel Line Bundling -------------------
def line_to_abcline(line):
x1,y1,x2,y2 = line
dx = x2 - x1; dy = y2 - y1
if dx==0 and dy==0:
return None
a = dy; b = -dx
norm = math.hypot(a,b)
a /= norm; b /= norm
c = -(a*x1 + b*y1)
return (a,b,c)
def line_angle(line):
x1,y1,x2,y2 = line
ang = math.atan2(y2-y1, x2-x1)
return ang
def distance_between_parallel_lines(l1_abc, l2_abc):
a1,b1,c1 = l1_abc; a2,b2,c2 = l2_abc
return abs(c1 - c2)
def seg_projection_on_line(seg, line_dir):
x1,y1,x2,y2 = seg
vx = math.cos(line_dir); vy = math.sin(line_dir)
p1 = x1*vx + y1*vy
p2 = x2*vx + y2*vy
return min(p1,p2), max(p1,p2)
def overlap_segment_length(a1,b1,a2,b2):
left = max(a1,a2); right = min(b1,b2)
return max(0.0, right-left)
def plb_restore(lines, avg_gap=None, gap_thresh_factor=1.15, angle_tol_deg=6, min_overlap_px=10):
"""
Parallel Line Bundling to restore concealed hairs
"""
out_lines = list(lines)
if len(lines) < 2:
return out_lines
gaps = []
abc_list = []
for l in lines:
abc = line_to_abcline(l)
if abc is None: abc_list.append(None)
else: abc_list.append(abc)
for i in range(len(lines)):
for j in range(i+1, len(lines)):
if abc_list[i] is None or abc_list[j] is None: continue
ang_i = line_angle(lines[i]); ang_j = line_angle(lines[j])
if abs((ang_i-ang_j)+math.pi) < 0.001: ang_j += math.pi
angdiff = abs((ang_i - ang_j))
angdiff = min(angdiff, abs(2*math.pi - angdiff))
if angdiff > math.radians(angle_tol_deg):
continue
d = distance_between_parallel_lines(abc_list[i], abc_list[j])
if d <= 0.5:
continue
gaps.append(d)
if len(gaps)>0:
if avg_gap is None:
avg_gap = np.median(gaps)
else:
avg_gap = avg_gap or 8.0
# Pairwise restore
for i in range(len(lines)):
for j in range(i+1, len(lines)):
if abc_list[i] is None or abc_list[j] is None: continue
ang_i = line_angle(lines[i]); ang_j = line_angle(lines[j])
angdiff = abs((ang_i - ang_j))
angdiff = min(angdiff, abs(2*math.pi - angdiff))
if angdiff > math.radians(angle_tol_deg):
continue
d = distance_between_parallel_lines(abc_list[i], abc_list[j])
if d < avg_gap * gap_thresh_factor * 0.7 or d > avg_gap * gap_thresh_factor * 2.5:
continue
dir_ang = 0.5*(ang_i + ang_j)
a1,b1 = seg_projection_on_line(lines[i], dir_ang)
a2,b2 = seg_projection_on_line(lines[j], dir_ang)
ov = overlap_segment_length(a1,b1,a2,b2)
if ov < min_overlap_px:
continue
mid_start = (max(a1,a2))
mid_end = (min(b1,b2))
def point_on_seg_by_proj(seg, proj_val, dir_ang):
x1,y1,x2,y2 = seg
vx = math.cos(dir_ang); vy = math.sin(dir_ang)
p1 = x1*vx + y1*vy; p2 = x2*vx + y2*vy
if p2==p1:
alpha = 0.0
else:
alpha = (proj_val - p1) / (p2 - p1)
alpha = max(0.0, min(1.0, alpha))
return (int(round(x1 + alpha * (x2-x1))), int(round(y1 + alpha * (y2-y1))))
p_start_i = point_on_seg_by_proj(lines[i], mid_start, dir_ang)
p_start_j = point_on_seg_by_proj(lines[j], mid_start, dir_ang)
p_end_i = point_on_seg_by_proj(lines[i], mid_end, dir_ang)
p_end_j = point_on_seg_by_proj(lines[j], mid_end, dir_ang)
mid_start_pt = (int(round(0.5*(p_start_i[0]+p_start_j[0]))), int(round(0.5*(p_start_i[1]+p_start_j[1]))))
mid_end_pt = (int(round(0.5*(p_end_i[0]+p_end_j[0]))), int(round(0.5*(p_end_i[1]+p_end_j[1]))))
out_lines.append((mid_start_pt[0], mid_start_pt[1], mid_end_pt[0], mid_end_pt[1]))
return out_lines
# ----------------------------- lines -> mask --------------------------------
def rasterize_lines_to_mask(lines, shape, thickness=3):
mask = np.zeros(shape[:2], dtype=np.uint8)
for (x1,y1,x2,y2) in lines:
cv2.line(mask, (x1,y1), (x2,y2), color=255, thickness=thickness)
mask = cv2.dilate(mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)), iterations=1)
return mask
# ----------------------------- Relaxation Labeling (IMPROVED) ---------------
def line_to_polar(line):
"""
Convert line segment to polar coordinates (rho, theta)
"""
x1, y1, x2, y2 = line
# Compute angle (orientation) - normalize to [0, pi]
theta = math.atan2(y2 - y1, x2 - x1)
if theta < 0:
theta += math.pi
# Compute rho (perpendicular distance from origin to line)
# Using the standard Hesse normal form
if abs(x2 - x1) < 1e-6 and abs(y2 - y1) < 1e-6:
# Degenerate case
rho = math.hypot(x1, y1)
else:
# Distance from origin (0,0) to the infinite line passing through the segment
rho = abs((y2-y1)*0 - (x2-x1)*0 + x2*y1 - y2*x1) / math.hypot(x2-x1, y2-y1)
return rho, theta
def get_line_midpoint(line):
"""Get midpoint of line segment"""
x1, y1, x2, y2 = line
return ((x1 + x2) / 2.0, (y1 + y2) / 2.0)
def get_line_length(line):
"""Get length of line segment"""
x1, y1, x2, y2 = line
return math.hypot(x2 - x1, y2 - y1)
def distance_between_points(p1, p2):
"""Euclidean distance between two points"""
return math.hypot(p1[0] - p2[0], p1[1] - p2[1])
def point_to_line_distance(point, line):
"""
Minimum distance from a point to a line segment
"""
x1, y1, x2, y2 = line
px, py = point
# Vector from line start to end
dx = x2 - x1
dy = y2 - y1
if dx == 0 and dy == 0:
# Degenerate line segment
return math.hypot(px - x1, py - y1)
# Parameter t for projection of point onto line
t = max(0, min(1, ((px - x1) * dx + (py - y1) * dy) / (dx * dx + dy * dy)))
# Closest point on line segment
closest_x = x1 + t * dx
closest_y = y1 + t * dy
return math.hypot(px - closest_x, py - closest_y)
def line_to_line_distance(line1, line2):
"""
Minimum distance between two line segments
"""
# Check distance from endpoints of line1 to line2
x1, y1, x2, y2 = line1
d1 = point_to_line_distance((x1, y1), line2)
d2 = point_to_line_distance((x2, y2), line2)
# Check distance from endpoints of line2 to line1
x1, y1, x2, y2 = line2
d3 = point_to_line_distance((x1, y1), line1)
d4 = point_to_line_distance((x2, y2), line1)
return min(d1, d2, d3, d4)
def find_neighbors_improved(lines, max_distance=100, max_angle_diff_deg=30):
"""
Find neighboring line segments with improved criteria:
- Close in space (line-to-line distance)
- Similar orientation
"""
n = len(lines)
neighbors = [set() for _ in range(n)]
# Precompute angles
angles = [line_angle(line) for line in lines]
for i in range(n):
for j in range(i + 1, n):
# Check angle similarity
angle_i = angles[i]
angle_j = angles[j]
# Normalize angle difference to [0, pi]
angle_diff = abs(angle_i - angle_j)
if angle_diff > math.pi:
angle_diff = 2 * math.pi - angle_diff
# Also check if they're opposite directions (should still be grouped)
angle_diff = min(angle_diff, math.pi - angle_diff)
if angle_diff > math.radians(max_angle_diff_deg):
continue
# Check spatial proximity (line-to-line distance)
dist = line_to_line_distance(lines[i], lines[j])
if dist <= max_distance:
neighbors[i].add(j)
neighbors[j].add(i)
return neighbors
def agglomerative_clustering(lines, max_cluster_distance=80, max_angle_diff_deg=25):
"""
Simple agglomerative clustering based on:
- Lines that are close in space
- Lines that have similar orientation
This is more robust than Relaxation Labeling for this problem.
"""
n = len(lines)
if n == 0:
return [], 0
# Initialize: each line is its own cluster
clusters = [[i] for i in range(n)]
# Precompute angles
angles = [line_angle(line) for line in lines]
def cluster_angle(cluster_indices):
"""Average angle of lines in a cluster"""
cluster_angles = [angles[i] for i in cluster_indices]
# Use circular mean for angles
x = sum(math.cos(a) for a in cluster_angles)
y = sum(math.sin(a) for a in cluster_angles)
return math.atan2(y, x)
def cluster_center(cluster_indices):
"""Center point of all line midpoints in cluster"""
midpoints = [get_line_midpoint(lines[i]) for i in cluster_indices]
cx = sum(p[0] for p in midpoints) / len(midpoints)
cy = sum(p[1] for p in midpoints) / len(midpoints)
return (cx, cy)
def cluster_distance(c1, c2):
"""
Distance between two clusters based on:
- Spatial distance between centers
- Angle difference
"""
center1 = cluster_center(c1)
center2 = cluster_center(c2)
spatial_dist = distance_between_points(center1, center2)
angle1 = cluster_angle(c1)
angle2 = cluster_angle(c2)
angle_diff = abs(angle1 - angle2)
angle_diff = min(angle_diff, 2 * math.pi - angle_diff, math.pi - angle_diff)
# Combined metric: spatial distance + angle penalty
if angle_diff > math.radians(max_angle_diff_deg):
return float('inf') # Don't merge if angles too different
return spatial_dist
# Agglomerative merging
changed = True
while changed and len(clusters) > 1:
changed = False
best_merge = None
best_dist = max_cluster_distance
# Find best pair to merge
for i in range(len(clusters)):
for j in range(i + 1, len(clusters)):
dist = cluster_distance(clusters[i], clusters[j])
if dist < best_dist:
best_dist = dist
best_merge = (i, j)
changed = True
# Merge best pair
if best_merge:
i, j = best_merge
clusters[i].extend(clusters[j])
del clusters[j]
# Assign labels
labels = [-1] * n
for cluster_id, cluster in enumerate(clusters):
for line_idx in cluster:
labels[line_idx] = cluster_id
return labels, len(clusters)
def relaxation_labeling_improved(lines, max_iterations=30, epsilon=0.7,
max_neighbor_dist=100, max_angle_diff_deg=25,
convergence_threshold=0.001):
"""
Improved Relaxation Labeling with better parameters and compatibility function
"""
n = len(lines)
if n == 0:
return [], 0
# Use improved neighbor finding
neighbors = find_neighbors_improved(lines, max_distance=max_neighbor_dist,
max_angle_diff_deg=max_angle_diff_deg)
# Precompute angles
angles = [line_angle(line) for line in lines]
# Precompute midpoints
midpoints = [get_line_midpoint(line) for line in lines]
# Initialize: Start with connected components as initial labels
# This gives a better initialization than one-label-per-line
visited = [False] * n
initial_labels = [-1] * n
current_label = 0
for start in range(n):
if visited[start]:
continue
# BFS to find connected component
queue = [start]
visited[start] = True
while queue:
i = queue.pop(0)
initial_labels[i] = current_label
for j in neighbors[i]:
if not visited[j]:
visited[j] = True
queue.append(j)
current_label += 1
num_labels = current_label
if num_labels == 0:
# No neighbors found, each line is separate
return list(range(n)), n
# Initialize probability matrix
p = np.zeros((n, num_labels), dtype=np.float64)
for i in range(n):
if initial_labels[i] >= 0:
p[i, initial_labels[i]] = 1.0
else:
p[i, :] = 1.0 / num_labels
# Iterative relaxation
for iteration in range(max_iterations):
q = np.zeros((n, num_labels), dtype=np.float64)
for i in range(n):
for label_i in range(num_labels):
support = 0.0
for j in neighbors[i]:
# Compute compatibility based on angle similarity
angle_diff = abs(angles[i] - angles[j])
angle_diff = min(angle_diff, 2 * math.pi - angle_diff, math.pi - angle_diff)
angle_sim = math.cos(angle_diff) # 1 if same, 0 if perpendicular
# Distance similarity
dist = distance_between_points(midpoints[i], midpoints[j])
dist_sim = math.exp(-dist / max_neighbor_dist) # Exponential decay
# Combined compatibility
compatibility = epsilon * angle_sim + (1 - epsilon) * dist_sim
# Accumulate support for same label
support += compatibility * p[j, label_i]
q[i, label_i] = support
# Update probabilities
p_new = np.zeros_like(p)
for i in range(n):
for label in range(num_labels):
p_new[i, label] = p[i, label] * (1 + q[i, label])
# Normalize
row_sum = np.sum(p_new[i, :])
if row_sum > 1e-10:
p_new[i, :] /= row_sum
else:
p_new[i, :] = p[i, :]
# Check convergence
diff = np.abs(p_new - p).max()
p = p_new
if diff < convergence_threshold:
break
# Assign final labels
final_labels = np.argmax(p, axis=1)
# Renumber to consecutive
unique_labels = np.unique(final_labels)
label_mapping = {old: new for new, old in enumerate(unique_labels)}
final_labels = np.array([label_mapping[label] for label in final_labels])
return final_labels.tolist(), len(unique_labels)
# ----------------------------- Update main pipeline to use improved version ---
def visualize_labeled_lines(rgb, lines, labels, title="Labeled Hairs"):
"""
Visualize line segments colored by their labels (hair clusters)
"""
vis = rgb.copy()
# Generate colors for each unique label
unique_labels = sorted(set(labels))
num_labels = len(unique_labels)
# Create colormap
np.random.seed(42)
colors = []
for i in range(num_labels):
colors.append((
np.random.randint(50, 255),
np.random.randint(50, 255),
np.random.randint(50, 255)
))
# Draw each line with its label color
label_to_color = {label: colors[i] for i, label in enumerate(unique_labels)}
for line, label in zip(lines, labels):
x1, y1, x2, y2 = line
color = label_to_color[label]
cv2.line(vis, (x1, y1), (x2, y2), color, 2)
# Add title
cv2.putText(vis, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
1.0, (255, 255, 255), 3)
cv2.putText(vis, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
1.0, (0, 0, 0), 1)
return vis
def run_complete_pipeline(image_path, out_dir, params, verbose=True):
imname = os.path.basename(image_path)
rgb = cv2.imread(image_path)
if rgb is None:
raise RuntimeError("Cannot open image: " + image_path)
if verbose:
print(f"\nProcessing: {imname}")
print("="*60)
print(" Step 1: BSR (Bright Spot Removal)...")
bsr = bsr_lab_opening(rgb, se_radius=params.get('bsr_se', 6))
if verbose:
print(" Step 2: Preprocessing (Bilateral Filter)...")
prep = preprocess(bsr,
bilateral_d=params.get('bilateral_d',9),
bilateral_sigmaColor=params.get('bilateral_sigmaColor',75),
bilateral_sigmaSpace=params.get('bilateral_sigmaSpace',75))
if verbose:
print(" Step 3: Binarization (Otsu + Morphology)...")
gray = cv2.cvtColor(prep, cv2.COLOR_BGR2GRAY)
binary = binarize(gray, morph_radius=params.get('morph_radius',3))
if verbose:
print(" Step 4: Thinning (Skeletonize)...")
skel = thin_mask(binary)
if verbose:
print(" Step 5: MSLD (Multi-Scale Line Detection)...")
edges = cv2.Canny(gray, 50, 150)
lines = multi_scale_hough(edges, scale_factors=params.get('scales',[1.0,0.75,0.5]),
hough_params=params.get('hough_params', None))
if verbose:
print(f" - Detected {len(lines)} lines")
if verbose:
print(" Step 6: PLB (Parallel Line Bundling)...")
restored_lines = plb_restore(lines, avg_gap=params.get('avg_gap', None),
gap_thresh_factor=params.get('gap_factor', 1.25),
angle_tol_deg=params.get('angle_tol_deg', 6),
min_overlap_px=params.get('min_overlap_px', 12))
if verbose:
print(f" - Restored to {len(restored_lines)} lines")
print(f" - Concealed hairs recovered: {len(restored_lines)-len(lines)}")
if verbose:
print(" Step 7: Merge lines mask with binary...")
lines_mask = rasterize_lines_to_mask(restored_lines, rgb.shape, thickness=params.get('line_thickness',3))
merged_foreground = cv2.bitwise_or(binary, lines_mask)
if verbose:
print(" Step 8: Clustering line segments into hairs...")
# Choose clustering method
clustering_method = params.get('clustering_method', 'agglomerative') # 'agglomerative' or 'relaxation'
if clustering_method == 'agglomerative':
if verbose:
print(" - Using Agglomerative Clustering...")
labels, num_hairs = agglomerative_clustering(
restored_lines,
max_cluster_distance=params.get('cluster_max_dist', 80),
max_angle_diff_deg=params.get('cluster_angle_diff', 25)
)
else:
if verbose:
print(" - Using Relaxation Labeling...")
labels, num_hairs = relaxation_labeling_improved(
restored_lines,
max_iterations=params.get('rl_max_iter', 30),
epsilon=params.get('rl_epsilon', 0.7),
max_neighbor_dist=params.get('rl_neighbor_dist', 100),
max_angle_diff_deg=params.get('rl_angle_diff', 25),
convergence_threshold=params.get('rl_conv_threshold', 0.001)
)
if verbose:
print(f" - Clustered into {num_hairs} hairs")
print("="*60)
# ========== VISUALIZATION ========== (keep same as before)
if verbose:
print(" Creating visualizations...")
lines_vis = rgb.copy()
for (x1,y1,x2,y2) in lines:
cv2.line(lines_vis, (x1,y1), (x2,y2), (0,255,0), 2)
restored_vis = rgb.copy()
for (x1,y1,x2,y2) in restored_lines:
cv2.line(restored_vis, (x1,y1), (x2,y2), (0,255,255), 2)
labeled_vis = visualize_labeled_lines(rgb, restored_lines, labels,
f"Hairs: {num_hairs}")
binary_bgr = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
skel_bgr = cv2.cvtColor(skel, cv2.COLOR_GRAY2BGR)
lines_mask_bgr = cv2.cvtColor(lines_mask, cv2.COLOR_GRAY2BGR)
merged_bgr = cv2.cvtColor(merged_foreground, cv2.COLOR_GRAY2BGR)
edges_bgr = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
target_size = (512, 512)
def resize_and_label(img, text):
resized = cv2.resize(img, target_size)
cv2.putText(resized, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
0.7, (255, 255, 255), 2)
cv2.putText(resized, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
0.7, (0, 0, 0), 1)
return resized
rgb_res = resize_and_label(rgb, "1. Original")
bsr_res = resize_and_label(bsr, "2. BSR")
prep_res = resize_and_label(prep, "3. Preprocessed")
binary_res = resize_and_label(binary_bgr, "4. Binary")
skel_res = resize_and_label(skel_bgr, "5. Skeleton")
edges_res = resize_and_label(edges_bgr, "6. Edges")
lines_vis_res = resize_and_label(lines_vis, f"7. Lines ({len(lines)})")
restored_vis_res = resize_and_label(restored_vis, f"8. PLB ({len(restored_lines)})")
lines_mask_res = resize_and_label(lines_mask_bgr, "9. Lines Mask")
merged_res = resize_and_label(merged_bgr, "10. Merged")
labeled_res = resize_and_label(labeled_vis, f"11. Labeled ({num_hairs})")
count_img = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
count_img[:] = (40, 40, 40)
method_name = "Agglomerative" if clustering_method == 'agglomerative' else "Relaxation"
info_text = [
f"Image: {imname}",
"",
f"Lines detected: {len(lines)}",
f"After PLB: {len(restored_lines)}",
f"Recovered: +{len(restored_lines)-len(lines)}",
"",
f"Hair Count: {num_hairs}",
"",
f"Method: {method_name}",
f"Max distance: {params.get('cluster_max_dist', 80)}",
f"Max angle diff: {params.get('cluster_angle_diff', 25)}°"
]
y_offset = 50
for text in info_text:
cv2.putText(count_img, text, (20, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
y_offset += 35
row1 = np.hstack([rgb_res, bsr_res, prep_res, binary_res])
row2 = np.hstack([skel_res, edges_res, lines_vis_res, restored_vis_res])
row3 = np.hstack([lines_mask_res, merged_res, labeled_res, count_img])
combined = np.vstack([row1, row2, row3])
out_vis_path = os.path.join(out_dir, "complete_pipeline_" + imname)
cv2.imwrite(out_vis_path, combined)
labeled_path = os.path.join(out_dir, "labeled_" + imname)
cv2.imwrite(labeled_path, labeled_vis)
cv2.imwrite(os.path.join(out_dir, "binary_" + imname), binary)
cv2.imwrite(os.path.join(out_dir, "lines_mask_" + imname), lines_mask)
cv2.imwrite(os.path.join(out_dir, "merged_foreground_" + imname), merged_foreground)
if verbose:
print(f" ✓ Visualization saved: {out_vis_path}")
print(f" ✓ Labeled image saved: {labeled_path}")
return {
'image': image_path,
'original': rgb,
'lines': lines,
'restored_lines': restored_lines,
'labels': labels,
'num_hairs': num_hairs,
'binary': binary,
'merged_foreground': merged_foreground,
'vis_path': out_vis_path,
'labeled_path': labeled_path
}
# ----------------------------- Batch processing --------------------------------
def process_folder(input_folder, output_folder, params):
"""
Process all images in a folder
"""
ensure_dir(output_folder)
# Find all image files
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.tif']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(input_folder, ext)))
image_files.extend(glob.glob(os.path.join(input_folder, ext.upper())))
image_files = sorted(list(set(image_files))) # Remove duplicates and sort
if len(image_files) == 0:
print(f"No images found in {input_folder}")
return [], []
print(f"Found {len(image_files)} images in {input_folder}")
print(f"Output will be saved to: {output_folder}\n")
results = []
failed = []
# Process each image with progress bar
for img_path in tqdm(image_files, desc="Processing images"):
try:
result = run_complete_pipeline(img_path, output_folder, params, verbose=False)
results.append(result)
print(f"✓ {os.path.basename(img_path)}: "
f"{len(result['lines'])}{len(result['restored_lines'])} lines → "
f"{result['num_hairs']} hairs")
except Exception as e:
failed.append((img_path, str(e)))
print(f"✗ {os.path.basename(img_path)}: ERROR - {str(e)}")
# Summary statistics
print("\n" + "="*80)
print(f"SUMMARY:")
print(f" Total images: {len(image_files)}")
print(f" Successfully processed: {len(results)}")
print(f" Failed: {len(failed)}")
if len(results) > 0:
total_original = sum(len(r['lines']) for r in results)
total_restored = sum(len(r['restored_lines']) for r in results)
total_hairs = sum(r['num_hairs'] for r in results)
print(f"\n Total lines detected: {total_original}")
print(f" Total after PLB restoration: {total_restored}")
print(f" Total concealed hairs recovered: {total_restored - total_original}")
print(f"\n TOTAL HAIR COUNT: {total_hairs}")
print(f" Average hairs per image: {total_hairs / len(results):.1f}")
# Hair count distribution
hair_counts = [r['num_hairs'] for r in results]
print(f"\n Hair count statistics:")
print(f" Min: {min(hair_counts)}")
print(f" Max: {max(hair_counts)}")
print(f" Mean: {np.mean(hair_counts):.1f}")
print(f" Median: {np.median(hair_counts):.1f}")
print(f" Std: {np.std(hair_counts):.1f}")
if len(failed) > 0:
print(f"\n Failed images:")
for path, error in failed:
print(f" - {os.path.basename(path)}: {error}")
print("="*80)
# Save summary to CSV
if len(results) > 0:
import csv
csv_path = os.path.join(output_folder, "hair_count_summary.csv")
with open(csv_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['Image', 'Lines_Detected', 'Lines_After_PLB',
'Lines_Recovered', 'Hair_Count'])
for r in results:
writer.writerow([
os.path.basename(r['image']),
len(r['lines']),
len(r['restored_lines']),
len(r['restored_lines']) - len(r['lines']),
r['num_hairs']
])
print(f"\n✓ Summary saved to: {csv_path}")
return results, failed
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Complete Hair Counting Pipeline')
parser.add_argument('--image', type=str, required=False)
parser.add_argument('--folder', type=str, default="/Users/Admin/ScalpVision/datasets/data")
parser.add_argument('--out', type=str, default="./complete_pipeline_out")
# Clustering method
parser.add_argument('--method', type=str, default='agglomerative',
choices=['agglomerative', 'relaxation'],
help="Clustering method: agglomerative (recommended) or relaxation")
# Clustering parameters
parser.add_argument('--cluster-dist', type=float, default=200,
help="Max distance for clustering (recommended: 60-100)")
parser.add_argument('--cluster-angle', type=float, default=30,
help="Max angle difference in degrees (recommended: 20-30)")
args = parser.parse_args()
ensure_dir(args.out)
params = {
'bsr_se': 5,
'bilateral_d': 9,
'bilateral_sigmaColor': 75,
'bilateral_sigmaSpace': 75,
'morph_radius': 3,
'scales': [1.0, 0.75, 0.5],
'hough_params': {
'rho': 1,
'theta': np.pi/180,
'threshold': 33,
'minLineLength': 30,
'maxLineGap': 20
},
'avg_gap': None,
'gap_factor': 1.25,
'angle_tol_deg': 6,
'min_overlap_px': 12,
'line_thickness': 3,
# Clustering parameters
'clustering_method': args.method,
'cluster_max_dist': args.cluster_dist,
'cluster_angle_diff': args.cluster_angle,
# Relaxation Labeling (if used)
'rl_max_iter': 30,
'rl_epsilon': 0.7,
'rl_neighbor_dist': 100,
'rl_angle_diff': 25,
'rl_conv_threshold': 0.001
}
if args.image:
result = run_complete_pipeline(args.image, args.out, params, verbose=True)
print(f"\n{'='*60}")
print(f"RESULTS:")
print(f" Lines: {len(result['lines'])}{len(result['restored_lines'])}")
print(f"\n ★ HAIR COUNT: {result['num_hairs']} ★")
print(f"{'='*60}")
else:
results, failed = process_folder(args.folder, args.out, params)