|
|
import os |
|
|
import argparse |
|
|
import csv |
|
|
import sys |
|
|
import glob |
|
|
import json |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import warnings |
|
|
from tqdm import tqdm |
|
|
from PIL import Image |
|
|
from skimage import io, transform |
|
|
from torch.autograd import Variable |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from segmentation.data_loader import RescaleT, ToTensorLab, SalObjDataset |
|
|
from segmentation.model import U2NET, U2NETP |
|
|
except ImportError: |
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'segmentation')) |
|
|
from data_loader import RescaleT, ToTensorLab, SalObjDataset |
|
|
from model import U2NET, U2NETP |
|
|
|
|
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.getcwd(), 'alopecia')) |
|
|
|
|
|
|
|
|
|
|
|
class ScalpPipeline: |
|
|
def __init__(self, root_dir=".", pixel_ratio=2.54): |
|
|
self.root_dir = os.path.abspath(root_dir) |
|
|
self.pixel_ratio = pixel_ratio |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.data_dir = os.path.join(self.root_dir, "datasets", "data") |
|
|
self.seg_train_dir = os.path.join(self.root_dir, "datasets", "seg_train") |
|
|
self.sam_val_dir = os.path.join(self.root_dir, "prediction", "sam_result", "sam_val") |
|
|
self.ensemble_val_dir = os.path.join(self.root_dir, "prediction", "ensemble_result", "ensemble_val") |
|
|
self.thickness_result_dir = os.path.join(self.root_dir, "alopecia", "thickness_result") |
|
|
self.count_result_dir = os.path.join(self.root_dir, "alopecia", "count_result") |
|
|
|
|
|
|
|
|
self.u2net_model_path = os.path.join(self.root_dir, "segmentation", "model", "U2NET.pth") |
|
|
self.sam_checkpoint = os.path.join(self.root_dir, "sam_vit_h_4b8939.pth") |
|
|
|
|
|
|
|
|
for d in [self.seg_train_dir, self.sam_val_dir, self.ensemble_val_dir, self.thickness_result_dir, self.count_result_dir]: |
|
|
os.makedirs(d, exist_ok=True) |
|
|
|
|
|
def normPRED(self, d): |
|
|
ma = torch.max(d) |
|
|
mi = torch.min(d) |
|
|
dn = (d-mi)/(ma-mi) |
|
|
return dn |
|
|
|
|
|
def save_output(self, image_name, pred, d_dir): |
|
|
predict = pred |
|
|
predict = predict.squeeze() |
|
|
predict_np = predict.cpu().data.numpy() |
|
|
|
|
|
im = Image.fromarray(predict_np*255).convert('RGB') |
|
|
img_name = image_name.split(os.sep)[-1] |
|
|
image = io.imread(image_name) |
|
|
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) |
|
|
|
|
|
pb_np = np.array(imo) |
|
|
|
|
|
aaa = img_name.split(".") |
|
|
bbb = aaa[0:-1] |
|
|
imidx = bbb[0] |
|
|
for i in range(1,len(bbb)): |
|
|
imidx = imidx + "." + bbb[i] |
|
|
|
|
|
imo.save(os.path.join(d_dir, imidx+'.jpg')) |
|
|
|
|
|
def run_u2net_segmentation(self): |
|
|
print("\n🔹 Running U2NET Segmentation...") |
|
|
model_name = 'u2net' |
|
|
|
|
|
img_name_list = glob.glob(os.path.join(self.data_dir, '*')) |
|
|
if not img_name_list: |
|
|
print(f"No images found in {self.data_dir}") |
|
|
return |
|
|
|
|
|
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, |
|
|
lbl_name_list = [], |
|
|
transform=transforms.Compose([RescaleT(320), |
|
|
ToTensorLab(flag=0)]) |
|
|
) |
|
|
test_salobj_dataloader = DataLoader(test_salobj_dataset, |
|
|
batch_size=1, |
|
|
shuffle=False, |
|
|
num_workers=1) |
|
|
|
|
|
if(model_name=='u2net'): |
|
|
print("...load U2NET---173.6 MB") |
|
|
net = U2NET(3,1) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
net.load_state_dict(torch.load(self.u2net_model_path)) |
|
|
net.cuda() |
|
|
else: |
|
|
net.load_state_dict(torch.load(self.u2net_model_path, map_location='cpu')) |
|
|
net.eval() |
|
|
|
|
|
for i_test, data_test in enumerate(test_salobj_dataloader): |
|
|
print("inferencing:",img_name_list[i_test].split(os.sep)[-1]) |
|
|
|
|
|
inputs_test = data_test['image'] |
|
|
inputs_test = inputs_test.type(torch.FloatTensor) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
inputs_test = Variable(inputs_test.cuda()) |
|
|
else: |
|
|
inputs_test = Variable(inputs_test) |
|
|
|
|
|
d1,d2,d3,d4,d5,d6,d7= net(inputs_test) |
|
|
|
|
|
|
|
|
pred = d1[:,0,:,:] |
|
|
pred = self.normPRED(pred) |
|
|
|
|
|
self.save_output(img_name_list[i_test], pred, self.seg_train_dir) |
|
|
del d1,d2,d3,d4,d5,d6,d7 |
|
|
|
|
|
print("✅ U2NET Segmentation Complete.\n") |
|
|
|
|
|
|
|
|
def nms(self, boxes, thresh): |
|
|
if len(boxes) == 0: |
|
|
return [] |
|
|
pick = [] |
|
|
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] |
|
|
area = (x2 - x1 + 1) * (y2 - y1 + 1) |
|
|
idxs = np.argsort(y2) |
|
|
while len(idxs) > 0: |
|
|
last = len(idxs) - 1 |
|
|
i = idxs[last] |
|
|
pick.append(i) |
|
|
xx1 = np.maximum(x1[i], x1[idxs[:last]]) |
|
|
yy1 = np.maximum(y1[i], y1[idxs[:last]]) |
|
|
xx2 = np.minimum(x2[i], x2[idxs[:last]]) |
|
|
yy2 = np.minimum(y2[i], y2[idxs[:last]]) |
|
|
w = np.maximum(0, xx2 - xx1 + 1) |
|
|
h = np.maximum(0, yy2 - yy1 + 1) |
|
|
overlap = (w * h) / area[idxs[:last]] |
|
|
idxs = np.delete(idxs, np.concatenate(([last], np.where(overlap > thresh)[0]))) |
|
|
return boxes[pick] |
|
|
|
|
|
def cluster(self, img_path, im, save_dir): |
|
|
img = cv2.imread(img_path) |
|
|
imgray = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) |
|
|
ret, binary_map = cv2.threshold(imgray, 127, 255, 0) |
|
|
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_map, None, None, None, 8, cv2.CV_32S) |
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
result = np.zeros((labels.shape), np.uint8) |
|
|
for i in range(0, nlabels - 1): |
|
|
if areas[i] >= 250: |
|
|
result[labels == i + 1] = 255 |
|
|
re_copy = result.copy() |
|
|
edgeimg = cv2.Canny(result, 10, 150) |
|
|
skel = np.zeros(result.shape, np.uint8) |
|
|
element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) |
|
|
while True: |
|
|
open_ = cv2.morphologyEx(result, cv2.MORPH_OPEN, element) |
|
|
temp = cv2.subtract(result, open_) |
|
|
eroded = cv2.erode(result, element) |
|
|
skel = cv2.bitwise_or(skel, temp) |
|
|
result = eroded.copy() |
|
|
if cv2.countNonZero(result) == 0: |
|
|
break |
|
|
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(skel, None, None, None, 8, cv2.CV_32S) |
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
skel = np.zeros((labels.shape), np.uint8) |
|
|
for i in range(0, nlabels - 1): |
|
|
if areas[i] >= 2: |
|
|
skel[labels == i + 1] = 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
white_pixels = np.where(skel == 255) |
|
|
x_coords, y_coords = white_pixels[1], white_pixels[0] |
|
|
filter_size = (10, 10) |
|
|
x1 = x_coords - filter_size[0] // 2 |
|
|
y1 = y_coords - filter_size[1] // 2 |
|
|
x2 = x_coords + filter_size[0] // 2 |
|
|
y2 = y_coords + filter_size[1] // 2 |
|
|
white_regions = np.column_stack((x1, y1, x2, y2)) |
|
|
white_regions = self.nms(white_regions, thresh=0.1) |
|
|
|
|
|
center_points = [] |
|
|
def get_direction2(bbox_pixels): |
|
|
nonzero_indices = np.column_stack(np.nonzero(bbox_pixels)) |
|
|
nonzero_indices = np.float32(nonzero_indices) |
|
|
if len(nonzero_indices) >= 2: |
|
|
mean, eigenvectors = cv2.PCACompute(nonzero_indices, mean=None) |
|
|
cntr = ((mean[0, 1]), (mean[0, 0])) |
|
|
return eigenvectors[0], cntr |
|
|
else: |
|
|
return (0, 0), (0, 0) |
|
|
|
|
|
for coor in white_regions: |
|
|
x1, y1, x2, y2 = coor |
|
|
bbox_pixels = skel[int(y1):int(y2), int(x1):int(x2)] |
|
|
direction, mean = get_direction2(bbox_pixels) |
|
|
center_points.append((mean[0] + x1, mean[1] + y1)) |
|
|
|
|
|
pts_group, bbox_group = [], [] |
|
|
for idx, pts in enumerate(center_points): |
|
|
if 640 > pts[0] > 0 and 480 > pts[1] > 0: |
|
|
pts_group.append([int(pts[0]), int(pts[1])]) |
|
|
x1, y1, x2, y2 = white_regions[idx] |
|
|
bbox_group.append([int(x1), int(y1), int(x2), int(y2)]) |
|
|
return pts_group, bbox_group |
|
|
|
|
|
def generate_sam_guides(self): |
|
|
print("\n🔹 Generating SAM Guides (Points/BBox)...") |
|
|
mask_dir = self.seg_train_dir |
|
|
save_json_dir = os.path.join(self.root_dir, "datasets") |
|
|
save_img_dir = os.path.join(save_json_dir, "output") |
|
|
os.makedirs(save_img_dir, exist_ok=True) |
|
|
|
|
|
patterns = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG'] |
|
|
files = [] |
|
|
for p in patterns: |
|
|
files.extend(glob.glob(os.path.join(mask_dir, p))) |
|
|
files = sorted(set(files)) |
|
|
print(f"Found {len(files)} files in {mask_dir}") |
|
|
|
|
|
file_dict = {} |
|
|
bbox_dict = {} |
|
|
|
|
|
for filepath in tqdm(files): |
|
|
filename = os.path.basename(filepath) |
|
|
pts, bbox = self.cluster(filepath, filename, save_img_dir) |
|
|
if len(pts) != 0: |
|
|
file_dict[filename] = pts |
|
|
bbox_dict[filename] = bbox |
|
|
|
|
|
with open(os.path.join(save_json_dir, 'train_seg_points.json'), 'w') as json_file: |
|
|
json.dump(file_dict, json_file) |
|
|
with open(os.path.join(save_json_dir, 'train_bbox_points.json'), 'w') as json_file: |
|
|
json.dump(bbox_dict, json_file) |
|
|
|
|
|
print("✅ SAM Guides Generated.\n") |
|
|
|
|
|
def run_sam_prediction(self): |
|
|
print("\n🔹 Running SAM Prediction...") |
|
|
points_file = os.path.join(self.root_dir, 'datasets', 'train_seg_points.json') |
|
|
if not os.path.exists(points_file): |
|
|
print(f"Points file not found: {points_file}") |
|
|
return |
|
|
|
|
|
with open(points_file, 'r') as f: |
|
|
points = json.load(f) |
|
|
|
|
|
model_type = "vit_h" |
|
|
sam = sam_model_registry[model_type](checkpoint=self.sam_checkpoint) |
|
|
sam.to(device=self.device) |
|
|
predictor = SamPredictor(sam) |
|
|
|
|
|
for full_name in tqdm(points.keys()): |
|
|
name, ext = os.path.splitext(full_name) |
|
|
sample_points = points.get(full_name) or points.get(f'{name}.png') or points.get(f'{name}.jpg') or points.get(f'{name}.jpeg') or [] |
|
|
|
|
|
possible_paths = [ |
|
|
os.path.join(self.data_dir, f'{name}.jpeg'), |
|
|
os.path.join(self.data_dir, f'{name}.jpg'), |
|
|
os.path.join(self.data_dir, f'{name}.png'), |
|
|
] |
|
|
image = None |
|
|
for p in possible_paths: |
|
|
if os.path.isfile(p): |
|
|
image = cv2.imread(p) |
|
|
break |
|
|
if image is None or image.size == 0: |
|
|
continue |
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
predictor.set_image(np.ascontiguousarray(image)) |
|
|
|
|
|
if len(sample_points) == 0: |
|
|
cv2.imwrite(os.path.join(self.sam_val_dir, f"{name}.jpg"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) |
|
|
continue |
|
|
|
|
|
tmp = np.array(sample_points) |
|
|
tmp = tmp[tmp.min(axis=1) > 0] |
|
|
|
|
|
if len(tmp) == 0: |
|
|
continue |
|
|
|
|
|
rand_idx = np.random.choice(len(tmp), max(1, len(tmp)//2), replace=False) |
|
|
input_point = tmp[rand_idx] |
|
|
|
|
|
img_height, img_width = image.shape[:2] |
|
|
neg_list = [] |
|
|
border_width = 50 |
|
|
|
|
|
while len(neg_list) < 10: |
|
|
side = np.random.choice(['top', 'bottom', 'left', 'right']) |
|
|
if side == 'top': |
|
|
xy = [np.random.randint(img_width), np.random.randint(0, border_width)] |
|
|
elif side == 'bottom': |
|
|
xy = [np.random.randint(img_width), np.random.randint(max(0, img_height-border_width), img_height)] |
|
|
elif side == 'left': |
|
|
xy = [np.random.randint(0, border_width), np.random.randint(img_height)] |
|
|
else: |
|
|
xy = [np.random.randint(max(0, img_width-border_width), img_width), np.random.randint(img_height)] |
|
|
|
|
|
if xy not in tmp.tolist(): |
|
|
neg_list.append(xy) |
|
|
|
|
|
neg_arr = np.array(neg_list) |
|
|
final_point = np.append(input_point, neg_arr).reshape(-1, 2) |
|
|
input_label = np.array([0] * len(input_point) + [1] * len(neg_arr)) |
|
|
|
|
|
masks, scores, logits = predictor.predict( |
|
|
point_coords=final_point, |
|
|
point_labels=input_label, |
|
|
multimask_output=True, |
|
|
) |
|
|
|
|
|
sam_mask = masks[np.argmax(scores)] |
|
|
if sam_mask.ndim > 2: |
|
|
sam_mask = sam_mask.squeeze() |
|
|
|
|
|
if sam_mask.shape != (img_height, img_width): |
|
|
sam_mask = cv2.resize(sam_mask.astype(np.uint8), (img_width, img_height)) |
|
|
|
|
|
binary_map = np.where(sam_mask > 0, 0, 255).astype(np.uint8) |
|
|
|
|
|
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats( |
|
|
binary_map, None, None, None, 8, cv2.CV_32S |
|
|
) |
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
result = np.zeros((labels.shape), np.uint8) |
|
|
|
|
|
for i in range(0, nlabels - 1): |
|
|
if areas[i] >= 400: |
|
|
result[labels == i + 1] = 255 |
|
|
|
|
|
save_path = os.path.join(self.sam_val_dir, f"{name}.jpg") |
|
|
cv2.imwrite(save_path, result) |
|
|
|
|
|
print("✅ SAM Prediction Complete.\n") |
|
|
|
|
|
def create_ensemble_mask(self): |
|
|
print("\n🔹 Creating Ensemble Masks...") |
|
|
seg_path = self.seg_train_dir |
|
|
sam_path = self.sam_val_dir |
|
|
result_path = self.ensemble_val_dir |
|
|
|
|
|
seg_patterns = [os.path.join(seg_path, '*.png'), os.path.join(seg_path, '*.jpg'), os.path.join(seg_path, '*.jpeg')] |
|
|
seg_full_path = [] |
|
|
for pattern in seg_patterns: |
|
|
seg_full_path.extend(sorted(glob.glob(pattern))) |
|
|
seg_full_path = sorted(list(set(seg_full_path))) |
|
|
|
|
|
sam_patterns = [os.path.join(sam_path, '*.jpg'), os.path.join(sam_path, '*.png'), os.path.join(sam_path, '*.jpeg')] |
|
|
sam_full_path = [] |
|
|
for pattern in sam_patterns: |
|
|
sam_full_path.extend(sorted(glob.glob(pattern))) |
|
|
sam_full_path = sorted(list(set(sam_full_path))) |
|
|
|
|
|
seg_dict = {os.path.splitext(os.path.basename(p))[0]: p for p in seg_full_path} |
|
|
sam_dict = {os.path.splitext(os.path.basename(p))[0]: p for p in sam_full_path} |
|
|
|
|
|
matched_pairs = [] |
|
|
for name in seg_dict.keys(): |
|
|
if name in sam_dict: |
|
|
matched_pairs.append((seg_dict[name], sam_dict[name])) |
|
|
|
|
|
for seg, sam in tqdm(matched_pairs): |
|
|
seg_img = cv2.imread(seg) |
|
|
sam_img = cv2.imread(sam) |
|
|
|
|
|
if seg_img is None or sam_img is None: |
|
|
continue |
|
|
|
|
|
if seg_img.shape != sam_img.shape: |
|
|
sam_img = cv2.resize(sam_img, (seg_img.shape[1], seg_img.shape[0])) |
|
|
|
|
|
img_name = os.path.basename(sam) |
|
|
added_img = cv2.bitwise_and(seg_img, sam_img) |
|
|
binary_map = cv2.cvtColor(added_img, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats( |
|
|
binary_map, None, None, None, 8, cv2.CV_32S |
|
|
) |
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
result = np.zeros((labels.shape), np.uint8) |
|
|
for i in range(0, nlabels - 1): |
|
|
if areas[i] >= 400: |
|
|
result[labels == i + 1] = 255 |
|
|
|
|
|
cv2.imwrite(os.path.join(result_path, img_name), result) |
|
|
|
|
|
print("✅ Ensemble Masks Created.\n") |
|
|
|
|
|
|
|
|
def calculate_hair_thickness(self): |
|
|
print("\n🔹 Calculating Hair Thickness...") |
|
|
|
|
|
|
|
|
def find_pts_on_line(og, slope, d): |
|
|
cx, cy = og |
|
|
x1 = cx - d / ((1 + slope ** 2) ** 0.5) |
|
|
y1 = cy - slope * cx + x1 * slope |
|
|
if np.isnan(x1) or np.isnan(y1): |
|
|
x1 = y1 = -1 |
|
|
return x1, y1 |
|
|
|
|
|
def find_intersection_points2(center, slope, img, threshold): |
|
|
p2 = p1 = (-1, -1) |
|
|
w, h = img.shape |
|
|
step, searching_len = 100, 50 |
|
|
for d in range(1, step * searching_len): |
|
|
px, py = find_pts_on_line(center, slope, d / step) |
|
|
if (0 < int(px) < h) and (0 < int(py) < w) and img[int(py)][int(px)] > threshold: |
|
|
p1 = (px, py) |
|
|
else: |
|
|
break |
|
|
for d in range(1, step * searching_len): |
|
|
px, py = find_pts_on_line(center, slope, -d / step) |
|
|
if (0 < int(px) < h) and (0 < int(py) < w) and img[int(py)][int(px)] > threshold: |
|
|
p2 = (px, py) |
|
|
else: |
|
|
break |
|
|
dst = 0 if p1 == (-1, -1) or p2 == (-1, -1) else np.linalg.norm(np.asarray(p1) - np.asarray(p2)) |
|
|
return [p1, p2], dst |
|
|
|
|
|
def get_direction2(bbox_pixels): |
|
|
nonzero_indices = np.column_stack(np.nonzero(bbox_pixels)) |
|
|
nonzero_indices = np.float32(nonzero_indices) |
|
|
if len(nonzero_indices) >= 2: |
|
|
mean, eigenvectors = cv2.PCACompute(nonzero_indices, mean=None) |
|
|
cntr = ((mean[0, 1]), (mean[0, 0])) |
|
|
return eigenvectors[0], cntr |
|
|
else: |
|
|
return (0,0), (0,0) |
|
|
|
|
|
img_folder = self.ensemble_val_dir |
|
|
save_path = self.thickness_result_dir |
|
|
|
|
|
for im_path in tqdm(sorted(glob.glob(os.path.join(img_folder, '*.jpg')))): |
|
|
img = cv2.imread(im_path) |
|
|
imgray = cv2.imread(im_path, cv2.IMREAD_GRAYSCALE) |
|
|
img_name = os.path.splitext(os.path.basename(im_path))[0] |
|
|
|
|
|
if np.all(imgray == 255) or np.all(imgray == 0): |
|
|
np.save(os.path.join(save_path, img_name), np.array([])) |
|
|
continue |
|
|
|
|
|
ret, binary_map = cv2.threshold(imgray, 127, 255, 0) |
|
|
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_map, None, None, None, 8, cv2.CV_32S) |
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
result = np.zeros((labels.shape), np.uint8) |
|
|
for i in range(nlabels - 1): |
|
|
if areas[i] >= 250: |
|
|
result[labels == i + 1] = 255 |
|
|
re_copy = result.copy() |
|
|
|
|
|
skel = np.zeros(result.shape, np.uint8) |
|
|
element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3,3)) |
|
|
while True: |
|
|
open_ = cv2.morphologyEx(result, cv2.MORPH_OPEN, element) |
|
|
temp = cv2.subtract(result, open_) |
|
|
eroded = cv2.erode(result, element) |
|
|
skel = cv2.bitwise_or(skel, temp) |
|
|
result = eroded.copy() |
|
|
if cv2.countNonZero(result) == 0: |
|
|
break |
|
|
|
|
|
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(skel, None, None, None, 8, cv2.CV_32S) |
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
skel = np.zeros((labels.shape), np.uint8) |
|
|
for i in range(nlabels - 1): |
|
|
if areas[i] >= 5: |
|
|
skel[labels == i + 1] = 255 |
|
|
|
|
|
filtered_image = cv2.cvtColor(re_copy, cv2.COLOR_GRAY2BGR) |
|
|
filtered_image[skel == 255] = [0, 255, 0] |
|
|
|
|
|
white_pixels = np.where(skel == 255) |
|
|
x_coords, y_coords = white_pixels[1], white_pixels[0] |
|
|
filter_size = (20, 20) |
|
|
x1, y1 = x_coords - filter_size[0]//2, y_coords - filter_size[1]//2 |
|
|
x2, y2 = x_coords + filter_size[0]//2, y_coords + filter_size[1]//2 |
|
|
white_regions = np.column_stack((x1, y1, x2, y2)) |
|
|
white_regions = self.nms(white_regions, thresh=0.1) |
|
|
|
|
|
directions, center_points, thicknesses = [], [], [] |
|
|
|
|
|
for coor in white_regions: |
|
|
x1, y1, x2, y2 = coor |
|
|
bbox_pixels = skel[y1:y2, x1:x2] |
|
|
direction, mean = get_direction2(bbox_pixels) |
|
|
directions.append(direction) |
|
|
center_points.append((mean[0] + x1, mean[1] + y1)) |
|
|
|
|
|
perpendicular_slope = [] |
|
|
for direction in directions: |
|
|
if direction[1] != 0: |
|
|
perpendicular_slope.append(-1 / (direction[0] / direction[1])) |
|
|
else: |
|
|
perpendicular_slope.append(0) |
|
|
|
|
|
for center_point, perp_slope in zip(center_points, perpendicular_slope): |
|
|
intersection, dst = find_intersection_points2(center_point, perp_slope, re_copy, 200) |
|
|
if dst != 0: |
|
|
thicknesses.append(dst * self.pixel_ratio) |
|
|
if intersection[0] != (-1, -1) and intersection[1] != (-1, -1): |
|
|
cv2.line(filtered_image, |
|
|
(int(intersection[0][0]), int(intersection[0][1])), |
|
|
(int(intersection[1][0]), int(intersection[1][1])), |
|
|
(0, 255, 255), 1) |
|
|
for pt in intersection: |
|
|
cv2.circle(filtered_image, (int(pt[0]), int(pt[1])), 3, (0, 0, 255), -1) |
|
|
|
|
|
if len(thicknesses) > 0: |
|
|
avg_thickness = np.mean(thicknesses) |
|
|
cv2.putText(filtered_image, f"Avg thickness: {avg_thickness:.2f} um", |
|
|
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2) |
|
|
|
|
|
save_img_path = os.path.join(save_path, f"{img_name}_vis.png") |
|
|
cv2.imwrite(save_img_path, filtered_image) |
|
|
np.save(os.path.join(save_path, img_name), np.sort(thicknesses)) |
|
|
|
|
|
print("✅ Hair Thickness Calculation Complete.\n") |
|
|
|
|
|
def calculate_hair_count(self): |
|
|
print("\n🔹 Calculating Hair Count...") |
|
|
|
|
|
|
|
|
def load_segment_mask(img_path): |
|
|
if not os.path.exists(img_path): return None |
|
|
img_gray = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) |
|
|
if img_gray is None: return None |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
binary_filtered = cv2.morphologyEx(img_gray, cv2.MORPH_OPEN, kernel) |
|
|
_, binary_filtered = cv2.threshold(binary_filtered, 127, 255, cv2.THRESH_BINARY) |
|
|
return binary_filtered |
|
|
|
|
|
def run_watershed_for_sep(binary_img, original_img, sep_factor): |
|
|
dist_transform = cv2.distanceTransform(binary_img, cv2.DIST_L2, 5) |
|
|
_, sure_fg = cv2.threshold(dist_transform, sep_factor * dist_transform.max(), 255, 0) |
|
|
sure_fg = np.uint8(sure_fg) |
|
|
kernel = np.ones((3,3), np.uint8) |
|
|
sure_bg = cv2.dilate(binary_img, kernel, iterations=3) |
|
|
unknown = cv2.subtract(sure_bg, sure_fg) |
|
|
ret, markers = cv2.connectedComponents(sure_fg) |
|
|
markers = markers + 1 |
|
|
markers[unknown == 255] = 0 |
|
|
if len(original_img.shape) == 2: |
|
|
original_color = cv2.cvtColor(original_img, cv2.COLOR_GRAY2BGR) |
|
|
else: |
|
|
original_color = original_img.copy() |
|
|
markers_w = markers.copy().astype(np.int32) |
|
|
cv2.watershed(original_color, markers_w) |
|
|
return markers_w |
|
|
|
|
|
def apply_watershed_hierarchical(binary_img, original_img, min_area, min_aspect_ratio, min_length, |
|
|
separation_factor=0.2, hierarchy_levels=3): |
|
|
low = max(0.01, separation_factor * 0.7) |
|
|
high = separation_factor * 1.6 |
|
|
if hierarchy_levels <= 1: |
|
|
sep_levels = [separation_factor] |
|
|
else: |
|
|
sep_levels = list(np.linspace(low, high, hierarchy_levels)) |
|
|
|
|
|
markers_levels = [] |
|
|
for s in sep_levels: |
|
|
markers_levels.append(run_watershed_for_sep(binary_img, original_img, s)) |
|
|
|
|
|
current = markers_levels[0].copy().astype(np.int32) |
|
|
next_label = int(current.max()) + 1 |
|
|
|
|
|
def region_props_from_mask(mask_uint8): |
|
|
cnts, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
props = [] |
|
|
for cnt in cnts: |
|
|
area = cv2.contourArea(cnt) |
|
|
if area <= 0: continue |
|
|
if len(cnt) >= 5: |
|
|
try: |
|
|
(x, y), (MA, ma), angle = cv2.fitEllipse(cnt) |
|
|
except: |
|
|
MA = ma = 0 |
|
|
x = y = 0 |
|
|
else: |
|
|
x, y, w, h = cv2.boundingRect(cnt) |
|
|
MA = max(w,h) |
|
|
ma = min(w,h) |
|
|
angle = 0 |
|
|
minor = ma if ma > 0 else 1e-6 |
|
|
aspect = float(max(MA, ma)) / (minor + 1e-6) |
|
|
props.append({ |
|
|
'area': area, |
|
|
'major': max(MA, ma), |
|
|
'minor': minor, |
|
|
'aspect': aspect, |
|
|
'centroid': (float(x), float(y)) if 'x' in locals() else (0,0), |
|
|
'contour': cnt |
|
|
}) |
|
|
return props |
|
|
|
|
|
for lvl in range(1, len(markers_levels)): |
|
|
finer = markers_levels[lvl] |
|
|
new_current = current.copy() |
|
|
unique_parents = np.unique(current) |
|
|
for parent_label in unique_parents: |
|
|
if parent_label <= 1: continue |
|
|
parent_mask = (current == parent_label) |
|
|
if parent_mask.sum() == 0: continue |
|
|
overlapped = finer[parent_mask] |
|
|
child_labels = np.unique(overlapped[(overlapped > 1)]) |
|
|
if len(child_labels) <= 1: continue |
|
|
|
|
|
accepted_children = [] |
|
|
for cl in child_labels: |
|
|
child_mask = np.logical_and(finer == cl, parent_mask) |
|
|
child_mask_uint8 = (child_mask.astype(np.uint8) * 255) |
|
|
props = region_props_from_mask(child_mask_uint8) |
|
|
if len(props) == 0: continue |
|
|
p = max(props, key=lambda x: x['area']) |
|
|
if p['area'] >= min_area and p['major'] >= min_length and p['aspect'] >= min_aspect_ratio: |
|
|
accepted_children.append((child_mask_uint8, p)) |
|
|
if len(accepted_children) >= 2: |
|
|
new_current[parent_mask] = 0 |
|
|
for (cmask_uint8, p) in accepted_children: |
|
|
new_current[cmask_uint8 == 255] = next_label |
|
|
next_label += 1 |
|
|
current = new_current |
|
|
|
|
|
final_labels = current |
|
|
valid_hairs = [] |
|
|
unique_labels = np.unique(final_labels) |
|
|
for label in unique_labels: |
|
|
if label <= 1: continue |
|
|
mask = (final_labels == label).astype(np.uint8) * 255 |
|
|
cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
for cnt in cnts: |
|
|
area = cv2.contourArea(cnt) |
|
|
if area < min_area: continue |
|
|
if len(cnt) < 5: continue |
|
|
try: |
|
|
(x, y), (MA, ma), angle = cv2.fitEllipse(cnt) |
|
|
major_axis = max(MA, ma) |
|
|
minor_axis = min(MA, ma) |
|
|
aspect_ratio = major_axis / (minor_axis + 1e-6) |
|
|
if major_axis >= min_length and aspect_ratio >= min_aspect_ratio: |
|
|
valid_hairs.append({ |
|
|
'centroid': (x, y), |
|
|
'ellipse': ((x, y), (MA, ma), angle), |
|
|
'length': major_axis, |
|
|
'thickness': minor_axis, |
|
|
'area': area, |
|
|
'label': int(label) |
|
|
}) |
|
|
except Exception: |
|
|
continue |
|
|
return len(valid_hairs), valid_hairs |
|
|
|
|
|
def create_visualization(true_original, sam_background, hair_info, filename, save_dir): |
|
|
h, w = true_original.shape[:2] |
|
|
overlay = sam_background.copy() |
|
|
if overlay.shape[:2] != (h, w): |
|
|
overlay = cv2.resize(overlay, (w, h), interpolation=cv2.INTER_LINEAR) |
|
|
for i, info in enumerate(hair_info): |
|
|
cv2.ellipse(overlay, info['ellipse'], (0, 255, 0), 2) |
|
|
cx, cy = map(int, info['centroid']) |
|
|
if w > 300: |
|
|
cv2.putText(overlay, str(i), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1) |
|
|
border = np.zeros((h, 5, 3), dtype=np.uint8) |
|
|
combined = np.hstack([true_original, border, overlay]) |
|
|
header_height = 50 |
|
|
header = np.zeros((header_height, combined.shape[1], 3), dtype=np.uint8) |
|
|
info_text = f"{filename} | Count: {len(hair_info)}" |
|
|
cv2.putText(header, info_text, (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) |
|
|
final_vis = np.vstack([header, combined]) |
|
|
cv2.imwrite(os.path.join(save_dir, f'vis_{filename}'), final_vis) |
|
|
|
|
|
img_folder = self.ensemble_val_dir |
|
|
original_folder = self.data_dir |
|
|
sam_folder = self.ensemble_val_dir |
|
|
save_path = self.count_result_dir |
|
|
|
|
|
min_area = 1500 |
|
|
min_length = 20 |
|
|
min_ratio = 1.0 |
|
|
separation_factor = 0.3 |
|
|
hierarchy_levels = 2 |
|
|
|
|
|
img_names = [] |
|
|
for ext in ['*.jpg', '*.png', '*.jpeg']: |
|
|
full_paths = glob.glob(os.path.join(img_folder, ext)) |
|
|
img_names.extend([os.path.basename(p) for p in full_paths]) |
|
|
|
|
|
results = {} |
|
|
density_results = {} |
|
|
|
|
|
for im in tqdm(img_names, desc="Processing"): |
|
|
segment_path = os.path.join(img_folder, im) |
|
|
original_path = os.path.join(original_folder, im) |
|
|
sam_path_file = os.path.join(sam_folder, im) |
|
|
|
|
|
if not os.path.exists(segment_path): continue |
|
|
binary = load_segment_mask(segment_path) |
|
|
if binary is None: continue |
|
|
true_original = cv2.imread(original_path) |
|
|
if true_original is None: |
|
|
true_original = np.zeros((binary.shape[0], binary.shape[1], 3), dtype=np.uint8) |
|
|
sam_background = cv2.imread(sam_path_file) |
|
|
if sam_background is None: |
|
|
sam_background = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
hair_count, hair_info = apply_watershed_hierarchical( |
|
|
binary, |
|
|
true_original, |
|
|
min_area=min_area, |
|
|
min_aspect_ratio=min_ratio, |
|
|
min_length=min_length, |
|
|
separation_factor=separation_factor, |
|
|
hierarchy_levels=hierarchy_levels |
|
|
) |
|
|
|
|
|
density_data = { |
|
|
'count': hair_count, |
|
|
'avg_thickness': float(np.mean([h['thickness'] for h in hair_info]) if hair_info else 0), |
|
|
'avg_length': float(np.mean([h['length'] for h in hair_info]) if hair_info else 0) |
|
|
} |
|
|
|
|
|
if hair_count > 0 or density_data: |
|
|
results[im] = hair_count |
|
|
density_results[im] = density_data |
|
|
|
|
|
vis_dir = os.path.join(save_path, 'visualizations') |
|
|
os.makedirs(vis_dir, exist_ok=True) |
|
|
create_visualization(true_original, sam_background, hair_info, im, vis_dir) |
|
|
|
|
|
csv_path = os.path.join(save_path, 'hair_count.csv') |
|
|
with open(csv_path, 'w', newline='') as f: |
|
|
w = csv.writer(f) |
|
|
w.writerow(['image_name', 'hair_count']) |
|
|
for k, v in results.items(): |
|
|
w.writerow([k, v]) |
|
|
|
|
|
json_path = os.path.join(save_path, 'density.json') |
|
|
with open(json_path, 'w') as f: |
|
|
json.dump(density_results, f, indent=2) |
|
|
|
|
|
print("✅ Hair Count Calculation Complete.\n") |
|
|
|
|
|
def run_pipeline(self): |
|
|
print("🚀 Starting ScalpPipeline...") |
|
|
self.run_u2net_segmentation() |
|
|
self.generate_sam_guides() |
|
|
self.run_sam_prediction() |
|
|
self.create_ensemble_mask() |
|
|
self.calculate_hair_thickness() |
|
|
self.calculate_hair_count() |
|
|
print("🎉 Pipeline Completed Successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="ScalpVision Pipeline") |
|
|
parser.add_argument("--root_dir", type=str, default=".", help="Root directory of the project") |
|
|
parser.add_argument("--pixel_ratio", type=float, default=2.54, help="Pixel to micrometer ratio (default: 2.54)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
pipeline = ScalpPipeline(root_dir=args.root_dir, pixel_ratio=args.pixel_ratio) |
|
|
pipeline.run_pipeline() |
|
|
|