Nguyễn Thành Đạt commited on
Commit
036e7c4
·
0 Parent(s):

update code

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Ignore Python cache files
2
+ *.pyc
3
+ __pycache__/
4
+ venv/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11.0
ML_model/__init__.py ADDED
File without changes
ML_model/data/X_train.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8104f7e7cbc69e960b07138448d538edc2d6856d02862a4182466da58ec07ab7
3
+ size 2775817
ML_model/data/clinical-2.csv ADDED
The diff for this file is too large to render. See raw diff
 
ML_model/data/selected_features.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8593b584c411a5f65c5a5265e987f077d54f8e17fbe7ae9b653a69d93b83955b
3
+ size 344
ML_model/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from sklearn.ensemble import RandomForestClassifier
2
+ import numpy as np
3
+ import joblib
4
+ import pandas as pd
5
+ from sklearn.preprocessing import MinMaxScaler
6
+ from lime import lime_tabular
7
+
8
+ class MLModel():
9
+ def __init__(self):
10
+ self.model = joblib.load('./ML_model/weights/xgboost_convnet_best.pkl')
11
+
12
+ self.clinical_data = pd.read_csv("./ML_model/data/clinical-2.csv")
13
+ del self.clinical_data['ID']
14
+ del self.clinical_data['SIDE']
15
+
16
+ self.columns_to_scale = ['AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI', 'KOOS PAIN SCORE']
17
+ self.scaler = MinMaxScaler()
18
+ self.clinical_data[self.columns_to_scale] = self.scaler.fit_transform(self.clinical_data[self.columns_to_scale])
19
+
20
+ self.columns_to_convert = ['FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA', 'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC', 'CREPITUS']
21
+ self.mapping_dict = {}
22
+ for column in self.columns_to_convert:
23
+ self.clinical_data[column], unique_values = pd.factorize(self.clinical_data[column])
24
+ self.mapping_dict[column] = unique_values
25
+
26
+ self.selected_features = joblib.load("./ML_model/data/selected_features.pkl")
27
+ self.X_train = joblib.load("./ML_model/data/X_train.pkl")
28
+ self.selected_feature_index = [self.X_train.columns.get_loc(col) for col in self.selected_features]
29
+
30
+ self.explainer = lime_tabular.LimeTabularExplainer(
31
+ self.X_train[self.selected_features].to_numpy(),
32
+ feature_names=self.selected_features,
33
+ class_names=['0', '1', '2', '3', '4'],
34
+ mode='classification'
35
+ )
36
+
37
+ def get_clinical_data(self, filename):
38
+ row = self.clinical_data.loc[self.clinical_data['FILENAME'] == filename]
39
+ return row.to_numpy()# remove filename col
40
+
41
+ def predict(self, overal_diagnosis, jsws, clinical=None, filename=None):
42
+ if not clinical:
43
+ if not filename:
44
+ raise "Need clinical data or filename in OAI database"
45
+ clinical = self.get_clinical_data(filename)
46
+ assert clinical.shape[0] != 0, "clinical data not found"
47
+ else:
48
+ raise "not implemented yet!"
49
+
50
+ x = np.concatenate((overal_diagnosis, clinical[0, 1:], np.array(jsws)))
51
+
52
+ return self.model.predict(x = x[self.selected_feature_index].reshape(1, -1)) # unsqueeze(0)
53
+
54
+ def predict_explain(self, overal_diagnosis, jsws, clinical=None, filename=None):
55
+ if not clinical:
56
+ if not filename:
57
+ raise "Need clinical data or filename in OAI database"
58
+ clinical = self.get_clinical_data(filename)
59
+ assert clinical.shape[0] != 0, "clinical data not found"
60
+ else:
61
+ raise "not implemented yet!"
62
+
63
+ x = np.concatenate((overal_diagnosis, clinical[0, 1:], np.array(jsws)))[self.selected_feature_index]
64
+
65
+ exp = self.explainer.explain_instance(x, self.model.predict_proba, num_features=len(self.selected_features), top_labels=1)
66
+ return self.model.predict(x.reshape(1, -1))[0], exp
ML_model/weights/xgboost_convnet_best.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ee491112dc64c61cc66b9922262c1a5920db5270dc6cf1220355a737b1c30c0
3
+ size 286017
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Demo XDesco
3
+ emoji: 💻
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: 1.32.2
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.patheffects as path_effects
5
+ from PIL import Image
6
+ import numpy as np
7
+ import os
8
+ import streamlit.components.v1 as components
9
+ import pandas as pd
10
+
11
+ from utils import read_image, combine_mask, combine_prob_jsw, scale_coordinates, get_annotations
12
+ from segmentation.model import Segmenter
13
+ from jsw import get_JSW, calculate_diff, calculate_jsw_info
14
+ from classification.model import Classifier
15
+ from ML_model.model import MLModel
16
+ from gae import AnomalyExtractor
17
+
18
+ import os
19
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
20
+
21
+ st.set_page_config(layout="wide")
22
+
23
+ @st.cache_resource
24
+ def load_models():
25
+ seg_model = Segmenter()
26
+ classif_model = Classifier('convnet')
27
+ ml_model = MLModel()
28
+ anomaly_extractor = AnomalyExtractor()
29
+
30
+ return seg_model, classif_model, ml_model, anomaly_extractor
31
+
32
+
33
+ seg_model, classif_model, ml_model, anomaly_extractor = load_models()
34
+
35
+ st.sidebar.title("Demo xDesco System")
36
+
37
+ uploaded_file = st.sidebar.file_uploader("Choose a knee join x-ray image", type=["jpg", "jpeg", "png"])
38
+
39
+ if uploaded_file is not None:
40
+ file_bytes = uploaded_file.read()
41
+ file_name = uploaded_file.name
42
+ img = read_image(file_bytes)
43
+
44
+ mask = seg_model.segment(img)
45
+
46
+ mask_image = combine_mask(img, mask)
47
+ col1, col2 = st.columns(2)
48
+
49
+
50
+
51
+ left_distances, right_distances, links = get_JSW(mask, dim = 10, verbose = 0)
52
+ left_links, right_links = links
53
+ jsw_m, jsw_mm = calculate_diff(left_distances, right_distances)
54
+ diff_percentage, mean_left, mean_right, side_min, index_min, value_min = calculate_jsw_info(left_distances, right_distances)
55
+
56
+ scaled_left_links = scale_coordinates([coord for pair in left_links for coord in pair], original_size=224, mask_size=640)
57
+ left_pairs = list(zip(scaled_left_links[::2], scaled_left_links[1::2]))
58
+ scaled_right_links = scale_coordinates([coord for pair in right_links for coord in pair], original_size=224, mask_size=640)
59
+ right_pairs = list(zip(scaled_right_links[::2], scaled_right_links[1::2]))
60
+
61
+ # st.write("Left JSW: [{}]".format(", ".join(f"{x:.2f}" for x in left_distances)))
62
+ # st.write("Right JSW: [{}]".format(", ".join(f"{x:.2f}" for x in right_distances)))
63
+
64
+ # st.write("$JSW_M$: ", jsw_m)
65
+ # st.write("$JSW_{MM}$: ", jsw_mm)
66
+
67
+ probabilites = classif_model.predict(img)
68
+ # st.write("Probability: [{}]".format(", ".join(f"{x:.2f}" for x in probabilites[0])))
69
+ # st.write("len: ", len(probabilites[0]))
70
+ annotations = get_annotations(probabilites[0])
71
+
72
+ predicted, exp = ml_model.predict_explain(probabilites[0], [jsw_m, jsw_mm], filename=file_name)
73
+
74
+ processed_img, anomaly = anomaly_extractor.extract(mask, img, verbose=0)
75
+
76
+ def plot_anomaly_with_clues(processed_img, anomaly, diff_percentage, mean_left, mean_right, side_min, index_min, value_min,
77
+ left_pairs, right_pairs, color='r', thickness=1):
78
+ # Tạo một figure và axes từ matplotlib
79
+ fig, ax = plt.subplots()
80
+ ax.imshow(processed_img, cmap = 'gray')
81
+ ax.imshow(anomaly, cmap='turbo', alpha = 0.3)
82
+
83
+ # # Vẽ đường thẳng từ pairs_left
84
+ # for pairs in [left_pairs, right_pairs]:
85
+ # for pair in pairs:
86
+ # start_point = pair[0]
87
+ # end_point = pair[1]
88
+ # ax.plot([start_point[0], end_point[0]], [start_point[1], end_point[1]], color=color, linewidth=thickness)
89
+ # ax.scatter(*start_point, c=color, s=thickness*2) # Vẽ điểm đầu
90
+ # ax.scatter(*end_point, c=color, s=thickness*2) # Vẽ điểm cuối
91
+
92
+ # Thêm chữ diff_percentage vào giữa ảnh
93
+ mid_x = anomaly.shape[1] // 2
94
+ mid_y = anomaly.shape[0] // 10
95
+ # ax.text(mid_x, mid_y, f'% difference between left & right joint space distance: {diff_percentage:.2f}%',
96
+ # color='white', fontsize=8, ha='center', va='center')
97
+
98
+
99
+ # Thêm mean_distance cho left và right
100
+ # left_mid_index = len(left_pairs) // 2
101
+ # right_mid_index = len(right_pairs) // 2
102
+ # left_mid_point = left_pairs[left_mid_index][0]
103
+ # right_mid_point = right_pairs[right_mid_index][0]
104
+
105
+ # ax.text(left_mid_point[0], anomaly.shape[0] // 1.5, f'{mean_left:.2f}', color='yellow', fontsize=20, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()])
106
+ # ax.text(right_mid_point[0], anomaly.shape[0] // 1.5, f'{mean_right:.2f}', color='yellow', fontsize=20, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()])
107
+ # print((mean_left // value_min > 2) if side_min == 0 else (mean_right // value_min > 2))
108
+
109
+ if (diff_percentage > 0) or ((mean_left / value_min > 2) if side_min == 0 else (mean_right / value_min > 2)):
110
+ min_pairs = left_pairs if side_min == 0 else right_pairs
111
+ min_pair = min_pairs[index_min]
112
+ start_point = min_pair[0]
113
+ end_point = min_pair[1]
114
+
115
+ # Xác định tọa độ để vẽ bbox xung quanh đường thẳng đứng với padding
116
+ padding_x = 23 # Độ rộng padding theo chiều x
117
+ padding_y = 13 # Độ rộng padding theo chiều y
118
+ min_x = start_point[0] - padding_x
119
+ max_x = start_point[0] + padding_x
120
+ min_y = min(start_point[1], end_point[1]) - padding_y
121
+ max_y = max(start_point[1], end_point[1]) + padding_y
122
+
123
+ rect = plt.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y, linewidth=3, edgecolor='red', facecolor='none')
124
+ ax.add_patch(rect)
125
+
126
+ # Thêm văn bản với giá trị value_min tại vị trí của đường thẳng có khoảng cách nhỏ nhất
127
+ # text = ax.text(start_point[0], min_y-padding_y//2, f'{value_min:.2f}', color='red', fontsize=18, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
128
+
129
+ # add annotation osteophyte & jsn
130
+ ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 3.2, f'{annotations["osteophyte"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
131
+ ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 1.3, f'{annotations["jsn"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
132
+
133
+ # add annoatation jsw_m & jsw_mm info
134
+ ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.05, f'$JSW_{{Mean}}$: {jsw_m:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()])
135
+ ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.1, f'$JSW_{{MM}}$: {jsw_mm:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()])
136
+
137
+ ax.axis('off')
138
+ # ax.set_title('Annotated anomaly map')
139
+ # plt.show()
140
+ return fig
141
+
142
+ with col1:
143
+ caption = "Uploaded Image"
144
+ st.markdown(
145
+ f"<h3 style='text-align: center;'>{caption}</h3>",
146
+ unsafe_allow_html=True
147
+ )
148
+ st.image(img, channels="BGR", use_column_width=True)
149
+ # plt.imshow(anomaly, cmap="turbo")
150
+ # plt.axis('off')
151
+ # st.pyplot()
152
+ with col2:
153
+ # st.image(mask_image, channels="BGR", caption='mask image', use_column_width=True)
154
+
155
+ fig = plot_anomaly_with_clues(
156
+ processed_img,
157
+ anomaly,
158
+ diff_percentage,
159
+ mean_left,
160
+ mean_right,
161
+ side_min, index_min,
162
+ value_min,
163
+ left_pairs,
164
+ right_pairs,
165
+ color='r',
166
+ thickness=1
167
+ )
168
+ caption = "Annotated Anomaly Map"
169
+ st.markdown(
170
+ f"<h3 style='text-align: center;'>{caption}</h3>",
171
+ unsafe_allow_html=True
172
+ )
173
+ st.pyplot(fig, bbox_inches='tight', pad_inches=0)
174
+
175
+
176
+ exp_html = exp.as_html()
177
+ components.html(exp_html, height=1400, width=None)
178
+
179
+ # def extract_explain(exp, label):
180
+ # ans = exp.local_exp[label]
181
+ # ans = [(exp.domain_mapper.feature_names[x[0]],
182
+ # exp.domain_mapper.feature_values[x[0]],
183
+ # exp.domain_mapper.discretized_feature_names[x[0]],
184
+ # float(x[1])
185
+ # ) for x in ans]
186
+
187
+ # return ans
188
+
189
+ # explanation_list = extract_explain(exp, predicted)
190
+ # explanation_df = pd.DataFrame(explanation_list, columns=['Feature', 'Value', 'Explain', 'Weight'])
191
+ # st.table(explanation_df)
192
+
classification/__init__.py ADDED
File without changes
classification/model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import cv2
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import numpy as np
8
+
9
+ from .utils import to_rgb
10
+
11
+
12
+ def create_effNetv2s():
13
+ model = torchvision.models.efficientnet_v2_s(weights='IMAGENET1K_V1')
14
+ num_features = model.classifier[1].in_features
15
+ model.classifier[1] = nn.Sequential(
16
+ nn.Linear(num_features, NUM_CLASSES),
17
+ nn.Sigmoid()
18
+ )
19
+ return model
20
+
21
+ def create_convnet():
22
+ model = torchvision.models.convnext_base(weights='IMAGENET1K_V1')
23
+ num_features = model.classifier[2].in_features
24
+ model.classifier[2] = nn.Sequential(
25
+ nn.Linear(num_features, NUM_CLASSES),
26
+ nn.Sigmoid()
27
+ )
28
+ return model
29
+
30
+ def create_model(model_name):
31
+ model = _MODEL[model_name]()
32
+ model.load_state_dict(torch.load(_WEIGHT[model_name], map_location=torch.device('cpu')))
33
+ model.to(DEVICE)
34
+ return model
35
+
36
+ def create_transform():
37
+ transform = transforms.Compose([
38
+ transforms.Resize((HEIGHT, WEIGHT)),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize((0.6078, 0.6078, 0.6078), (0.1932, 0.1932, 0.1932))
41
+ ])
42
+ return transform
43
+
44
+
45
+ _MODEL = {
46
+ "effNetv2s": create_effNetv2s,
47
+ "convnet": create_convnet
48
+ }
49
+
50
+ _WEIGHT = {
51
+ "effNetv2s": './classification/weights/effnetv2s.pt',
52
+ "convnet": './classification/weights/convnet.pt',
53
+ }
54
+
55
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ HEIGHT = 224
57
+ WEIGHT = 224
58
+ NUM_CLASSES = 44
59
+
60
+
61
+ class Classifier():
62
+ def __init__(self, model_name="effNetv2s"):
63
+ self.model = create_model(model_name)
64
+ self.transform = create_transform()
65
+
66
+ def predict(self, image):
67
+ '''
68
+ input: cv2 image
69
+ output: multi-label probability vector
70
+ '''
71
+ image = to_rgb(image)
72
+ image = Image.fromarray(image)
73
+ image = self.transform(image)
74
+ self.model.eval()
75
+ with torch.no_grad():
76
+ out = self.model(image.unsqueeze(0).to(DEVICE)).cpu().numpy()
77
+ return out
classification/utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ def to_rgb(image):
4
+ if len(image.shape) == 3 and image.shape[-1] == 3:
5
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
6
+ return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
classification/weights/convnet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe61c8220bbd41af3e536451985c1c49cc0d8b438de2da76f7fdbba8f6051741
3
+ size 350549706
gae/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .anomaly_extraction import AnomalyExtractor
gae/anomaly_extraction.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import streamlit as st
3
+ import tensorflow as tf
4
+ from PIL import Image
5
+ from skimage.filters import gaussian
6
+
7
+ from .model import DCGAN
8
+ from .modules import DropBlockNoise
9
+ from .utils import *
10
+ from .contour import get_contours_v2
11
+
12
+
13
+ IMAGE_SIZE = 224
14
+
15
+
16
+ class AnomalyExtractor():
17
+ def __init__(self):
18
+ print("================================load dcgan=================================\n")
19
+ dcgan = DCGAN(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),
20
+ architecture='two-stage',
21
+ output_activation='sigmoid',
22
+ noise=DropBlockNoise(rate=0.1, block_size=16),
23
+ pretrain_weights=None,
24
+ block_type='pixel-shuffle',
25
+ kernel_initializer='glorot_uniform',
26
+ C=1.)
27
+
28
+ self.restore_model = dcgan.generator
29
+ self.restore_model.load_weights("./gae/weights/gan_efficientunet_full_augment-hist_equal_generator.h5")
30
+ self.restore_model.trainable = False
31
+
32
+ def extract(self, mask, img, verbose=0):
33
+ img = to_gray(img)
34
+ img = np.array(Image.fromarray(img).resize((224, 224)))
35
+ img = preprocess(img)
36
+
37
+ if verbose:
38
+ # Hiển thị ảnh gốc
39
+ show_image(img, title="Original image")
40
+ plt.axis('off')
41
+ st.pyplot()
42
+
43
+ # Hiển thị mask
44
+ show_image(mask, title="Mask")
45
+ plt.axis('off')
46
+ st.pyplot()
47
+
48
+ uc, lc = get_contours_v2(mask, verbose=0)
49
+
50
+ mask = np.zeros((640, 640)).astype('uint8')
51
+ mask = draw_points(mask, lc, thickness=1, color=(255, 255, 255))
52
+ mask = draw_points(mask, uc, thickness=1, color=(255, 255, 255))
53
+ mask = cv2.resize(mask, (224, 224), cv2.INTER_NEAREST)
54
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
55
+ mask = mask / 255.
56
+
57
+ if verbose:
58
+ show_image(mask, title = "Contour")
59
+ plt.axis('off')
60
+ st.pyplot()
61
+
62
+ # sử dụng equalization histogram
63
+ mask = 1 - mask
64
+ dila = dilate(mask)
65
+ dilated = gaussian(dilate(mask), sigma=50, truncate=0.3)
66
+
67
+ if verbose:
68
+ show_image(dila, title="dilated Image")
69
+ plt.axis('off')
70
+ st.pyplot()
71
+
72
+ show_image(dilated, title="blur Image")
73
+ plt.axis('off')
74
+ st.pyplot()
75
+
76
+ im = np.expand_dims(img * (1 - dilated), axis=0)
77
+ im = tf.convert_to_tensor(im, dtype=tf.float32)
78
+
79
+ restored_img = self.restore_model(im)
80
+
81
+ res = tf.squeeze(tf.squeeze(restored_img, axis=-1), axis=0)
82
+
83
+ if verbose:
84
+ show_image(im[0], title="Masked Image")
85
+ plt.axis('off')
86
+ st.pyplot()
87
+
88
+ show_image(res, title="Reconstructed image")
89
+ plt.axis('off')
90
+ st.pyplot()
91
+
92
+ show_image(dilated*tf.abs(img - res), title="Anomaly map", cmap_type='turbo')
93
+ plt.axis('off')
94
+ st.pyplot()
95
+
96
+
97
+ # plt.imshow(img, cmap = 'gray')
98
+ # plt.imshow(dilated*tf.abs(img - res), cmap ='turbo', alpha = 0.3)
99
+ # plt.axis('off')
100
+ # plt.show()
101
+ # st.pyplot()
102
+
103
+ # show_image(dilated, title="dilated", cmap_type='gray')
104
+ # plt.axis('off')
105
+ # st.pyplot()
106
+
107
+ return img, dilated*tf.abs(img - res)
gae/contour.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ def find_boundaries_v2(mask, top=True, verbose=0):
7
+ boundaries = []
8
+ height, width = mask.shape
9
+
10
+ contours, _ = cv2.findContours(255 * mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
11
+
12
+ areas = np.array([cv2.contourArea(cnt) for cnt in contours])
13
+ contour = contours[areas.argmax()]
14
+ contour = contour.reshape(-1, 2)
15
+ org_contour = contour.copy()
16
+ pos = (contour[:, 1].max() + contour[:, 1].min()) // 2
17
+ idx = np.where(contour[:, 1] == pos)
18
+ if contour[idx[0][0]][0] < contour[idx[0][-1]][0] and not top:
19
+ # start = contour[idx[0][0]]
20
+ # end = contour[idx[0][-1]]
21
+ start_idx = idx[0][0]
22
+ end_idx = idx[0][-1]
23
+ else:
24
+ # end = contour[idx[0][0]]
25
+ # start = contour[idx[0][-1]]
26
+ end_idx = idx[0][0]
27
+ start_idx = idx[0][-1]
28
+ # start_idx = ((start - contour) ** 2).sum(axis=-1).argmin()
29
+ # end_idx = ((end - contour) ** 2).sum(axis=-1).argmin()
30
+ if start_idx <= end_idx:
31
+ contour = contour[start_idx:end_idx + 1]
32
+ else:
33
+ contour = np.concatenate([contour[start_idx:], contour[:end_idx + 1]])
34
+ if verbose:
35
+ temp = draw_points(127 * mask.astype(np.uint8), contour, thickness=5)
36
+ temp = draw_points(temp, [start, end], color=[155, 155], thickness=15)
37
+ cv2_imshow(temp)
38
+
39
+ return np.array(contour), np.array(org_contour)
40
+
41
+
42
+ def get_contours_v2(mask, verbose=0):
43
+ upper_contour, full_upper = find_boundaries_v2(mask == 1, top=False, verbose=verbose)
44
+ lower_contour, full_lower = find_boundaries_v2(mask == 2, top=True, verbose=verbose)
45
+ if verbose:
46
+ temp = draw_points(127 * mask, full_upper, thickness=3, color=(255, 0, 0))
47
+ temp = draw_points(temp, full_lower, thickness=3)
48
+ plt.imshow(temp)
49
+ plt.title("Segmentation")
50
+ plt.axis('off')
51
+ plt.show()
52
+ # st.pyplot()
53
+ # cv2.imwrite('full.png', temp)
54
+ # temp = draw_points(temp, limit_points, thickness = 7, color = (0, 0, 255))
55
+ # cv2_imshow(temp)
56
+ # cv2.imwrite('limit_points.png', temp)
57
+ if verbose:
58
+ temp = draw_points(127 * mask, upper_contour, thickness=3, color=(255, 0, 0))
59
+ temp = draw_points(temp, lower_contour, thickness=3)
60
+ cv2_imshow(temp)
61
+ # st.pyplot()
62
+ # cv2.imwrite('cropped.png', temp)
63
+
64
+ return upper_contour, lower_contour
gae/model.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import numpy as np
3
+ import cv2
4
+ from tqdm import tqdm
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.keras.optimizers import Adam
8
+ from tensorflow.keras.models import Sequential, Model
9
+
10
+ from tensorflow.keras import layers
11
+ from tensorflow.keras.applications import EfficientNetV2S
12
+ from tensorflow.keras.layers import (
13
+ Dense, Flatten, Conv2D, Activation, BatchNormalization,
14
+ MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D,
15
+ Dropout, Input, concatenate, add, Conv2DTranspose, Lambda,
16
+ SpatialDropout2D, Cropping2D, UpSampling2D, LeakyReLU,
17
+ ZeroPadding2D, Reshape, Concatenate, Multiply, Permute, Add
18
+ )
19
+
20
+ from .contour import get_contours_v2
21
+ from .modules import (
22
+ MultipleTrackers, DropBlockNoise, squeeze_excite_block, spatial_squeeze_excite_block,
23
+ channel_spatial_squeeze_excite, DoubleConv, UpSampling2D_block, Conv2DTranspose_block,
24
+ PixelShuffle_block
25
+ )
26
+ from .utils import mae
27
+
28
+
29
+ IMAGE_SIZE = 224
30
+
31
+
32
+ def adjust_pretrained_weights(model_cls, input_size, name=None):
33
+ weights_model = model_cls(weights='imagenet',
34
+ include_top=False,
35
+ input_shape=(*input_size, 3))
36
+ target_model = model_cls(weights=None,
37
+ include_top=False,
38
+ input_shape=(*input_size, 1))
39
+ weights = weights_model.get_weights()
40
+ weights[0] = np.sum(weights[0], axis=2, keepdims=True)
41
+ target_model.set_weights(weights)
42
+
43
+ del weights_model
44
+ tf.keras.backend.clear_session()
45
+ gc.collect()
46
+ if name:
47
+ target_model._name = name
48
+ return target_model
49
+
50
+
51
+ def get_efficient_unet(name=None,
52
+ option='full',
53
+ input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),
54
+ encoder_weights=None,
55
+ block_type='conv-transpose',
56
+ output_activation='sigmoid',
57
+ kernel_initializer='glorot_uniform'):
58
+
59
+ if encoder_weights == 'imagenet':
60
+ encoder = adjust_pretrained_weights(EfficientNetV2S, input_shape[:-1], name)
61
+ elif encoder_weights is None:
62
+ encoder = EfficientNetV2S(weights=None,
63
+ include_top=False,
64
+ input_shape=input_shape)
65
+ encoder._name = name
66
+ else:
67
+ raise ValueError(encoder_weights)
68
+
69
+ if option == 'encoder':
70
+ return encoder
71
+
72
+ MBConvBlocks = []
73
+
74
+ skip_candidates = ['1b', '2d', '3d', '4f']
75
+
76
+ for mbblock_nr in skip_candidates:
77
+ mbblock = encoder.get_layer('block{}_add'.format(mbblock_nr)).output
78
+ MBConvBlocks.append(mbblock)
79
+
80
+ head = encoder.get_layer('top_activation').output
81
+ blocks = MBConvBlocks + [head]
82
+
83
+ if block_type == 'upsampling':
84
+ UpBlock = UpSampling2D_block
85
+ elif block_type == 'conv-transpose':
86
+ UpBlock = Conv2DTranspose_block
87
+ elif block_type == 'pixel-shuffle':
88
+ UpBlock = PixelShuffle_block
89
+ else:
90
+ raise ValueError(block_type)
91
+
92
+ o = blocks.pop()
93
+ o = UpBlock(512, initializer=kernel_initializer, skip=blocks.pop())(o)
94
+ o = UpBlock(256, initializer=kernel_initializer, skip=blocks.pop())(o)
95
+ o = UpBlock(128, initializer=kernel_initializer, skip=blocks.pop())(o)
96
+ o = UpBlock(64, initializer=kernel_initializer, skip=blocks.pop())(o)
97
+ o = UpBlock(32, initializer=kernel_initializer, skip=None)(o)
98
+ o = Conv2D(input_shape[-1], (1, 1), padding='same', activation=output_activation, kernel_initializer=kernel_initializer)(o)
99
+
100
+ model = Model(encoder.input, o, name=name)
101
+
102
+ if option == 'full':
103
+ return model, encoder
104
+ elif option == 'model':
105
+ return model
106
+ else:
107
+ raise ValueError(option)
108
+
109
+
110
+ class DCGAN():
111
+ def __init__(self,
112
+ input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),
113
+ architecture='two-stage',
114
+ pretrain_weights=None,
115
+ output_activation='sigmoid',
116
+ block_type='conv-transpose',
117
+ kernel_initializer='glorot_uniform',
118
+ noise=None,
119
+ C=1.):
120
+
121
+ self.C = C
122
+ # Build
123
+ kwargs = dict(input_shape=input_shape,
124
+ output_activation=output_activation,
125
+ encoder_weights=pretrain_weights,
126
+ block_type=block_type,
127
+ kernel_initializer=kernel_initializer)
128
+
129
+ if architecture == 'two-stage':
130
+ encoder = get_efficient_unet(name='dcgan_disc',
131
+ option='encoder',
132
+ **kwargs)
133
+
134
+ self.generator = get_efficient_unet(name='dcgan_gen', option='model', **kwargs)
135
+ elif architecture == 'shared':
136
+
137
+ self.generator, encoder = get_efficient_unet(name='dcgan', option='full', **kwargs)
138
+ else:
139
+ raise ValueError(f'Unsupport architecture: {architecture}')
140
+
141
+ gpooling = GlobalAveragePooling2D()(encoder.output)
142
+ prediction = Dense(1, activation='sigmoid')(gpooling)
143
+ self.discriminator = Model(encoder.input, prediction, name='dcgan_disc')
144
+
145
+ tf.keras.backend.clear_session()
146
+ _ = gc.collect()
147
+
148
+ if noise:
149
+ gen_inputs = self.generator.input
150
+ corrupted_inputs = noise(gen_inputs)
151
+ outputs = self.generator(corrupted_inputs)
152
+ self.generator = Model(gen_inputs, outputs, name='dcgan_gen')
153
+
154
+ tf.keras.backend.clear_session()
155
+ _ = gc.collect()
156
+
157
+ if output_activation == 'tanh':
158
+
159
+ self.process_input = layers.Lambda(lambda img: (img*2.-1.), name='dcgan_normalize')
160
+ self.process_output = layers.Lambda(lambda img: (img*0.5+0.5), name='dcgan_denormalize')
161
+ gen_inputs = self.generator.input
162
+ process_inputs = self.process_input(gen_inputs)
163
+ process_inputs = self.generator(process_inputs)
164
+ gen_outputs = self.process_output(process_inputs)
165
+ self.generator = Model(gen_inputs, gen_outputs, name='dcgan_gen')
166
+
167
+ disc_inputs = self.discriminator.input
168
+ process_inputs = self.process_input(disc_inputs)
169
+ disc_outputs = self.discriminator(process_inputs)
170
+ self.discriminator = Model(disc_inputs, disc_outputs, name='dcgan_disc')
171
+
172
+ tf.keras.backend.clear_session()
173
+ _ = gc.collect()
174
+
175
+ def summary(self):
176
+ self.generator.summary()
177
+ self.discriminator.summary()
178
+
179
+ def compile(self,
180
+ generator_optimizer=Adam(5e-4, 0.5),
181
+ discriminator_optimizer=Adam(5e-4),
182
+ reconstruction_loss=mae,
183
+ discriminative_loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
184
+ reconstruction_metrics=[],
185
+ discriminative_metrics=[]):
186
+
187
+ self.discriminator_optimizer = discriminator_optimizer
188
+ self.discriminator.compile(optimizer=self.discriminator_optimizer)
189
+
190
+ self.generator_optimizer = generator_optimizer
191
+ self.generator.compile(optimizer=self.generator_optimizer)
192
+
193
+ self.loss = discriminative_loss
194
+ self.reconstruction_loss = reconstruction_loss
195
+ self.d_loss_tracker = tf.keras.metrics.Mean()
196
+ self.g_loss_tracker = tf.keras.metrics.Mean()
197
+ self.g_recon_tracker = tf.keras.metrics.Mean()
198
+ self.g_disc_tracker = tf.keras.metrics.Mean()
199
+
200
+ self.g_metric_trackers = [(tf.keras.metrics.Mean(), metric) for metric in reconstruction_metrics]
201
+ self.d_metric_trackers = [(tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), metric) for metric in discriminative_metrics]
202
+
203
+ all_trackers = [self.d_loss_tracker, self.g_loss_tracker, self.g_recon_tracker, self.g_disc_tracker] + \
204
+ [tracker for tracker,_ in self.g_metric_trackers] + \
205
+ [tracker for t in self.d_metric_trackers for tracker in t[:-1]]
206
+ self.all_trackers = MultipleTrackers(all_trackers)
207
+
208
+ def discriminator_loss(self, real_output, fake_output):
209
+ real_loss = self.loss(tf.ones_like(real_output), real_output)
210
+ fake_loss = self.loss(tf.zeros_like(fake_output), fake_output)
211
+ total_loss = 0.5*(real_loss + fake_loss)
212
+ return total_loss
213
+
214
+ def generator_loss(self, fake_output):
215
+ return self.loss(tf.ones_like(fake_output), fake_output)
216
+
217
+ @tf.function
218
+ def train_step(self, images):
219
+ masked, original = images
220
+ n_samples = tf.shape(original)[0]
221
+
222
+ with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
223
+ generated_images = self.generator(masked, training=True)
224
+
225
+ real_output = self.discriminator(original, training=True)
226
+ fake_output = self.discriminator(generated_images, training=True)
227
+
228
+ gen_disc_loss = self.generator_loss(fake_output)
229
+ recon_loss = self.reconstruction_loss(original, generated_images)
230
+ gen_loss = self.C*recon_loss + gen_disc_loss
231
+ disc_loss = self.discriminator_loss(real_output, fake_output)
232
+
233
+ gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
234
+ gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
235
+
236
+ self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
237
+ self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
238
+
239
+ self.d_loss_tracker.update_state(tf.repeat([[disc_loss]], repeats=n_samples, axis=0))
240
+ self.g_loss_tracker.update_state(tf.repeat([[gen_loss]], repeats=n_samples, axis=0))
241
+ self.g_recon_tracker.update_state(tf.repeat([[recon_loss]], repeats=n_samples, axis=0))
242
+ self.g_disc_tracker.update_state(tf.repeat([[gen_disc_loss]], repeats=n_samples, axis=0))
243
+
244
+ logs = {'d_loss': self.d_loss_tracker.result()}
245
+
246
+ for tracker, real_tracker, fake_tracker, metric in self.d_metric_trackers:
247
+ v_real = metric(tf.ones_like(real_output), real_output)
248
+ v_fake = metric(tf.zeros_like(fake_output), fake_output)
249
+ v = 0.5*(v_real + v_fake)
250
+ tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0))
251
+ real_tracker.update_state(tf.repeat([[v_real]], repeats=n_samples, axis=0))
252
+ fake_tracker.update_state(tf.repeat([[v_fake]], repeats=n_samples, axis=0))
253
+
254
+ metric_name = metric.__name__
255
+ logs['d_' + metric_name] = tracker.result()
256
+ logs['d_real_' + metric_name] = real_tracker.result()
257
+ logs['d_fake_' + metric_name] = fake_tracker.result()
258
+
259
+ logs['g_loss'] = self.g_loss_tracker.result()
260
+ logs['g_recon'] = self.g_recon_tracker.result()
261
+ logs['g_disc'] = self.g_disc_tracker.result()
262
+
263
+ for tracker, metric in self.g_metric_trackers:
264
+ v = metric(original, generated_images)
265
+ tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0))
266
+ logs['g_' + metric.__name__] = tracker.result()
267
+
268
+ return logs
269
+
270
+ @tf.function
271
+ def val_step(self, images):
272
+ masked, original = images
273
+ n_samples = tf.shape(original)[0]
274
+
275
+ generated_images = self.generator(masked, training=False)
276
+
277
+ real_output = self.discriminator(original, training=False)
278
+ fake_output = self.discriminator(generated_images, training=False)
279
+
280
+ gen_disc_loss = self.generator_loss(fake_output)
281
+ recon_loss = self.reconstruction_loss(original, generated_images)
282
+ gen_loss = self.C*recon_loss + gen_disc_loss
283
+ disc_loss = self.discriminator_loss(real_output, fake_output)
284
+
285
+ self.d_loss_tracker.update_state(tf.repeat([[disc_loss]], repeats=n_samples, axis=0))
286
+ self.g_loss_tracker.update_state(tf.repeat([[gen_loss]], repeats=n_samples, axis=0))
287
+ self.g_recon_tracker.update_state(tf.repeat([[recon_loss]], repeats=n_samples, axis=0))
288
+ self.g_disc_tracker.update_state(tf.repeat([[gen_disc_loss]], repeats=n_samples, axis=0))
289
+
290
+ logs = {'val_d_loss': self.d_loss_tracker.result()}
291
+
292
+ for tracker, real_tracker, fake_tracker, metric in self.d_metric_trackers:
293
+ v_real = metric(tf.ones_like(real_output), real_output)
294
+ v_fake = metric(tf.zeros_like(fake_output), fake_output)
295
+ v = 0.5*(v_real + v_fake)
296
+ tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0))
297
+ real_tracker.update_state(tf.repeat([[v_real]], repeats=n_samples, axis=0))
298
+ fake_tracker.update_state(tf.repeat([[v_fake]], repeats=n_samples, axis=0))
299
+
300
+ metric_name = metric.__name__
301
+ logs['val_d_' + metric_name] = tracker.result()
302
+ logs['val_d_real_' + metric_name] = real_tracker.result()
303
+ logs['val_d_fake_' + metric_name] = fake_tracker.result()
304
+
305
+ logs['val_g_loss'] = self.g_loss_tracker.result()
306
+ logs['val_g_recon'] = self.g_recon_tracker.result()
307
+ logs['val_g_disc'] = self.g_disc_tracker.result()
308
+
309
+ for tracker, metric in self.g_metric_trackers:
310
+ v = metric(original, generated_images)
311
+ tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0))
312
+ logs['val_g_' + metric.__name__] = tracker.result()
313
+
314
+ return logs
315
+
316
+ def fit(self,
317
+ trainset,
318
+ valset=None,
319
+ trainsize=-1,
320
+ valsize=-1,
321
+ epochs=1,
322
+ display_per_epochs=5,
323
+ generator_callbacks=[],
324
+ discriminator_callbacks=[]):
325
+
326
+ print('🌊🐉 Start Training 🐉🌊')
327
+ gen_callback_tracker = tf.keras.callbacks.CallbackList(
328
+ generator_callbacks, add_history=True, model=self.generator
329
+ )
330
+
331
+ disc_callback_tracker = tf.keras.callbacks.CallbackList(
332
+ discriminator_callbacks, add_history=True, model=self.discriminator
333
+ )
334
+
335
+ callbacks_tracker = MultipleTrackers([gen_callback_tracker, disc_callback_tracker])
336
+
337
+ logs = {}
338
+ callbacks_tracker.on_train_begin(logs=logs)
339
+
340
+ for epoch in range(epochs):
341
+ print(f'Epochs {epoch+1}/{epochs}:')
342
+ callbacks_tracker.on_epoch_begin(epoch, logs=logs)
343
+
344
+ batches = tqdm(trainset,
345
+ desc="Train",
346
+ total=trainsize,
347
+ unit="step",
348
+ position=0,
349
+ leave=True)
350
+
351
+ for batch, image_batch in enumerate(batches):
352
+
353
+ callbacks_tracker.on_batch_begin(batch, logs=logs)
354
+ callbacks_tracker.on_train_batch_begin(batch, logs=logs)
355
+
356
+ train_logs = {k:v.numpy() for k, v in self.train_step(image_batch).items()}
357
+ logs.update(train_logs)
358
+
359
+ callbacks_tracker.on_train_batch_end(batch, logs=logs)
360
+ callbacks_tracker.on_batch_end(batch, logs=logs)
361
+ batches.set_postfix({'d_loss': train_logs['d_loss'],
362
+ 'g_loss': train_logs['g_loss']
363
+ })
364
+
365
+ # Presentation
366
+ stats = ", ".join("{}={:.3g}".format(k, v) for k, v in logs.items() if 'val_' not in k and 'loss' not in k)
367
+ print('Train:', stats)
368
+
369
+ batches.close()
370
+ if valset:
371
+ self.all_trackers.reset_state()
372
+
373
+ batches = tqdm(valset,
374
+ desc="Valid",
375
+ total=valsize,
376
+ unit="step",
377
+ position=0,
378
+ leave=True)
379
+
380
+ for batch, image_batch in enumerate(batches):
381
+ callbacks_tracker.on_batch_begin(batch, logs=logs)
382
+ callbacks_tracker.on_test_batch_begin(batch, logs=logs)
383
+ val_logs = {k:v.numpy() for k, v in self.val_step(image_batch).items()}
384
+ logs.update(val_logs)
385
+
386
+ callbacks_tracker.on_test_batch_end(batch, logs=logs)
387
+ callbacks_tracker.on_batch_end(batch, logs=logs)
388
+ # Presentation
389
+ batches.set_postfix({'val_d_loss': val_logs['val_d_loss'],
390
+ 'val_g_loss': val_logs['val_g_loss']
391
+ })
392
+
393
+ stats = ", ".join("{}={:.3g}".format(k, v) for k, v in logs.items() if 'val_' in k and 'loss' not in k)
394
+ print('Valid:', stats)
395
+
396
+ batches.close()
397
+
398
+ if epoch % display_per_epochs == 0:
399
+ print('-'*128)
400
+ self.visualize_samples((image_batch[0][:2], image_batch[1][:2]))
401
+
402
+ self.all_trackers.reset_state()
403
+
404
+ callbacks_tracker.on_epoch_end(epoch, logs=logs)
405
+ # tf.keras.backend.clear_session()
406
+ _ = gc.collect()
407
+
408
+ if self.generator.stop_training or self.discriminator.stop_training:
409
+ break
410
+ print('-'*128)
411
+
412
+ callbacks_tracker.on_train_end(logs=logs)
413
+ tf.keras.backend.clear_session()
414
+ _ = gc.collect()
415
+ gen_history = None
416
+ for cb in gen_callback_tracker:
417
+ if isinstance(cb, tf.keras.callbacks.History):
418
+ gen_history = cb
419
+ gen_history.history = {k:v for k,v in cb.history.items() if 'd_' not in k}
420
+
421
+ disc_history = None
422
+ for cb in disc_callback_tracker:
423
+ if isinstance(cb, tf.keras.callbacks.History):
424
+ disc_history = cb
425
+ disc_history.history = {k:v for k,v in cb.history.items() if 'g_' not in k}
426
+
427
+ return {'generator':gen_history,
428
+ 'discriminator':disc_history}
429
+
430
+ def visualize_samples(self, samples, figsize=(12, 2)):
431
+ x, y = samples
432
+ y_pred = self.generator.predict(x[:2], verbose=0)
433
+ fig, axs = plt.subplots(1, 6, figsize=figsize)
434
+ for i in range(2):
435
+ pos = 3*i
436
+ axs[pos].imshow(x[i], cmap='gray', vmin=0., vmax=1.)
437
+ axs[pos].set_title('Masked')
438
+ axs[pos].axis('off')
439
+ axs[pos+1].imshow(y[i], cmap='gray', vmin=0., vmax=1.)
440
+ axs[pos+1].set_title('Original')
441
+ axs[pos+1].axis('off')
442
+ axs[pos+2].imshow(y_pred[i], cmap='gray', vmin=0., vmax=1.)
443
+ axs[pos+2].set_title('Predicted')
444
+ axs[pos+2].axis('off')
445
+ plt.show()
446
+
447
+ # tf.keras.backend.clear_session()
448
+ del y_pred
449
+ _ = gc.collect()
gae/modules.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.__internal__.layers import BaseRandomLayer
3
+ from tensorflow.keras.layers import (
4
+ Dense, Flatten, Conv2D, Activation, BatchNormalization,
5
+ MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D,
6
+ Dropout, Input, concatenate, add, Conv2DTranspose, Lambda,
7
+ SpatialDropout2D, Cropping2D, UpSampling2D, LeakyReLU,
8
+ ZeroPadding2D, Reshape, Concatenate, Multiply, Permute, Add
9
+ )
10
+ from keras import backend as K
11
+
12
+ from .utils import normalize_tuple
13
+
14
+
15
+ class MultipleTrackers():
16
+ def __init__(self, callback_lists: list):
17
+ self.callbacks_list = callback_lists
18
+
19
+ def __getattr__(self, attr):
20
+ def helper(*arg, **kwarg):
21
+ for cb in self.callbacks_list:
22
+ getattr(cb, attr)(*arg, **kwarg)
23
+ if attr in self.__class__.__dict__:
24
+ return getattr(self, attr)
25
+ else:
26
+ return helper
27
+
28
+
29
+ class DropBlockNoise(BaseRandomLayer):
30
+ def __init__(
31
+ self,
32
+ rate,
33
+ block_size,
34
+ seed=None,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(seed=seed, **kwargs)
38
+ if not 0.0 <= rate <= 1.0:
39
+ raise ValueError(
40
+ f"rate must be a number between 0 and 1. " f"Received: {rate}"
41
+ )
42
+
43
+ self._rate = rate
44
+ (
45
+ self._dropblock_height,
46
+ self._dropblock_width,
47
+ ) = normalize_tuple(
48
+ value=block_size, n=2, name="block_size", allow_zero=False
49
+ )
50
+ self.seed = seed
51
+
52
+ def call(self, x, training=None):
53
+ if not training or self._rate == 0.0:
54
+ return x
55
+
56
+ _, height, width, _ = tf.split(tf.shape(x), 4)
57
+
58
+ # Unnest scalar values
59
+ height = tf.squeeze(height)
60
+ width = tf.squeeze(width)
61
+
62
+ dropblock_height = tf.math.minimum(self._dropblock_height, height)
63
+ dropblock_width = tf.math.minimum(self._dropblock_width, width)
64
+
65
+ gamma = (
66
+ self._rate
67
+ * tf.cast(width * height, dtype=tf.float32)
68
+ / tf.cast(dropblock_height * dropblock_width, dtype=tf.float32)
69
+ / tf.cast(
70
+ (width - self._dropblock_width + 1)
71
+ * (height - self._dropblock_height + 1),
72
+ tf.float32,
73
+ )
74
+ )
75
+
76
+ # Forces the block to be inside the feature map.
77
+ w_i, h_i = tf.meshgrid(tf.range(width), tf.range(height))
78
+ valid_block = tf.logical_and(
79
+ tf.logical_and(
80
+ w_i >= int(dropblock_width // 2),
81
+ w_i < width - (dropblock_width - 1) // 2,
82
+ ),
83
+ tf.logical_and(
84
+ h_i >= int(dropblock_height // 2),
85
+ h_i < width - (dropblock_height - 1) // 2,
86
+ ),
87
+ )
88
+
89
+ valid_block = tf.reshape(valid_block, [1, height, width, 1])
90
+
91
+ random_noise = self._random_generator.random_uniform(
92
+ tf.shape(x), dtype=tf.float32
93
+ )
94
+ valid_block = tf.cast(valid_block, dtype=tf.float32)
95
+ seed_keep_rate = tf.cast(1 - gamma, dtype=tf.float32)
96
+ block_pattern = (1 - valid_block + seed_keep_rate + random_noise) >= 1
97
+ block_pattern = tf.cast(block_pattern, dtype=tf.float32)
98
+
99
+ window_size = [1, self._dropblock_height, self._dropblock_width, 1]
100
+
101
+ # Double negative and max_pool is essentially min_pooling
102
+ block_pattern = -tf.nn.max_pool(
103
+ -block_pattern,
104
+ ksize=window_size,
105
+ strides=[1, 1, 1, 1],
106
+ padding="SAME",
107
+ )
108
+
109
+ return (
110
+ x * tf.cast(block_pattern, x.dtype)
111
+ )
112
+
113
+
114
+ def squeeze_excite_block(input, ratio=16):
115
+ ''' Create a channel-wise squeeze-excite block
116
+
117
+ Args:
118
+ input: input tensor
119
+ filters: number of output filters
120
+
121
+ Returns: a keras tensor
122
+
123
+ References
124
+ - [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507)
125
+ '''
126
+ init = input
127
+ channel_axis = 1 if K.image_data_format() == "channels_first" else -1
128
+ filters = int(init.shape[channel_axis])
129
+ se_shape = (1, 1, filters)
130
+
131
+ se = GlobalAveragePooling2D()(init)
132
+ se = Reshape(se_shape)(se)
133
+ se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
134
+ se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
135
+
136
+ if K.image_data_format() == 'channels_first':
137
+ se = Permute((3, 1, 2))(se)
138
+
139
+ x = Multiply()([init, se])
140
+ return x
141
+
142
+
143
+ def spatial_squeeze_excite_block(input):
144
+ ''' Create a spatial squeeze-excite block
145
+
146
+ Args:
147
+ input: input tensor
148
+
149
+ Returns: a keras tensor
150
+
151
+ References
152
+ - [Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks](https://arxiv.org/abs/1803.02579)
153
+ '''
154
+
155
+ se = Conv2D(1, (1, 1), activation='sigmoid', use_bias=False,
156
+ kernel_initializer='he_normal')(input)
157
+
158
+ x = Multiply()([input, se])
159
+ return x
160
+
161
+
162
+ def channel_spatial_squeeze_excite(input, ratio=16):
163
+ ''' Create a spatial squeeze-excite block
164
+
165
+ Args:
166
+ input: input tensor
167
+ filters: number of output filters
168
+
169
+ Returns: a keras tensor
170
+
171
+ References
172
+ - [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507)
173
+ - [Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks](https://arxiv.org/abs/1803.02579)
174
+ '''
175
+
176
+ cse = squeeze_excite_block(input, ratio)
177
+ sse = spatial_squeeze_excite_block(input)
178
+
179
+ x = Add()([cse, sse])
180
+ return x
181
+
182
+
183
+ def DoubleConv(filters, kernel_size, initializer='glorot_uniform'):
184
+ def layer(x):
185
+
186
+ x = Conv2D(filters, kernel_size, padding='same', kernel_initializer=initializer)(x)
187
+ x = BatchNormalization()(x)
188
+ x = Activation('swish')(x)
189
+ x = Conv2D(filters, kernel_size, padding='same', kernel_initializer=initializer)(x)
190
+ x = BatchNormalization()(x)
191
+ x = Activation('swish')(x)
192
+
193
+ return x
194
+
195
+ return layer
196
+
197
+
198
+ def UpSampling2D_block(filters, kernel_size=(3, 3), upsample_rate=(2, 2), interpolation='bilinear',
199
+ initializer='glorot_uniform', skip=None):
200
+ def layer(input_tensor):
201
+
202
+ x = UpSampling2D(size=upsample_rate, interpolation=interpolation)(input_tensor)
203
+
204
+ if skip is not None:
205
+ x = Concatenate()([x, skip])
206
+
207
+ x = DoubleConv(filters, kernel_size, initializer=initializer)(x)
208
+ x = channel_spatial_squeeze_excite(x)
209
+ return x
210
+
211
+ return layer
212
+
213
+
214
+ def Conv2DTranspose_block(filters, transpose_kernel_size=(3, 3), upsample_rate=(2, 2),
215
+ initializer='glorot_uniform', skip=None, met_input=None, sat_input=None):
216
+ def layer(input_tensor):
217
+ x = Conv2DTranspose(filters, transpose_kernel_size, strides=upsample_rate, padding='same')(input_tensor)
218
+ if skip is not None:
219
+ x = Concatenate()([x, skip])
220
+
221
+ x = DoubleConv(filters, transpose_kernel_size, initializer=initializer)(x)
222
+ x = channel_spatial_squeeze_excite(x)
223
+ return x
224
+
225
+ return layer
226
+
227
+
228
+ def PixelShuffle_block(filters, kernel_size=(3, 3), upsample_rate=2,
229
+ initializer='glorot_uniform', skip=None, met_input=None, sat_input=None):
230
+ def layer(input_tensor):
231
+ x = Conv2D(filters * (upsample_rate ** 2), kernel_size, padding="same",
232
+ activation="swish", kernel_initializer='Orthogonal')(input_tensor)
233
+ x = tf.nn.depth_to_space(x, upsample_rate)
234
+ if skip is not None:
235
+ x = Concatenate()([x, skip])
236
+
237
+ x = DoubleConv(filters, kernel_size, initializer=initializer)(x)
238
+ x = channel_spatial_squeeze_excite(x)
239
+ return x
240
+
241
+ return layer
gae/utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from scipy import ndimage
4
+ import matplotlib.pyplot as plt
5
+
6
+ from skimage import exposure
7
+
8
+ import tensorflow as tf
9
+ import tensorflow_addons as tfa
10
+
11
+
12
+ IMAGE_SIZE = 224
13
+
14
+
15
+ def acc(y_true, y_pred, threshold=0.5):
16
+ threshold = tf.cast(threshold, y_pred.dtype)
17
+ y_pred = tf.cast(y_pred > threshold, y_pred.dtype)
18
+ return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))
19
+
20
+ def mae(y_true, y_pred):
21
+ return tf.reduce_mean(tf.abs(y_true-y_pred))
22
+
23
+ def inv_ssim(y_true, y_pred):
24
+ return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
25
+
26
+ def inv_msssim(y_true, y_pred):
27
+ return 1 - tf.reduce_mean(tf.image.ssim_multiscale(y_true, y_pred, 1.0, filter_size=4))
28
+
29
+ def inv_msssim_l1(y_true, y_pred, alpha=0.8):
30
+ return alpha*inv_msssim(y_true, y_pred) + (1-alpha)*mae(y_true, y_pred)
31
+
32
+ def inv_msssim_gaussian_l1(y_true, y_pred, alpha=0.8):
33
+ l1_diff = tf.abs(y_true-y_pred)
34
+ gaussian_l1 = tfa.image.gaussian_filter2d(l1_diff, filter_shape=(11, 11), sigma=1.5)
35
+ return alpha*inv_msssim(y_true, y_pred) + (1-alpha)*gaussian_l1
36
+
37
+ def psnr(y_true, y_pred):
38
+ return tf.reduce_mean(tf.image.psnr(y_true, y_pred, 1.0))
39
+
40
+
41
+ def show_image(image, title='Image', cmap_type='gray'):
42
+ plt.imshow(image, cmap=cmap_type)
43
+ plt.title(title)
44
+ plt.axis('off')
45
+ plt.show()
46
+
47
+
48
+ # đảo màu những ảnh bị ngược màu
49
+ def remove_negative(img):
50
+ outside = np.mean(img[ : , 0])
51
+ inside = np.mean(img[ : , int(IMAGE_SIZE / 2)])
52
+ if outside < inside:
53
+ return img
54
+ else:
55
+ return 1 - img
56
+
57
+ # lựa chọn tiền xử lý: ảnh gốc, Equalization histogram, CLAHE
58
+ def preprocess(img):
59
+ img = remove_negative(img)
60
+
61
+ img = exposure.equalize_hist(img)
62
+ img = exposure.equalize_adapthist(img)
63
+ img = exposure.equalize_hist(img)
64
+ return img
65
+
66
+
67
+ # dilate contour
68
+ def dilate(mask_img):
69
+ kernel_size = 2 * 22 + 1
70
+ kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
71
+ return ndimage.binary_dilation(mask_img == 0, structure=kernel)
72
+
73
+
74
+ def normalize_tuple(value, n, name, allow_zero=False):
75
+ """Transforms non-negative/positive integer/integers into an integer tuple.
76
+ Args:
77
+ value: The value to validate and convert. Could an int, or any iterable of
78
+ ints.
79
+ n: The size of the tuple to be returned.
80
+ name: The name of the argument being validated, e.g. "strides" or
81
+ "kernel_size". This is only used to format error messages.
82
+ allow_zero: Default to False. A ValueError will raised if zero is received
83
+ and this param is False.
84
+ Returns:
85
+ A tuple of n integers.
86
+ Raises:
87
+ ValueError: If something else than an int/long or iterable thereof or a
88
+ negative value is
89
+ passed.
90
+ """
91
+ error_msg = (
92
+ f"The `{name}` argument must be a tuple of {n} "
93
+ f"integers. Received: {value}"
94
+ )
95
+
96
+ if isinstance(value, int):
97
+ value_tuple = (value,) * n
98
+ else:
99
+ try:
100
+ value_tuple = tuple(value)
101
+ except TypeError:
102
+ raise ValueError(error_msg)
103
+ if len(value_tuple) != n:
104
+ raise ValueError(error_msg)
105
+ for single_value in value_tuple:
106
+ try:
107
+ int(single_value)
108
+ except (ValueError, TypeError):
109
+ error_msg += (
110
+ f"including element {single_value} of "
111
+ f"type {type(single_value)}"
112
+ )
113
+ raise ValueError(error_msg)
114
+
115
+ if allow_zero:
116
+ unqualified_values = {v for v in value_tuple if v < 0}
117
+ req_msg = ">= 0"
118
+ else:
119
+ unqualified_values = {v for v in value_tuple if v <= 0}
120
+ req_msg = "> 0"
121
+
122
+ if unqualified_values:
123
+ error_msg += (
124
+ f" including {unqualified_values}"
125
+ f" that does not satisfy the requirement `{req_msg}`."
126
+ )
127
+ raise ValueError(error_msg)
128
+
129
+ return value_tuple
130
+
131
+
132
+ def draw_points(image, points, color=None, random_color=False, same=True, thickness=1):
133
+ if color is None and not random_color:
134
+ color = (0, 255, 0) # Màu mặc định là xanh lá cây (BGR)
135
+ if random_color:
136
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
137
+
138
+ image = to_color(image)
139
+
140
+ for point in points:
141
+ if random_color and not same:
142
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
143
+
144
+ x, y = point
145
+ image = cv2.circle(image, (x, y), thickness, color, -1) # Vẽ điểm lên ảnh
146
+ return image
147
+
148
+
149
+ def draw_lines(image, pairs, color=None, random_color=False, same=True, thickness=1):
150
+ image_with_line = to_color(np.copy(image))
151
+
152
+ if color is None and not random_color:
153
+ color = (0, 255, 0) # Màu mặc định là xanh lá cây (BGR)
154
+ if random_color:
155
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
156
+
157
+ # Vẽ đường thẳng dựa trên danh sách các cặp điểm
158
+ for pair in pairs:
159
+
160
+ if random_color and not same:
161
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
162
+
163
+ start_point = pair[0]
164
+ end_point = pair[1]
165
+ image_with_line = cv2.line(image_with_line, start_point, end_point, color, thickness)
166
+ image_with_line = cv2.circle(image_with_line, start_point, thickness + 1, color, -1)
167
+ image_with_line = cv2.circle(image_with_line, end_point, thickness + 1, color, -1)
168
+
169
+ return image_with_line
170
+
171
+ def to_color(image):
172
+ if len(image.shape) == 3 and image.shape[-1] == 3:
173
+ return image
174
+ return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
175
+
176
+
177
+ def to_gray(image):
178
+ if len(image.shape) == 3 and image.shape[-1] == 3:
179
+ return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
180
+ return image
gae/weights/gan_efficientunet_full_augment-hist_equal_generator.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6b3f57e8b8c2b6d56f1b0ac9eaceaa693600ccd52ba637293aee23eaf40a819
3
+ size 230002208
jsw/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .cal_jsw import get_JSW, calculate_diff, calculate_jsw_info
jsw/cal_jsw.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import matplotlib.pyplot as plt
4
+
5
+ from .utils import *
6
+ from .contour import get_contours
7
+ from .distance import distance
8
+
9
+
10
+ def pooling_array(array, n, mode='mean'):
11
+ if mode == 'mean':
12
+ pool = lambda x: np.mean(x)
13
+ elif mode == 'min':
14
+ pool = lambda x: np.min(x)
15
+ elif mode == 'sum':
16
+ pool = lambda x: np.sum(x)
17
+
18
+ if n == 1:
19
+ return pool(array)
20
+
21
+ array_length = len(array)
22
+ if array_length < n:
23
+ return array
24
+ segment_length = array_length // n
25
+ remaining_elements = array_length % n
26
+
27
+ if remaining_elements == 0:
28
+ segments = np.split(array, n)
29
+ else:
30
+ mid = remaining_elements * (segment_length + 1)
31
+ segments = np.split(array[:mid], remaining_elements)
32
+ segments += np.split(array[mid:], n - remaining_elements)
33
+
34
+ segments = [pool(segment) for segment in segments]
35
+
36
+ return np.array(segments)
37
+
38
+
39
+ def pool_links(links, dim, mode="mean"):
40
+ '''
41
+ links là 1 list gồm nhiều cặp tọa độ trên dưới (knee join space giữa lower_contour vs upper_contour) có dạng:
42
+ [(array([436, 421], dtype=int32), array([436, 451], dtype=int32)), (array([436, 421], dtype=int32), array([436, 451], dtype=int32)), ...]
43
+ đầu tiên lấy danh sách các tọa độ x của các pair coord này, pool nó y như pool distance, rồi filter links ban đầu để lấy các pair có x thuộc pooled
44
+
45
+ '''
46
+ pooled_x = pooling_array(np.array(links)[:, 0, 0], dim, mode)
47
+ filtered_pairs = []
48
+ for x in pooled_x:
49
+ for pair in links:
50
+ if pair[0][0] == int(x):
51
+ filtered_pairs.append(pair)
52
+ break # Lấy 1 cặp và dừng lại
53
+ return filtered_pairs
54
+
55
+
56
+ def get_JSW(mask, dim=None, pool='mean', p=0.3, verbose=0):
57
+ '''
58
+ input: mask (H, W) with femur is 1 and tibia is 2
59
+ output: 2 distance vectors (left, right) (aggregated to <dim> by <pool>), links (list of pairs coords for each side)
60
+ '''
61
+ if isinstance(mask, str):
62
+ mask = cv2.imread(mask, 0)
63
+ if mask is None:
64
+ return np.zeros(10), np.zeros(10)
65
+ uc, lc = get_contours(mask, verbose=verbose)
66
+ left_distances, right_distances, links, contours = distance(mask, uc, lc, p=p, verbose=verbose)
67
+ if verbose:
68
+ # print('in getjsw')
69
+ temp = draw_points(mask * 127, contours[0], thickness=3, color=(255, 0, 0))
70
+ temp = draw_points(temp, contours[1], thickness=5, color=(255, 0, 0))
71
+ temp = draw_points(temp, contours[2], thickness=5, color=(0, 255, 0))
72
+ temp = draw_points(temp, contours[3], thickness=5, color=(0, 255, 0))
73
+ temp = draw_lines(temp, links[::6], color=(0, 0, 255), thickness=2)
74
+ # cv2_imshow(temp)
75
+ # cv2.imwrite("drawn_lines.png", temp)
76
+ if dim:
77
+ left_distances = pooling_array(left_distances, dim, pool)
78
+ right_distances = pooling_array(right_distances, dim, pool)
79
+ # print(left_distances)
80
+ # print(right_distances)
81
+
82
+ # print(links[1])
83
+ links = pool_links(links[0], dim, mode="mean"), pool_links(links[1], dim, mode="mean")
84
+ return left_distances, right_distances, links
85
+
86
+
87
+ def calculate_diff(left_jsw, right_jsw):
88
+ '''
89
+ input: left_distances and right_distances vectors
90
+ output: jsw_m and jsw_mm
91
+ '''
92
+ jsw_max = max(np.max(left_jsw), np.max(right_jsw))
93
+ jsw_max_side = np.argmax([np.max(left_jsw), np.max(right_jsw)])
94
+ if jsw_max_side == 0:
95
+ jsw_min = np.min(right_jsw)
96
+ else:
97
+ jsw_min = np.min(left_jsw)
98
+
99
+ diff_mean = abs(np.mean(left_jsw) - np.mean(right_jsw)) / jsw_max
100
+ diff_max_min = (jsw_max - jsw_min) / jsw_max
101
+
102
+ return diff_mean, diff_max_min
103
+
104
+
105
+ def calculate_jsw_info(left_jsw, right_jsw):
106
+ '''
107
+ input: left_distances and right_distances vectors
108
+ output: % diff, mean_left, mean_right, side_min (0 is left, 1 is right), index with value min (in side min), value min (in side min)
109
+ '''
110
+ mean_left = np.mean(left_jsw)
111
+ mean_right = np.mean(right_jsw)
112
+
113
+ side_min, index_min, value_min = (0, np.argmin(left_jsw), np.min(left_jsw)) if mean_left <= mean_right else (1, np.argmin(right_jsw), np.min(right_jsw))
114
+
115
+ diff_percentage = np.abs((mean_left - mean_right) / mean_right) * 100
116
+
117
+ return diff_percentage, mean_left, mean_right, side_min, index_min, value_min
jsw/contour.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import cv2
4
+
5
+ from .utils import *
6
+
7
+ def detect_limit_points(mask, verbose=0):
8
+ # tìm giới hạn hai bên của khớp gối
9
+ '''
10
+ input: mask (H, W) with femur is 1 and tibia is 2
11
+ output: 4 points serve as limit points to determine the upper and lower contours from the full contours of the femur and tibia
12
+ '''
13
+ h, w = mask.shape
14
+ res = []
15
+ upper_pivot = np.array([0, w // 2]) # r c
16
+ lower_pivot = np.array([h, w // 2]) # r c
17
+
18
+ left_slice = slice(0, w // 2)
19
+ right_slice = slice(w // 2, None)
20
+ center_slice = slice(int(0.2 * h), int(0.8 * h))
21
+
22
+ left = np.zeros_like(mask)
23
+ left[center_slice, left_slice] = mask[center_slice, left_slice]
24
+
25
+ right = np.zeros_like(mask)
26
+ right[center_slice, right_slice] = mask[center_slice, right_slice]
27
+
28
+ if verbose:
29
+ cv2_imshow([left, right])
30
+
31
+ pivot = np.array([0, w])
32
+ coords = np.argwhere(left == 1)
33
+ distances = ((coords - pivot) ** 2).sum(axis=-1)
34
+ point = coords[distances.argmax()][::-1]
35
+ res.append(point)
36
+
37
+ pivot = np.array([0, 0])
38
+ coords = np.argwhere(right == 1)
39
+ distances = ((coords - pivot) ** 2).sum(axis=-1)
40
+ point = coords[distances.argmax()][::-1]
41
+ res.append(point)
42
+
43
+ pivot = np.array([h, w])
44
+ coords = np.argwhere(left == 2)
45
+ distances = ((coords - pivot) ** 2).sum(axis=-1)
46
+ point = coords[distances.argmax()][::-1]
47
+ res.append(point)
48
+
49
+ pivot = np.array([h, 0])
50
+ coords = np.argwhere(right == 2)
51
+ distances = ((coords - pivot) ** 2).sum(axis=-1)
52
+ point = coords[distances.argmax()][::-1]
53
+ res.append(point)
54
+
55
+ if verbose:
56
+ cv2_imshow(draw_points(127 * mask, res))
57
+
58
+ return res
59
+
60
+ def find_boundaries(mask, start, end, top=True, verbose=0):
61
+ # nếu top = True, tìm đường bao bên trên cùng từ left đến right
62
+ # nếu top = False, tìm đường bao dưới cùng từ left đến right
63
+ '''
64
+ input:
65
+ mask (H, W) of femur or tibia
66
+ start, end is limit point to extract upper_contour/lower_contour from femur/tibia full contour
67
+ output: upper_contour/lower_contour
68
+ use top = True if determine lower contour from tibia mask
69
+ '''
70
+ boundaries = []
71
+ height, width = mask.shape
72
+
73
+ contours, _ = cv2.findContours(255 * mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
74
+
75
+ areas = np.array([cv2.contourArea(cnt) for cnt in contours])
76
+ contour = contours[areas.argmax()]
77
+ contour = contour.reshape(-1, 2)
78
+ org_contour = contour.copy()
79
+
80
+ start_idx = ((start - contour) ** 2).sum(axis=-1).argmin()
81
+ end_idx = ((end - contour) ** 2).sum(axis=-1).argmin()
82
+ if start_idx <= end_idx:
83
+ contour = contour[start_idx:end_idx + 1]
84
+ else:
85
+ contour = np.concatenate([contour[start_idx:], contour[:end_idx + 1]])
86
+
87
+ if top:
88
+ sorted_indices = np.argsort(contour[:, 1])[::-1]
89
+ else:
90
+ sorted_indices = np.argsort(contour[:, 1])
91
+ contour = contour[sorted_indices]
92
+
93
+ unique_indices = sorted(np.unique(contour[:, 0], return_index=True)[1])
94
+ contour = contour[unique_indices]
95
+ sorted_indices = np.argsort(contour[:, 0])
96
+ contour = contour[sorted_indices]
97
+ if verbose:
98
+ temp = draw_points(127 * mask.astype(np.uint8), contour, thickness=5)
99
+ temp = draw_points(temp, [start, end], color=[155, 155], thickness=15)
100
+ # cv2_imshow(temp)
101
+
102
+ return np.array(contour), np.array(org_contour)
103
+
104
+
105
+ def get_contours(mask, verbose=0):
106
+ '''
107
+ input: mask (H, W) with femur is 1 and tibia is 2
108
+ output: upper_contour, lower_contour
109
+ '''
110
+ limit_points = detect_limit_points(mask, verbose=verbose)
111
+ upper_contour, full_upper = find_boundaries(mask == 1, limit_points[0], limit_points[1], top=False, verbose=verbose)
112
+ lower_contour, full_lower = find_boundaries(mask == 2, limit_points[3], limit_points[2], top=True, verbose=verbose)
113
+ if verbose:
114
+ temp = draw_points(127 * mask, full_upper, thickness=3, color=(255, 0, 0))
115
+ temp = draw_points(temp, full_lower, thickness=3)
116
+ # cv2_imshow(temp)
117
+ # cv2.imwrite('full.png', temp)
118
+ temp = draw_points(temp, limit_points, thickness=7, color=(0, 0, 255))
119
+ # cv2_imshow(temp)
120
+ # cv2.imwrite('limit_points.png', temp)
121
+ if verbose:
122
+ temp = draw_points(127 * mask, upper_contour, thickness=3, color=(255, 0, 0))
123
+ temp = draw_points(temp, lower_contour, thickness=3)
124
+ # cv2_imshow(temp)
125
+ # cv2.imwrite('cropped.png', temp)
126
+
127
+ return upper_contour, lower_contour
jsw/distance.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.linear_model import LinearRegression
2
+ import numpy as np
3
+ import cv2
4
+ import matplotlib.pyplot as plt
5
+
6
+ from .utils import *
7
+
8
+
9
+ def distance(mask, upper_contour, lower_contour, p=0.12, verbose = 0):
10
+ '''
11
+ input:
12
+ mask (H, W) with femur is 1 and tibia is 2
13
+ upper_contour, lower_contour
14
+ output: left_distances and right_distances vectors
15
+ '''
16
+ left, right = getMiddle(mask, lower_contour, verbose = verbose)
17
+ x_min = max(lower_contour[0,0],upper_contour[0,0])
18
+ x_max = min(lower_contour[-1,0],upper_contour[-1,0])
19
+
20
+ left_idx = np.where(lower_contour[:,0] == left)[0][0]
21
+ right_idx = np.where(lower_contour[:,0] == right)[0][0]
22
+ left_lower_contour = lower_contour[left_idx:]
23
+ right_lower_contour = lower_contour[:right_idx+1][::-1]
24
+
25
+ left_lower_contour = lower_contour[(lower_contour[:,0]<=left) & (lower_contour[:,0]>=x_min) ]
26
+ right_lower_contour = lower_contour[(lower_contour[:,0]>=right) & (lower_contour[:,0]<=x_max)][::-1]
27
+
28
+ left_upper_contour = upper_contour[(upper_contour[:,0]<=left) & (upper_contour[:,0]>=x_min)]
29
+ right_upper_contour = upper_contour[(upper_contour[:,0]>=right) & (upper_contour[:,0]<=x_max)][::-1]
30
+
31
+ if verbose == 1:
32
+ temp = draw_points(mask*127, left_lower_contour, color = (0, 255, 0), thickness = 5)
33
+ temp = draw_points(temp, right_lower_contour, color = (0, 255, 0), thickness = 5)
34
+ temp = draw_points(temp, left_upper_contour, color = (255, 0, 0), thickness = 5)
35
+ temp = draw_points(temp, right_upper_contour, color = (255, 0, 0), thickness = 5)
36
+ # cv2_imshow(temp)
37
+ # cv2.imwrite('center_cropped.png', temp)
38
+ links = list(zip(left_upper_contour,left_lower_contour)), list(zip(right_upper_contour,right_lower_contour))
39
+
40
+ temp = left_upper_contour, right_upper_contour, left_lower_contour, right_lower_contour
41
+
42
+ return left_lower_contour[:,1]-left_upper_contour[:,1], right_lower_contour[:,1] - right_upper_contour[:,1], links, temp
43
+
44
+
45
+ def getMiddle(mask, contour, verbose=0):
46
+ '''
47
+ use Linear Regression to construct a straight line to remove the bone segment from upper_contour and lower_contour
48
+ input:
49
+ mask (H, W) with femur is 1 and tibia is 2
50
+ lower_contour
51
+ output: 2 point (left and right) define the boundary within which every point of the bone segment is located
52
+ '''
53
+ X = contour[:, 0].reshape(-1, 1)
54
+ y = contour[:, 1]
55
+ reg = LinearRegression().fit(X, y)
56
+ i_min = np.argmin(y[int(len(y) * 0.2):int(len(y) * 0.8)]) + int(len(y) * 0.2)
57
+ left = i_min - 1
58
+ right = i_min + 1
59
+ left_check = False
60
+ right_check = False
61
+ if verbose == 1:
62
+ # print('get Middle 1')
63
+ cmask = draw_points(mask, contour, thickness=3, color=(255, 0, 0))
64
+ cmask = draw_points(cmask, np.hstack([X, reg.predict(X).reshape(-1, 1).astype('int')]),thickness = 5,
65
+ color=(0, 255, 0))
66
+ # cv2.imwrite("lr_mask.png", cmask)
67
+ # cv2_imshow(cmask)
68
+ # plt.show()
69
+
70
+ while True:
71
+ while not left_check:
72
+ if y[left] > reg.predict(X[left].reshape(-1, 1)):
73
+ break
74
+ left -= 1
75
+ while not right_check:
76
+ if y[right] > reg.predict(X[right].reshape(-1, 1)):
77
+ break
78
+ right += 1
79
+ if verbose == 1:
80
+ # print('get middle 2')
81
+ cmask = draw_points(cmask, [contour[left]], thickness=10, color=(255, 255, 0))
82
+ cmask = draw_points(cmask, [contour[right]], thickness=10, color=(0, 255, 255))
83
+ # print(cmask.shape)
84
+ # cv2.imwrite("lr.png", cmask)
85
+ # cv2_imshow(cmask)
86
+ # plt.show()
87
+
88
+ left_min = np.argmin(y[int(len(y) * 0.2):left]) + int(len(y) * 0.2) if int(len(y) * 0.2) < left else left
89
+ right_min = np.argmin(y[right:int(len(y) * 0.8)]) + right if right < int(len(y) * 0.8) else right
90
+ if y[left_min] > reg.predict(X[left_min].reshape(-1, 1)):
91
+ left_check = True
92
+ if y[right_min] > reg.predict(X[right_min].reshape(-1, 1)):
93
+ right_check = True
94
+ if right_check and left_check:
95
+ break
96
+ left = left_min - 1
97
+ right = right_min + 1
98
+ return min(X.flatten()[left], X.flatten()[right]), max(X.flatten()[left], X.flatten()[right])
jsw/utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import cv2
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ def draw_points(image, points, color=None, random_color=False, same=True, thickness=1):
7
+ if color is None and not random_color:
8
+ color = (0, 255, 0) # Màu mặc định là xanh lá cây (BGR)
9
+ if random_color:
10
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
11
+
12
+ image = to_color(image)
13
+
14
+ for point in points:
15
+ if random_color and not same:
16
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
17
+
18
+ x, y = point
19
+ image = cv2.circle(image, (x, y), thickness, color, -1) # Vẽ điểm lên ảnh
20
+ return image
21
+
22
+
23
+ def draw_lines(image, pairs, color=None, random_color=False, same=True, thickness=1):
24
+ image_with_line = to_color(np.copy(image))
25
+
26
+ if color is None and not random_color:
27
+ color = (255, 0, 0) # Màu mặc định là xanh lá cây (BGR)
28
+ if random_color:
29
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
30
+
31
+ # Vẽ đường thẳng dựa trên danh sách các cặp điểm
32
+ for pair in pairs:
33
+
34
+ if random_color and not same:
35
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
36
+
37
+ start_point = pair[0]
38
+ end_point = pair[1]
39
+ image_with_line = cv2.line(image_with_line, start_point, end_point, color, thickness)
40
+ image_with_line = cv2.circle(image_with_line, start_point, thickness + 1, color, -1)
41
+ image_with_line = cv2.circle(image_with_line, end_point, thickness + 1, color, -1)
42
+
43
+ return image_with_line
44
+
45
+
46
+ def center(contour):
47
+ idx = len(contour) // 2
48
+ return contour[idx]
49
+
50
+
51
+ def to_color(image):
52
+ if len(image.shape) == 3 and image.shape[-1] == 3:
53
+ return image
54
+ return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
55
+
56
+
57
+ def to_gray(image):
58
+ if len(image.shape) == 3 and image.shape[-1] == 3:
59
+ return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
60
+ return image
61
+
62
+
63
+ def cv2_imshow(images):
64
+ if not isinstance(images, list):
65
+ images = [images]
66
+
67
+ num_images = len(images)
68
+
69
+ # Hiển thị ảnh đơn lẻ trực tiếp bằng imshow
70
+ if num_images == 1:
71
+ image = images[0]
72
+ if len(image.shape) == 3 and image.shape[2] == 3:
73
+ # Ảnh màu (RGB)
74
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
+ plt.imshow(image_rgb)
76
+ else:
77
+ # Ảnh xám
78
+ plt.imshow(image, cmap='gray')
79
+
80
+ plt.axis("off")
81
+ plt.show()
82
+ else:
83
+ # Hiển thị nhiều ảnh trên cùng một cột
84
+ fig, ax = plt.subplots(num_images, 1, figsize=(4, 4 * num_images))
85
+
86
+ for i in range(num_images):
87
+ image = images[i]
88
+ if len(image.shape) == 3 and image.shape[2] == 3:
89
+ # Ảnh màu (RGB)
90
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
91
+ ax[i].imshow(image_rgb)
92
+ else:
93
+ # Ảnh xám
94
+ ax[i].imshow(image, cmap='gray')
95
+
96
+ ax[i].axis("off")
97
+
98
+ plt.tight_layout()
99
+ plt.show()
100
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.32.2
2
+ opencv-python==4.9.0.80
3
+ ultralytics
4
+ numpy==1.26.4
5
+ torch==2.2.2
6
+ scikit-learn==1.2.2
7
+
8
+ keras_cv~=0.5.0
9
+ tqdm==4.66.2
10
+ keras-core==0.1.7
11
+ tensorflow==2.14
12
+ tensorflow-addons==0.22.0
13
+ scipy==1.12.0
14
+ matplotlib==3.8.3
15
+ scikit-image==0.22.0
16
+ Pillow==10.2.0
17
+
18
+ xgboost==2.1.2
19
+ lime==0.2.0.1
segmentation/__init__.py ADDED
File without changes
segmentation/model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import cv2
3
+ import torch
4
+
5
+ BEST_WEIGHT = './segmentation/weights/oai_s_best4.pt'
6
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def create_model(weight_path=None):
9
+ if weight_path:
10
+ return YOLO(weight_path)
11
+ else:
12
+ return YOLO(BEST_WEIGHT)
13
+
14
+ class Segmenter():
15
+ def __init__(self, weight_path=None):
16
+ self.model = create_model(weight_path).to(DEVICE)
17
+
18
+ def segment(self, img):
19
+ """
20
+ input: image (H, W, C)
21
+ output: mask (H, W) with femur is 1 and tibia is 2
22
+ """
23
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
24
+
25
+ eimg = cv2.equalizeHist(img)
26
+
27
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
28
+ eimg = clahe.apply(eimg)
29
+
30
+ eimg = cv2.cvtColor(eimg, cv2.COLOR_GRAY2RGB)
31
+
32
+ res = self.model(eimg, verbose=False)
33
+
34
+ mask = res[0].masks.data[0] * (res[0].boxes.cls[0] + 1) + res[0].masks.data[1] * (res[0].boxes.cls[1] + 1)
35
+ mask = mask.cpu().numpy()
36
+
37
+ return mask
segmentation/weights/oai_s_best4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0159a91db899213c34db9cc93db1b5a31576eb3b78eb61bd1d9a29a4ef92e843
3
+ size 6771320
utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ # from shared_queue import image_queue
4
+
5
+
6
+ def read_image(file_bytes):
7
+ np_img = np.frombuffer(file_bytes, np.uint8)
8
+ img = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
9
+ return img
10
+
11
+ def combine_mask(image, mask, label2color={1: (255, 255, 0), 2: (0, 255, 255)}, alpha=0.1):
12
+ image = to_color(image)
13
+ image = cv2.resize(image, mask.shape)
14
+ mask_image = np.zeros_like(image)
15
+ for label, color in label2color.items():
16
+ mask_image[mask == label] = color
17
+
18
+ mask_image = cv2.addWeighted(image, 1 - alpha, mask_image, alpha, 0)
19
+ return mask_image
20
+
21
+
22
+ def to_color(image):
23
+ if len(image.shape) == 3 and image.shape[-1] == 3:
24
+ return image
25
+ return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
26
+
27
+
28
+ def to_gray(image):
29
+ if len(image.shape) == 3 and image.shape[-1] == 3:
30
+ return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
31
+ return image
32
+
33
+ def prob_process(prob_list):
34
+ prob = prob_list[0]
35
+ max_idx = np.argmax(prob)
36
+ min_idx = np.argmin(prob)
37
+ prob[max_idx] += prob[min_idx]
38
+ prob[min_idx] = 0
39
+
40
+ if max_idx == 0:
41
+ return 1 - prob[max_idx]
42
+ elif max_idx ==1:
43
+ return prob[max_idx]
44
+ else:
45
+ return 1 + prob[max_idx]
46
+
47
+ def combine_prob_jsw(prob, jsw_m, jsw_mm):
48
+ processed_prob = prob_process(prob)
49
+ temp = np.array([processed_prob, jsw_m, jsw_mm])
50
+ if temp.any():
51
+ temp = np.hstack(temp)
52
+ else:
53
+ temp = temp.flatten()
54
+
55
+ return temp
56
+
57
+ def scale_coordinates(links, original_size, mask_size):
58
+ scale_factor = original_size / mask_size
59
+ scaled_links = [(int(x * scale_factor), int(y * scale_factor)) for x, y in links]
60
+ return scaled_links
61
+
62
+ def get_annotations(probabilites):
63
+ OSTEOPYTE_LEVELS = {
64
+ 0: "Definite osteophytes",
65
+ 1: "No osteophytes",
66
+ 2: "Possible osteophytes",
67
+ }
68
+
69
+ JNS_LEVELS = {
70
+ 0: "Definite JSN",
71
+ 1: "Mild JSN",
72
+ 2: "No JSN",
73
+ 3: "Severe JSN",
74
+ }
75
+
76
+ return {
77
+ "osteophyte": OSTEOPYTE_LEVELS[np.argmax(probabilites[5:8])],
78
+ "jsn": JNS_LEVELS[np.argmax(probabilites[8:12])],
79
+ }