File size: 5,194 Bytes
69591a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from skimage.morphology import skeletonize
import numpy as np
from skimage.measure import label
from tqdm.contrib.concurrent import thread_map  # or thread_map
def extract_fiber_properties(mask):
    
    binary_mask = mask > 0
    skeleton = skeletonize(binary_mask)
    r = mask == 1
    g = mask == 2
    labeled_skeleton = label(skeleton, connectivity=2)
    properties = {"R": [], "G": [], "ratio": []}
    for i in range(1, labeled_skeleton.max() + 1):
        fiber_mask = labeled_skeleton == i
        sum_r = np.sum(r & fiber_mask)
        sum_g = np.sum(g & fiber_mask)
        if sum_r == 0 or sum_g == 0:
            continue
        properties["R"].append(np.sum(r & fiber_mask))
        properties["G"].append(np.sum(g & fiber_mask))
    
    properties["R"] = np.array(properties["R"])
    properties["G"] = np.array(properties["G"])
    properties["ratio"] = properties["R"] / (properties["G"])
    properties["label"] = labeled_skeleton
    return properties


def filter_non_commons_fibers(properties):
    # Properties is a a list of dicts. For each dict, we have a labelmap and a list of reds, greens and ratios
    # We want to filter out the fibers that are not common in all images

    binary_labels = [p['label'] > 0 for p in properties]
    common_labels = np.logical_and.reduce(binary_labels)
    filtered_properties = {k:[] for k in properties.keys()}
    for i, p in enumerate(properties):
        # We want to keep the labels that are common in all images
        good_labels = common_labels * p['label']
        indices = np.unique(good_labels[good_labels > 0])
        
        filtered_properties.append({
            "R": p["R"][common_labels],
            "G": p["G"][common_labels],
            "ratio": p["ratio"][common_labels],
            "label": p["label"][common_labels]
        })

def skeletonize_mask(mask):
    # Skeletonize the mask and return the skeleton
    binary_mask = mask > 0
    skeleton = skeletonize(binary_mask) * mask
    return skeleton


def skeletonize_data_dict(data_dict):
    skeletons = dict()
    for annotator, images in data_dict.items():
        skeletons[annotator] = dict()
        for image_type, masks in images.items():
            skeletons[annotator][image_type] = thread_map(skeletonize_mask, masks, max_workers=8)

    return skeletons


def extract_properties_from_datadict(data_dict, with_common_analysis=True):
    """

    Extract the properties of the fibers from the data dictionary.

    The data dictionary is a dict of annotators. Each value is a dict of images. Each image is a list of masks.

    """
    properties = dict(annotator=[], image_type=[], red=[], green=[], ratio=[], fiber_type=[])
    all_annotators = list(data_dict.keys())

    found_by = {a: [] for a in all_annotators}
    properties.update(found_by)
    for annotator, images in data_dict.items():
        for image_type, masks in images.items():
            for i, mask in enumerate(masks):
                if with_common_analysis:
                    others_masks = []
                    other_annotators = []
                    for other in all_annotators:
                        if other == annotator:
                            continue
                        other_annotators.append(other)
                        others_masks.append(data_dict[other][image_type][i] > 0)
                
                labels, num = label(mask>0, connectivity=2, return_num=True)
                for l in range(1, num + 1):
                    fiber = labels == l
                    if np.sum(fiber) < 10:
                        continue

                    properties["annotator"].append(annotator)
                    properties["image_type"].append(image_type)

                    # Check for common fibers
                    properties[annotator].append(True)
                    if with_common_analysis:
                        for i, (other_mask, other_annotator) in enumerate(zip(others_masks, other_annotators)):
                            properties[other_annotator].append(np.any(fiber & other_mask))
                        
                    red_length = np.sum(fiber & (mask == 1))
                    green_length = np.sum(fiber & (mask == 2))
                    if red_length == 0 or green_length == 0:
                        continue
                    properties["ratio"].append(green_length / (red_length + 1e-7))  # Avoid division by zero
                    properties["red"].append(red_length)
                    properties["green"].append(green_length)
    
                    segments, count = label(mask[fiber], connectivity=1, return_num=True)
                    if count == 1:
                        properties["fiber_type"].append("single")
                    elif count == 2:
                        properties["fiber_type"].append("double")
                    elif count > 2:
                        properties["fiber_type"].append("multiple")
                    else:
                        properties["fiber_type"].append("unknown")
                
    return properties