Nguyễn Thành Đạt
commited on
Commit
·
036e7c4
0
Parent(s):
update code
Browse files- .gitattributes +35 -0
- .gitignore +4 -0
- .python-version +1 -0
- ML_model/__init__.py +0 -0
- ML_model/data/X_train.pkl +3 -0
- ML_model/data/clinical-2.csv +0 -0
- ML_model/data/selected_features.pkl +3 -0
- ML_model/model.py +66 -0
- ML_model/weights/xgboost_convnet_best.pkl +3 -0
- README.md +12 -0
- __init__.py +0 -0
- app.py +192 -0
- classification/__init__.py +0 -0
- classification/model.py +77 -0
- classification/utils.py +6 -0
- classification/weights/convnet.pt +3 -0
- gae/__init__.py +1 -0
- gae/anomaly_extraction.py +107 -0
- gae/contour.py +64 -0
- gae/model.py +449 -0
- gae/modules.py +241 -0
- gae/utils.py +180 -0
- gae/weights/gan_efficientunet_full_augment-hist_equal_generator.h5 +3 -0
- jsw/__init__.py +1 -0
- jsw/cal_jsw.py +117 -0
- jsw/contour.py +127 -0
- jsw/distance.py +98 -0
- jsw/utils.py +100 -0
- requirements.txt +19 -0
- segmentation/__init__.py +0 -0
- segmentation/model.py +37 -0
- segmentation/weights/oai_s_best4.pt +3 -0
- utils.py +79 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore Python cache files
|
| 2 |
+
*.pyc
|
| 3 |
+
__pycache__/
|
| 4 |
+
venv/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11.0
|
ML_model/__init__.py
ADDED
|
File without changes
|
ML_model/data/X_train.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8104f7e7cbc69e960b07138448d538edc2d6856d02862a4182466da58ec07ab7
|
| 3 |
+
size 2775817
|
ML_model/data/clinical-2.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ML_model/data/selected_features.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8593b584c411a5f65c5a5265e987f077d54f8e17fbe7ae9b653a69d93b83955b
|
| 3 |
+
size 344
|
ML_model/model.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from sklearn.ensemble import RandomForestClassifier
|
| 2 |
+
import numpy as np
|
| 3 |
+
import joblib
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 6 |
+
from lime import lime_tabular
|
| 7 |
+
|
| 8 |
+
class MLModel():
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.model = joblib.load('./ML_model/weights/xgboost_convnet_best.pkl')
|
| 11 |
+
|
| 12 |
+
self.clinical_data = pd.read_csv("./ML_model/data/clinical-2.csv")
|
| 13 |
+
del self.clinical_data['ID']
|
| 14 |
+
del self.clinical_data['SIDE']
|
| 15 |
+
|
| 16 |
+
self.columns_to_scale = ['AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI', 'KOOS PAIN SCORE']
|
| 17 |
+
self.scaler = MinMaxScaler()
|
| 18 |
+
self.clinical_data[self.columns_to_scale] = self.scaler.fit_transform(self.clinical_data[self.columns_to_scale])
|
| 19 |
+
|
| 20 |
+
self.columns_to_convert = ['FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA', 'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC', 'CREPITUS']
|
| 21 |
+
self.mapping_dict = {}
|
| 22 |
+
for column in self.columns_to_convert:
|
| 23 |
+
self.clinical_data[column], unique_values = pd.factorize(self.clinical_data[column])
|
| 24 |
+
self.mapping_dict[column] = unique_values
|
| 25 |
+
|
| 26 |
+
self.selected_features = joblib.load("./ML_model/data/selected_features.pkl")
|
| 27 |
+
self.X_train = joblib.load("./ML_model/data/X_train.pkl")
|
| 28 |
+
self.selected_feature_index = [self.X_train.columns.get_loc(col) for col in self.selected_features]
|
| 29 |
+
|
| 30 |
+
self.explainer = lime_tabular.LimeTabularExplainer(
|
| 31 |
+
self.X_train[self.selected_features].to_numpy(),
|
| 32 |
+
feature_names=self.selected_features,
|
| 33 |
+
class_names=['0', '1', '2', '3', '4'],
|
| 34 |
+
mode='classification'
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def get_clinical_data(self, filename):
|
| 38 |
+
row = self.clinical_data.loc[self.clinical_data['FILENAME'] == filename]
|
| 39 |
+
return row.to_numpy()# remove filename col
|
| 40 |
+
|
| 41 |
+
def predict(self, overal_diagnosis, jsws, clinical=None, filename=None):
|
| 42 |
+
if not clinical:
|
| 43 |
+
if not filename:
|
| 44 |
+
raise "Need clinical data or filename in OAI database"
|
| 45 |
+
clinical = self.get_clinical_data(filename)
|
| 46 |
+
assert clinical.shape[0] != 0, "clinical data not found"
|
| 47 |
+
else:
|
| 48 |
+
raise "not implemented yet!"
|
| 49 |
+
|
| 50 |
+
x = np.concatenate((overal_diagnosis, clinical[0, 1:], np.array(jsws)))
|
| 51 |
+
|
| 52 |
+
return self.model.predict(x = x[self.selected_feature_index].reshape(1, -1)) # unsqueeze(0)
|
| 53 |
+
|
| 54 |
+
def predict_explain(self, overal_diagnosis, jsws, clinical=None, filename=None):
|
| 55 |
+
if not clinical:
|
| 56 |
+
if not filename:
|
| 57 |
+
raise "Need clinical data or filename in OAI database"
|
| 58 |
+
clinical = self.get_clinical_data(filename)
|
| 59 |
+
assert clinical.shape[0] != 0, "clinical data not found"
|
| 60 |
+
else:
|
| 61 |
+
raise "not implemented yet!"
|
| 62 |
+
|
| 63 |
+
x = np.concatenate((overal_diagnosis, clinical[0, 1:], np.array(jsws)))[self.selected_feature_index]
|
| 64 |
+
|
| 65 |
+
exp = self.explainer.explain_instance(x, self.model.predict_proba, num_features=len(self.selected_features), top_labels=1)
|
| 66 |
+
return self.model.predict(x.reshape(1, -1))[0], exp
|
ML_model/weights/xgboost_convnet_best.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ee491112dc64c61cc66b9922262c1a5920db5270dc6cf1220355a737b1c30c0
|
| 3 |
+
size 286017
|
README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Demo XDesco
|
| 3 |
+
emoji: 💻
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.32.2
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__init__.py
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import cv2
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import matplotlib.patheffects as path_effects
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
import streamlit.components.v1 as components
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
from utils import read_image, combine_mask, combine_prob_jsw, scale_coordinates, get_annotations
|
| 12 |
+
from segmentation.model import Segmenter
|
| 13 |
+
from jsw import get_JSW, calculate_diff, calculate_jsw_info
|
| 14 |
+
from classification.model import Classifier
|
| 15 |
+
from ML_model.model import MLModel
|
| 16 |
+
from gae import AnomalyExtractor
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 20 |
+
|
| 21 |
+
st.set_page_config(layout="wide")
|
| 22 |
+
|
| 23 |
+
@st.cache_resource
|
| 24 |
+
def load_models():
|
| 25 |
+
seg_model = Segmenter()
|
| 26 |
+
classif_model = Classifier('convnet')
|
| 27 |
+
ml_model = MLModel()
|
| 28 |
+
anomaly_extractor = AnomalyExtractor()
|
| 29 |
+
|
| 30 |
+
return seg_model, classif_model, ml_model, anomaly_extractor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
seg_model, classif_model, ml_model, anomaly_extractor = load_models()
|
| 34 |
+
|
| 35 |
+
st.sidebar.title("Demo xDesco System")
|
| 36 |
+
|
| 37 |
+
uploaded_file = st.sidebar.file_uploader("Choose a knee join x-ray image", type=["jpg", "jpeg", "png"])
|
| 38 |
+
|
| 39 |
+
if uploaded_file is not None:
|
| 40 |
+
file_bytes = uploaded_file.read()
|
| 41 |
+
file_name = uploaded_file.name
|
| 42 |
+
img = read_image(file_bytes)
|
| 43 |
+
|
| 44 |
+
mask = seg_model.segment(img)
|
| 45 |
+
|
| 46 |
+
mask_image = combine_mask(img, mask)
|
| 47 |
+
col1, col2 = st.columns(2)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
left_distances, right_distances, links = get_JSW(mask, dim = 10, verbose = 0)
|
| 52 |
+
left_links, right_links = links
|
| 53 |
+
jsw_m, jsw_mm = calculate_diff(left_distances, right_distances)
|
| 54 |
+
diff_percentage, mean_left, mean_right, side_min, index_min, value_min = calculate_jsw_info(left_distances, right_distances)
|
| 55 |
+
|
| 56 |
+
scaled_left_links = scale_coordinates([coord for pair in left_links for coord in pair], original_size=224, mask_size=640)
|
| 57 |
+
left_pairs = list(zip(scaled_left_links[::2], scaled_left_links[1::2]))
|
| 58 |
+
scaled_right_links = scale_coordinates([coord for pair in right_links for coord in pair], original_size=224, mask_size=640)
|
| 59 |
+
right_pairs = list(zip(scaled_right_links[::2], scaled_right_links[1::2]))
|
| 60 |
+
|
| 61 |
+
# st.write("Left JSW: [{}]".format(", ".join(f"{x:.2f}" for x in left_distances)))
|
| 62 |
+
# st.write("Right JSW: [{}]".format(", ".join(f"{x:.2f}" for x in right_distances)))
|
| 63 |
+
|
| 64 |
+
# st.write("$JSW_M$: ", jsw_m)
|
| 65 |
+
# st.write("$JSW_{MM}$: ", jsw_mm)
|
| 66 |
+
|
| 67 |
+
probabilites = classif_model.predict(img)
|
| 68 |
+
# st.write("Probability: [{}]".format(", ".join(f"{x:.2f}" for x in probabilites[0])))
|
| 69 |
+
# st.write("len: ", len(probabilites[0]))
|
| 70 |
+
annotations = get_annotations(probabilites[0])
|
| 71 |
+
|
| 72 |
+
predicted, exp = ml_model.predict_explain(probabilites[0], [jsw_m, jsw_mm], filename=file_name)
|
| 73 |
+
|
| 74 |
+
processed_img, anomaly = anomaly_extractor.extract(mask, img, verbose=0)
|
| 75 |
+
|
| 76 |
+
def plot_anomaly_with_clues(processed_img, anomaly, diff_percentage, mean_left, mean_right, side_min, index_min, value_min,
|
| 77 |
+
left_pairs, right_pairs, color='r', thickness=1):
|
| 78 |
+
# Tạo một figure và axes từ matplotlib
|
| 79 |
+
fig, ax = plt.subplots()
|
| 80 |
+
ax.imshow(processed_img, cmap = 'gray')
|
| 81 |
+
ax.imshow(anomaly, cmap='turbo', alpha = 0.3)
|
| 82 |
+
|
| 83 |
+
# # Vẽ đường thẳng từ pairs_left
|
| 84 |
+
# for pairs in [left_pairs, right_pairs]:
|
| 85 |
+
# for pair in pairs:
|
| 86 |
+
# start_point = pair[0]
|
| 87 |
+
# end_point = pair[1]
|
| 88 |
+
# ax.plot([start_point[0], end_point[0]], [start_point[1], end_point[1]], color=color, linewidth=thickness)
|
| 89 |
+
# ax.scatter(*start_point, c=color, s=thickness*2) # Vẽ điểm đầu
|
| 90 |
+
# ax.scatter(*end_point, c=color, s=thickness*2) # Vẽ điểm cuối
|
| 91 |
+
|
| 92 |
+
# Thêm chữ diff_percentage vào giữa ảnh
|
| 93 |
+
mid_x = anomaly.shape[1] // 2
|
| 94 |
+
mid_y = anomaly.shape[0] // 10
|
| 95 |
+
# ax.text(mid_x, mid_y, f'% difference between left & right joint space distance: {diff_percentage:.2f}%',
|
| 96 |
+
# color='white', fontsize=8, ha='center', va='center')
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Thêm mean_distance cho left và right
|
| 100 |
+
# left_mid_index = len(left_pairs) // 2
|
| 101 |
+
# right_mid_index = len(right_pairs) // 2
|
| 102 |
+
# left_mid_point = left_pairs[left_mid_index][0]
|
| 103 |
+
# right_mid_point = right_pairs[right_mid_index][0]
|
| 104 |
+
|
| 105 |
+
# ax.text(left_mid_point[0], anomaly.shape[0] // 1.5, f'{mean_left:.2f}', color='yellow', fontsize=20, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()])
|
| 106 |
+
# ax.text(right_mid_point[0], anomaly.shape[0] // 1.5, f'{mean_right:.2f}', color='yellow', fontsize=20, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()])
|
| 107 |
+
# print((mean_left // value_min > 2) if side_min == 0 else (mean_right // value_min > 2))
|
| 108 |
+
|
| 109 |
+
if (diff_percentage > 0) or ((mean_left / value_min > 2) if side_min == 0 else (mean_right / value_min > 2)):
|
| 110 |
+
min_pairs = left_pairs if side_min == 0 else right_pairs
|
| 111 |
+
min_pair = min_pairs[index_min]
|
| 112 |
+
start_point = min_pair[0]
|
| 113 |
+
end_point = min_pair[1]
|
| 114 |
+
|
| 115 |
+
# Xác định tọa độ để vẽ bbox xung quanh đường thẳng đứng với padding
|
| 116 |
+
padding_x = 23 # Độ rộng padding theo chiều x
|
| 117 |
+
padding_y = 13 # Độ rộng padding theo chiều y
|
| 118 |
+
min_x = start_point[0] - padding_x
|
| 119 |
+
max_x = start_point[0] + padding_x
|
| 120 |
+
min_y = min(start_point[1], end_point[1]) - padding_y
|
| 121 |
+
max_y = max(start_point[1], end_point[1]) + padding_y
|
| 122 |
+
|
| 123 |
+
rect = plt.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y, linewidth=3, edgecolor='red', facecolor='none')
|
| 124 |
+
ax.add_patch(rect)
|
| 125 |
+
|
| 126 |
+
# Thêm văn bản với giá trị value_min tại vị trí của đường thẳng có khoảng cách nhỏ nhất
|
| 127 |
+
# text = ax.text(start_point[0], min_y-padding_y//2, f'{value_min:.2f}', color='red', fontsize=18, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
|
| 128 |
+
|
| 129 |
+
# add annotation osteophyte & jsn
|
| 130 |
+
ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 3.2, f'{annotations["osteophyte"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
|
| 131 |
+
ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 1.3, f'{annotations["jsn"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
|
| 132 |
+
|
| 133 |
+
# add annoatation jsw_m & jsw_mm info
|
| 134 |
+
ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.05, f'$JSW_{{Mean}}$: {jsw_m:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()])
|
| 135 |
+
ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.1, f'$JSW_{{MM}}$: {jsw_mm:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()])
|
| 136 |
+
|
| 137 |
+
ax.axis('off')
|
| 138 |
+
# ax.set_title('Annotated anomaly map')
|
| 139 |
+
# plt.show()
|
| 140 |
+
return fig
|
| 141 |
+
|
| 142 |
+
with col1:
|
| 143 |
+
caption = "Uploaded Image"
|
| 144 |
+
st.markdown(
|
| 145 |
+
f"<h3 style='text-align: center;'>{caption}</h3>",
|
| 146 |
+
unsafe_allow_html=True
|
| 147 |
+
)
|
| 148 |
+
st.image(img, channels="BGR", use_column_width=True)
|
| 149 |
+
# plt.imshow(anomaly, cmap="turbo")
|
| 150 |
+
# plt.axis('off')
|
| 151 |
+
# st.pyplot()
|
| 152 |
+
with col2:
|
| 153 |
+
# st.image(mask_image, channels="BGR", caption='mask image', use_column_width=True)
|
| 154 |
+
|
| 155 |
+
fig = plot_anomaly_with_clues(
|
| 156 |
+
processed_img,
|
| 157 |
+
anomaly,
|
| 158 |
+
diff_percentage,
|
| 159 |
+
mean_left,
|
| 160 |
+
mean_right,
|
| 161 |
+
side_min, index_min,
|
| 162 |
+
value_min,
|
| 163 |
+
left_pairs,
|
| 164 |
+
right_pairs,
|
| 165 |
+
color='r',
|
| 166 |
+
thickness=1
|
| 167 |
+
)
|
| 168 |
+
caption = "Annotated Anomaly Map"
|
| 169 |
+
st.markdown(
|
| 170 |
+
f"<h3 style='text-align: center;'>{caption}</h3>",
|
| 171 |
+
unsafe_allow_html=True
|
| 172 |
+
)
|
| 173 |
+
st.pyplot(fig, bbox_inches='tight', pad_inches=0)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
exp_html = exp.as_html()
|
| 177 |
+
components.html(exp_html, height=1400, width=None)
|
| 178 |
+
|
| 179 |
+
# def extract_explain(exp, label):
|
| 180 |
+
# ans = exp.local_exp[label]
|
| 181 |
+
# ans = [(exp.domain_mapper.feature_names[x[0]],
|
| 182 |
+
# exp.domain_mapper.feature_values[x[0]],
|
| 183 |
+
# exp.domain_mapper.discretized_feature_names[x[0]],
|
| 184 |
+
# float(x[1])
|
| 185 |
+
# ) for x in ans]
|
| 186 |
+
|
| 187 |
+
# return ans
|
| 188 |
+
|
| 189 |
+
# explanation_list = extract_explain(exp, predicted)
|
| 190 |
+
# explanation_df = pd.DataFrame(explanation_list, columns=['Feature', 'Value', 'Explain', 'Weight'])
|
| 191 |
+
# st.table(explanation_df)
|
| 192 |
+
|
classification/__init__.py
ADDED
|
File without changes
|
classification/model.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import cv2
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from .utils import to_rgb
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_effNetv2s():
|
| 13 |
+
model = torchvision.models.efficientnet_v2_s(weights='IMAGENET1K_V1')
|
| 14 |
+
num_features = model.classifier[1].in_features
|
| 15 |
+
model.classifier[1] = nn.Sequential(
|
| 16 |
+
nn.Linear(num_features, NUM_CLASSES),
|
| 17 |
+
nn.Sigmoid()
|
| 18 |
+
)
|
| 19 |
+
return model
|
| 20 |
+
|
| 21 |
+
def create_convnet():
|
| 22 |
+
model = torchvision.models.convnext_base(weights='IMAGENET1K_V1')
|
| 23 |
+
num_features = model.classifier[2].in_features
|
| 24 |
+
model.classifier[2] = nn.Sequential(
|
| 25 |
+
nn.Linear(num_features, NUM_CLASSES),
|
| 26 |
+
nn.Sigmoid()
|
| 27 |
+
)
|
| 28 |
+
return model
|
| 29 |
+
|
| 30 |
+
def create_model(model_name):
|
| 31 |
+
model = _MODEL[model_name]()
|
| 32 |
+
model.load_state_dict(torch.load(_WEIGHT[model_name], map_location=torch.device('cpu')))
|
| 33 |
+
model.to(DEVICE)
|
| 34 |
+
return model
|
| 35 |
+
|
| 36 |
+
def create_transform():
|
| 37 |
+
transform = transforms.Compose([
|
| 38 |
+
transforms.Resize((HEIGHT, WEIGHT)),
|
| 39 |
+
transforms.ToTensor(),
|
| 40 |
+
transforms.Normalize((0.6078, 0.6078, 0.6078), (0.1932, 0.1932, 0.1932))
|
| 41 |
+
])
|
| 42 |
+
return transform
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
_MODEL = {
|
| 46 |
+
"effNetv2s": create_effNetv2s,
|
| 47 |
+
"convnet": create_convnet
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
_WEIGHT = {
|
| 51 |
+
"effNetv2s": './classification/weights/effnetv2s.pt',
|
| 52 |
+
"convnet": './classification/weights/convnet.pt',
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 56 |
+
HEIGHT = 224
|
| 57 |
+
WEIGHT = 224
|
| 58 |
+
NUM_CLASSES = 44
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Classifier():
|
| 62 |
+
def __init__(self, model_name="effNetv2s"):
|
| 63 |
+
self.model = create_model(model_name)
|
| 64 |
+
self.transform = create_transform()
|
| 65 |
+
|
| 66 |
+
def predict(self, image):
|
| 67 |
+
'''
|
| 68 |
+
input: cv2 image
|
| 69 |
+
output: multi-label probability vector
|
| 70 |
+
'''
|
| 71 |
+
image = to_rgb(image)
|
| 72 |
+
image = Image.fromarray(image)
|
| 73 |
+
image = self.transform(image)
|
| 74 |
+
self.model.eval()
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
out = self.model(image.unsqueeze(0).to(DEVICE)).cpu().numpy()
|
| 77 |
+
return out
|
classification/utils.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
|
| 3 |
+
def to_rgb(image):
|
| 4 |
+
if len(image.shape) == 3 and image.shape[-1] == 3:
|
| 5 |
+
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 6 |
+
return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
classification/weights/convnet.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe61c8220bbd41af3e536451985c1c49cc0d8b438de2da76f7fdbba8f6051741
|
| 3 |
+
size 350549706
|
gae/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .anomaly_extraction import AnomalyExtractor
|
gae/anomaly_extraction.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from skimage.filters import gaussian
|
| 6 |
+
|
| 7 |
+
from .model import DCGAN
|
| 8 |
+
from .modules import DropBlockNoise
|
| 9 |
+
from .utils import *
|
| 10 |
+
from .contour import get_contours_v2
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
IMAGE_SIZE = 224
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AnomalyExtractor():
|
| 17 |
+
def __init__(self):
|
| 18 |
+
print("================================load dcgan=================================\n")
|
| 19 |
+
dcgan = DCGAN(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),
|
| 20 |
+
architecture='two-stage',
|
| 21 |
+
output_activation='sigmoid',
|
| 22 |
+
noise=DropBlockNoise(rate=0.1, block_size=16),
|
| 23 |
+
pretrain_weights=None,
|
| 24 |
+
block_type='pixel-shuffle',
|
| 25 |
+
kernel_initializer='glorot_uniform',
|
| 26 |
+
C=1.)
|
| 27 |
+
|
| 28 |
+
self.restore_model = dcgan.generator
|
| 29 |
+
self.restore_model.load_weights("./gae/weights/gan_efficientunet_full_augment-hist_equal_generator.h5")
|
| 30 |
+
self.restore_model.trainable = False
|
| 31 |
+
|
| 32 |
+
def extract(self, mask, img, verbose=0):
|
| 33 |
+
img = to_gray(img)
|
| 34 |
+
img = np.array(Image.fromarray(img).resize((224, 224)))
|
| 35 |
+
img = preprocess(img)
|
| 36 |
+
|
| 37 |
+
if verbose:
|
| 38 |
+
# Hiển thị ảnh gốc
|
| 39 |
+
show_image(img, title="Original image")
|
| 40 |
+
plt.axis('off')
|
| 41 |
+
st.pyplot()
|
| 42 |
+
|
| 43 |
+
# Hiển thị mask
|
| 44 |
+
show_image(mask, title="Mask")
|
| 45 |
+
plt.axis('off')
|
| 46 |
+
st.pyplot()
|
| 47 |
+
|
| 48 |
+
uc, lc = get_contours_v2(mask, verbose=0)
|
| 49 |
+
|
| 50 |
+
mask = np.zeros((640, 640)).astype('uint8')
|
| 51 |
+
mask = draw_points(mask, lc, thickness=1, color=(255, 255, 255))
|
| 52 |
+
mask = draw_points(mask, uc, thickness=1, color=(255, 255, 255))
|
| 53 |
+
mask = cv2.resize(mask, (224, 224), cv2.INTER_NEAREST)
|
| 54 |
+
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
| 55 |
+
mask = mask / 255.
|
| 56 |
+
|
| 57 |
+
if verbose:
|
| 58 |
+
show_image(mask, title = "Contour")
|
| 59 |
+
plt.axis('off')
|
| 60 |
+
st.pyplot()
|
| 61 |
+
|
| 62 |
+
# sử dụng equalization histogram
|
| 63 |
+
mask = 1 - mask
|
| 64 |
+
dila = dilate(mask)
|
| 65 |
+
dilated = gaussian(dilate(mask), sigma=50, truncate=0.3)
|
| 66 |
+
|
| 67 |
+
if verbose:
|
| 68 |
+
show_image(dila, title="dilated Image")
|
| 69 |
+
plt.axis('off')
|
| 70 |
+
st.pyplot()
|
| 71 |
+
|
| 72 |
+
show_image(dilated, title="blur Image")
|
| 73 |
+
plt.axis('off')
|
| 74 |
+
st.pyplot()
|
| 75 |
+
|
| 76 |
+
im = np.expand_dims(img * (1 - dilated), axis=0)
|
| 77 |
+
im = tf.convert_to_tensor(im, dtype=tf.float32)
|
| 78 |
+
|
| 79 |
+
restored_img = self.restore_model(im)
|
| 80 |
+
|
| 81 |
+
res = tf.squeeze(tf.squeeze(restored_img, axis=-1), axis=0)
|
| 82 |
+
|
| 83 |
+
if verbose:
|
| 84 |
+
show_image(im[0], title="Masked Image")
|
| 85 |
+
plt.axis('off')
|
| 86 |
+
st.pyplot()
|
| 87 |
+
|
| 88 |
+
show_image(res, title="Reconstructed image")
|
| 89 |
+
plt.axis('off')
|
| 90 |
+
st.pyplot()
|
| 91 |
+
|
| 92 |
+
show_image(dilated*tf.abs(img - res), title="Anomaly map", cmap_type='turbo')
|
| 93 |
+
plt.axis('off')
|
| 94 |
+
st.pyplot()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# plt.imshow(img, cmap = 'gray')
|
| 98 |
+
# plt.imshow(dilated*tf.abs(img - res), cmap ='turbo', alpha = 0.3)
|
| 99 |
+
# plt.axis('off')
|
| 100 |
+
# plt.show()
|
| 101 |
+
# st.pyplot()
|
| 102 |
+
|
| 103 |
+
# show_image(dilated, title="dilated", cmap_type='gray')
|
| 104 |
+
# plt.axis('off')
|
| 105 |
+
# st.pyplot()
|
| 106 |
+
|
| 107 |
+
return img, dilated*tf.abs(img - res)
|
gae/contour.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def find_boundaries_v2(mask, top=True, verbose=0):
|
| 7 |
+
boundaries = []
|
| 8 |
+
height, width = mask.shape
|
| 9 |
+
|
| 10 |
+
contours, _ = cv2.findContours(255 * mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
| 11 |
+
|
| 12 |
+
areas = np.array([cv2.contourArea(cnt) for cnt in contours])
|
| 13 |
+
contour = contours[areas.argmax()]
|
| 14 |
+
contour = contour.reshape(-1, 2)
|
| 15 |
+
org_contour = contour.copy()
|
| 16 |
+
pos = (contour[:, 1].max() + contour[:, 1].min()) // 2
|
| 17 |
+
idx = np.where(contour[:, 1] == pos)
|
| 18 |
+
if contour[idx[0][0]][0] < contour[idx[0][-1]][0] and not top:
|
| 19 |
+
# start = contour[idx[0][0]]
|
| 20 |
+
# end = contour[idx[0][-1]]
|
| 21 |
+
start_idx = idx[0][0]
|
| 22 |
+
end_idx = idx[0][-1]
|
| 23 |
+
else:
|
| 24 |
+
# end = contour[idx[0][0]]
|
| 25 |
+
# start = contour[idx[0][-1]]
|
| 26 |
+
end_idx = idx[0][0]
|
| 27 |
+
start_idx = idx[0][-1]
|
| 28 |
+
# start_idx = ((start - contour) ** 2).sum(axis=-1).argmin()
|
| 29 |
+
# end_idx = ((end - contour) ** 2).sum(axis=-1).argmin()
|
| 30 |
+
if start_idx <= end_idx:
|
| 31 |
+
contour = contour[start_idx:end_idx + 1]
|
| 32 |
+
else:
|
| 33 |
+
contour = np.concatenate([contour[start_idx:], contour[:end_idx + 1]])
|
| 34 |
+
if verbose:
|
| 35 |
+
temp = draw_points(127 * mask.astype(np.uint8), contour, thickness=5)
|
| 36 |
+
temp = draw_points(temp, [start, end], color=[155, 155], thickness=15)
|
| 37 |
+
cv2_imshow(temp)
|
| 38 |
+
|
| 39 |
+
return np.array(contour), np.array(org_contour)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_contours_v2(mask, verbose=0):
|
| 43 |
+
upper_contour, full_upper = find_boundaries_v2(mask == 1, top=False, verbose=verbose)
|
| 44 |
+
lower_contour, full_lower = find_boundaries_v2(mask == 2, top=True, verbose=verbose)
|
| 45 |
+
if verbose:
|
| 46 |
+
temp = draw_points(127 * mask, full_upper, thickness=3, color=(255, 0, 0))
|
| 47 |
+
temp = draw_points(temp, full_lower, thickness=3)
|
| 48 |
+
plt.imshow(temp)
|
| 49 |
+
plt.title("Segmentation")
|
| 50 |
+
plt.axis('off')
|
| 51 |
+
plt.show()
|
| 52 |
+
# st.pyplot()
|
| 53 |
+
# cv2.imwrite('full.png', temp)
|
| 54 |
+
# temp = draw_points(temp, limit_points, thickness = 7, color = (0, 0, 255))
|
| 55 |
+
# cv2_imshow(temp)
|
| 56 |
+
# cv2.imwrite('limit_points.png', temp)
|
| 57 |
+
if verbose:
|
| 58 |
+
temp = draw_points(127 * mask, upper_contour, thickness=3, color=(255, 0, 0))
|
| 59 |
+
temp = draw_points(temp, lower_contour, thickness=3)
|
| 60 |
+
cv2_imshow(temp)
|
| 61 |
+
# st.pyplot()
|
| 62 |
+
# cv2.imwrite('cropped.png', temp)
|
| 63 |
+
|
| 64 |
+
return upper_contour, lower_contour
|
gae/model.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from tensorflow.keras.optimizers import Adam
|
| 8 |
+
from tensorflow.keras.models import Sequential, Model
|
| 9 |
+
|
| 10 |
+
from tensorflow.keras import layers
|
| 11 |
+
from tensorflow.keras.applications import EfficientNetV2S
|
| 12 |
+
from tensorflow.keras.layers import (
|
| 13 |
+
Dense, Flatten, Conv2D, Activation, BatchNormalization,
|
| 14 |
+
MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D,
|
| 15 |
+
Dropout, Input, concatenate, add, Conv2DTranspose, Lambda,
|
| 16 |
+
SpatialDropout2D, Cropping2D, UpSampling2D, LeakyReLU,
|
| 17 |
+
ZeroPadding2D, Reshape, Concatenate, Multiply, Permute, Add
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from .contour import get_contours_v2
|
| 21 |
+
from .modules import (
|
| 22 |
+
MultipleTrackers, DropBlockNoise, squeeze_excite_block, spatial_squeeze_excite_block,
|
| 23 |
+
channel_spatial_squeeze_excite, DoubleConv, UpSampling2D_block, Conv2DTranspose_block,
|
| 24 |
+
PixelShuffle_block
|
| 25 |
+
)
|
| 26 |
+
from .utils import mae
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
IMAGE_SIZE = 224
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def adjust_pretrained_weights(model_cls, input_size, name=None):
|
| 33 |
+
weights_model = model_cls(weights='imagenet',
|
| 34 |
+
include_top=False,
|
| 35 |
+
input_shape=(*input_size, 3))
|
| 36 |
+
target_model = model_cls(weights=None,
|
| 37 |
+
include_top=False,
|
| 38 |
+
input_shape=(*input_size, 1))
|
| 39 |
+
weights = weights_model.get_weights()
|
| 40 |
+
weights[0] = np.sum(weights[0], axis=2, keepdims=True)
|
| 41 |
+
target_model.set_weights(weights)
|
| 42 |
+
|
| 43 |
+
del weights_model
|
| 44 |
+
tf.keras.backend.clear_session()
|
| 45 |
+
gc.collect()
|
| 46 |
+
if name:
|
| 47 |
+
target_model._name = name
|
| 48 |
+
return target_model
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_efficient_unet(name=None,
|
| 52 |
+
option='full',
|
| 53 |
+
input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),
|
| 54 |
+
encoder_weights=None,
|
| 55 |
+
block_type='conv-transpose',
|
| 56 |
+
output_activation='sigmoid',
|
| 57 |
+
kernel_initializer='glorot_uniform'):
|
| 58 |
+
|
| 59 |
+
if encoder_weights == 'imagenet':
|
| 60 |
+
encoder = adjust_pretrained_weights(EfficientNetV2S, input_shape[:-1], name)
|
| 61 |
+
elif encoder_weights is None:
|
| 62 |
+
encoder = EfficientNetV2S(weights=None,
|
| 63 |
+
include_top=False,
|
| 64 |
+
input_shape=input_shape)
|
| 65 |
+
encoder._name = name
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(encoder_weights)
|
| 68 |
+
|
| 69 |
+
if option == 'encoder':
|
| 70 |
+
return encoder
|
| 71 |
+
|
| 72 |
+
MBConvBlocks = []
|
| 73 |
+
|
| 74 |
+
skip_candidates = ['1b', '2d', '3d', '4f']
|
| 75 |
+
|
| 76 |
+
for mbblock_nr in skip_candidates:
|
| 77 |
+
mbblock = encoder.get_layer('block{}_add'.format(mbblock_nr)).output
|
| 78 |
+
MBConvBlocks.append(mbblock)
|
| 79 |
+
|
| 80 |
+
head = encoder.get_layer('top_activation').output
|
| 81 |
+
blocks = MBConvBlocks + [head]
|
| 82 |
+
|
| 83 |
+
if block_type == 'upsampling':
|
| 84 |
+
UpBlock = UpSampling2D_block
|
| 85 |
+
elif block_type == 'conv-transpose':
|
| 86 |
+
UpBlock = Conv2DTranspose_block
|
| 87 |
+
elif block_type == 'pixel-shuffle':
|
| 88 |
+
UpBlock = PixelShuffle_block
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(block_type)
|
| 91 |
+
|
| 92 |
+
o = blocks.pop()
|
| 93 |
+
o = UpBlock(512, initializer=kernel_initializer, skip=blocks.pop())(o)
|
| 94 |
+
o = UpBlock(256, initializer=kernel_initializer, skip=blocks.pop())(o)
|
| 95 |
+
o = UpBlock(128, initializer=kernel_initializer, skip=blocks.pop())(o)
|
| 96 |
+
o = UpBlock(64, initializer=kernel_initializer, skip=blocks.pop())(o)
|
| 97 |
+
o = UpBlock(32, initializer=kernel_initializer, skip=None)(o)
|
| 98 |
+
o = Conv2D(input_shape[-1], (1, 1), padding='same', activation=output_activation, kernel_initializer=kernel_initializer)(o)
|
| 99 |
+
|
| 100 |
+
model = Model(encoder.input, o, name=name)
|
| 101 |
+
|
| 102 |
+
if option == 'full':
|
| 103 |
+
return model, encoder
|
| 104 |
+
elif option == 'model':
|
| 105 |
+
return model
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(option)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class DCGAN():
|
| 111 |
+
def __init__(self,
|
| 112 |
+
input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),
|
| 113 |
+
architecture='two-stage',
|
| 114 |
+
pretrain_weights=None,
|
| 115 |
+
output_activation='sigmoid',
|
| 116 |
+
block_type='conv-transpose',
|
| 117 |
+
kernel_initializer='glorot_uniform',
|
| 118 |
+
noise=None,
|
| 119 |
+
C=1.):
|
| 120 |
+
|
| 121 |
+
self.C = C
|
| 122 |
+
# Build
|
| 123 |
+
kwargs = dict(input_shape=input_shape,
|
| 124 |
+
output_activation=output_activation,
|
| 125 |
+
encoder_weights=pretrain_weights,
|
| 126 |
+
block_type=block_type,
|
| 127 |
+
kernel_initializer=kernel_initializer)
|
| 128 |
+
|
| 129 |
+
if architecture == 'two-stage':
|
| 130 |
+
encoder = get_efficient_unet(name='dcgan_disc',
|
| 131 |
+
option='encoder',
|
| 132 |
+
**kwargs)
|
| 133 |
+
|
| 134 |
+
self.generator = get_efficient_unet(name='dcgan_gen', option='model', **kwargs)
|
| 135 |
+
elif architecture == 'shared':
|
| 136 |
+
|
| 137 |
+
self.generator, encoder = get_efficient_unet(name='dcgan', option='full', **kwargs)
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError(f'Unsupport architecture: {architecture}')
|
| 140 |
+
|
| 141 |
+
gpooling = GlobalAveragePooling2D()(encoder.output)
|
| 142 |
+
prediction = Dense(1, activation='sigmoid')(gpooling)
|
| 143 |
+
self.discriminator = Model(encoder.input, prediction, name='dcgan_disc')
|
| 144 |
+
|
| 145 |
+
tf.keras.backend.clear_session()
|
| 146 |
+
_ = gc.collect()
|
| 147 |
+
|
| 148 |
+
if noise:
|
| 149 |
+
gen_inputs = self.generator.input
|
| 150 |
+
corrupted_inputs = noise(gen_inputs)
|
| 151 |
+
outputs = self.generator(corrupted_inputs)
|
| 152 |
+
self.generator = Model(gen_inputs, outputs, name='dcgan_gen')
|
| 153 |
+
|
| 154 |
+
tf.keras.backend.clear_session()
|
| 155 |
+
_ = gc.collect()
|
| 156 |
+
|
| 157 |
+
if output_activation == 'tanh':
|
| 158 |
+
|
| 159 |
+
self.process_input = layers.Lambda(lambda img: (img*2.-1.), name='dcgan_normalize')
|
| 160 |
+
self.process_output = layers.Lambda(lambda img: (img*0.5+0.5), name='dcgan_denormalize')
|
| 161 |
+
gen_inputs = self.generator.input
|
| 162 |
+
process_inputs = self.process_input(gen_inputs)
|
| 163 |
+
process_inputs = self.generator(process_inputs)
|
| 164 |
+
gen_outputs = self.process_output(process_inputs)
|
| 165 |
+
self.generator = Model(gen_inputs, gen_outputs, name='dcgan_gen')
|
| 166 |
+
|
| 167 |
+
disc_inputs = self.discriminator.input
|
| 168 |
+
process_inputs = self.process_input(disc_inputs)
|
| 169 |
+
disc_outputs = self.discriminator(process_inputs)
|
| 170 |
+
self.discriminator = Model(disc_inputs, disc_outputs, name='dcgan_disc')
|
| 171 |
+
|
| 172 |
+
tf.keras.backend.clear_session()
|
| 173 |
+
_ = gc.collect()
|
| 174 |
+
|
| 175 |
+
def summary(self):
|
| 176 |
+
self.generator.summary()
|
| 177 |
+
self.discriminator.summary()
|
| 178 |
+
|
| 179 |
+
def compile(self,
|
| 180 |
+
generator_optimizer=Adam(5e-4, 0.5),
|
| 181 |
+
discriminator_optimizer=Adam(5e-4),
|
| 182 |
+
reconstruction_loss=mae,
|
| 183 |
+
discriminative_loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
|
| 184 |
+
reconstruction_metrics=[],
|
| 185 |
+
discriminative_metrics=[]):
|
| 186 |
+
|
| 187 |
+
self.discriminator_optimizer = discriminator_optimizer
|
| 188 |
+
self.discriminator.compile(optimizer=self.discriminator_optimizer)
|
| 189 |
+
|
| 190 |
+
self.generator_optimizer = generator_optimizer
|
| 191 |
+
self.generator.compile(optimizer=self.generator_optimizer)
|
| 192 |
+
|
| 193 |
+
self.loss = discriminative_loss
|
| 194 |
+
self.reconstruction_loss = reconstruction_loss
|
| 195 |
+
self.d_loss_tracker = tf.keras.metrics.Mean()
|
| 196 |
+
self.g_loss_tracker = tf.keras.metrics.Mean()
|
| 197 |
+
self.g_recon_tracker = tf.keras.metrics.Mean()
|
| 198 |
+
self.g_disc_tracker = tf.keras.metrics.Mean()
|
| 199 |
+
|
| 200 |
+
self.g_metric_trackers = [(tf.keras.metrics.Mean(), metric) for metric in reconstruction_metrics]
|
| 201 |
+
self.d_metric_trackers = [(tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), metric) for metric in discriminative_metrics]
|
| 202 |
+
|
| 203 |
+
all_trackers = [self.d_loss_tracker, self.g_loss_tracker, self.g_recon_tracker, self.g_disc_tracker] + \
|
| 204 |
+
[tracker for tracker,_ in self.g_metric_trackers] + \
|
| 205 |
+
[tracker for t in self.d_metric_trackers for tracker in t[:-1]]
|
| 206 |
+
self.all_trackers = MultipleTrackers(all_trackers)
|
| 207 |
+
|
| 208 |
+
def discriminator_loss(self, real_output, fake_output):
|
| 209 |
+
real_loss = self.loss(tf.ones_like(real_output), real_output)
|
| 210 |
+
fake_loss = self.loss(tf.zeros_like(fake_output), fake_output)
|
| 211 |
+
total_loss = 0.5*(real_loss + fake_loss)
|
| 212 |
+
return total_loss
|
| 213 |
+
|
| 214 |
+
def generator_loss(self, fake_output):
|
| 215 |
+
return self.loss(tf.ones_like(fake_output), fake_output)
|
| 216 |
+
|
| 217 |
+
@tf.function
|
| 218 |
+
def train_step(self, images):
|
| 219 |
+
masked, original = images
|
| 220 |
+
n_samples = tf.shape(original)[0]
|
| 221 |
+
|
| 222 |
+
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
|
| 223 |
+
generated_images = self.generator(masked, training=True)
|
| 224 |
+
|
| 225 |
+
real_output = self.discriminator(original, training=True)
|
| 226 |
+
fake_output = self.discriminator(generated_images, training=True)
|
| 227 |
+
|
| 228 |
+
gen_disc_loss = self.generator_loss(fake_output)
|
| 229 |
+
recon_loss = self.reconstruction_loss(original, generated_images)
|
| 230 |
+
gen_loss = self.C*recon_loss + gen_disc_loss
|
| 231 |
+
disc_loss = self.discriminator_loss(real_output, fake_output)
|
| 232 |
+
|
| 233 |
+
gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
|
| 234 |
+
gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
|
| 235 |
+
|
| 236 |
+
self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
|
| 237 |
+
self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
|
| 238 |
+
|
| 239 |
+
self.d_loss_tracker.update_state(tf.repeat([[disc_loss]], repeats=n_samples, axis=0))
|
| 240 |
+
self.g_loss_tracker.update_state(tf.repeat([[gen_loss]], repeats=n_samples, axis=0))
|
| 241 |
+
self.g_recon_tracker.update_state(tf.repeat([[recon_loss]], repeats=n_samples, axis=0))
|
| 242 |
+
self.g_disc_tracker.update_state(tf.repeat([[gen_disc_loss]], repeats=n_samples, axis=0))
|
| 243 |
+
|
| 244 |
+
logs = {'d_loss': self.d_loss_tracker.result()}
|
| 245 |
+
|
| 246 |
+
for tracker, real_tracker, fake_tracker, metric in self.d_metric_trackers:
|
| 247 |
+
v_real = metric(tf.ones_like(real_output), real_output)
|
| 248 |
+
v_fake = metric(tf.zeros_like(fake_output), fake_output)
|
| 249 |
+
v = 0.5*(v_real + v_fake)
|
| 250 |
+
tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0))
|
| 251 |
+
real_tracker.update_state(tf.repeat([[v_real]], repeats=n_samples, axis=0))
|
| 252 |
+
fake_tracker.update_state(tf.repeat([[v_fake]], repeats=n_samples, axis=0))
|
| 253 |
+
|
| 254 |
+
metric_name = metric.__name__
|
| 255 |
+
logs['d_' + metric_name] = tracker.result()
|
| 256 |
+
logs['d_real_' + metric_name] = real_tracker.result()
|
| 257 |
+
logs['d_fake_' + metric_name] = fake_tracker.result()
|
| 258 |
+
|
| 259 |
+
logs['g_loss'] = self.g_loss_tracker.result()
|
| 260 |
+
logs['g_recon'] = self.g_recon_tracker.result()
|
| 261 |
+
logs['g_disc'] = self.g_disc_tracker.result()
|
| 262 |
+
|
| 263 |
+
for tracker, metric in self.g_metric_trackers:
|
| 264 |
+
v = metric(original, generated_images)
|
| 265 |
+
tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0))
|
| 266 |
+
logs['g_' + metric.__name__] = tracker.result()
|
| 267 |
+
|
| 268 |
+
return logs
|
| 269 |
+
|
| 270 |
+
@tf.function
|
| 271 |
+
def val_step(self, images):
|
| 272 |
+
masked, original = images
|
| 273 |
+
n_samples = tf.shape(original)[0]
|
| 274 |
+
|
| 275 |
+
generated_images = self.generator(masked, training=False)
|
| 276 |
+
|
| 277 |
+
real_output = self.discriminator(original, training=False)
|
| 278 |
+
fake_output = self.discriminator(generated_images, training=False)
|
| 279 |
+
|
| 280 |
+
gen_disc_loss = self.generator_loss(fake_output)
|
| 281 |
+
recon_loss = self.reconstruction_loss(original, generated_images)
|
| 282 |
+
gen_loss = self.C*recon_loss + gen_disc_loss
|
| 283 |
+
disc_loss = self.discriminator_loss(real_output, fake_output)
|
| 284 |
+
|
| 285 |
+
self.d_loss_tracker.update_state(tf.repeat([[disc_loss]], repeats=n_samples, axis=0))
|
| 286 |
+
self.g_loss_tracker.update_state(tf.repeat([[gen_loss]], repeats=n_samples, axis=0))
|
| 287 |
+
self.g_recon_tracker.update_state(tf.repeat([[recon_loss]], repeats=n_samples, axis=0))
|
| 288 |
+
self.g_disc_tracker.update_state(tf.repeat([[gen_disc_loss]], repeats=n_samples, axis=0))
|
| 289 |
+
|
| 290 |
+
logs = {'val_d_loss': self.d_loss_tracker.result()}
|
| 291 |
+
|
| 292 |
+
for tracker, real_tracker, fake_tracker, metric in self.d_metric_trackers:
|
| 293 |
+
v_real = metric(tf.ones_like(real_output), real_output)
|
| 294 |
+
v_fake = metric(tf.zeros_like(fake_output), fake_output)
|
| 295 |
+
v = 0.5*(v_real + v_fake)
|
| 296 |
+
tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0))
|
| 297 |
+
real_tracker.update_state(tf.repeat([[v_real]], repeats=n_samples, axis=0))
|
| 298 |
+
fake_tracker.update_state(tf.repeat([[v_fake]], repeats=n_samples, axis=0))
|
| 299 |
+
|
| 300 |
+
metric_name = metric.__name__
|
| 301 |
+
logs['val_d_' + metric_name] = tracker.result()
|
| 302 |
+
logs['val_d_real_' + metric_name] = real_tracker.result()
|
| 303 |
+
logs['val_d_fake_' + metric_name] = fake_tracker.result()
|
| 304 |
+
|
| 305 |
+
logs['val_g_loss'] = self.g_loss_tracker.result()
|
| 306 |
+
logs['val_g_recon'] = self.g_recon_tracker.result()
|
| 307 |
+
logs['val_g_disc'] = self.g_disc_tracker.result()
|
| 308 |
+
|
| 309 |
+
for tracker, metric in self.g_metric_trackers:
|
| 310 |
+
v = metric(original, generated_images)
|
| 311 |
+
tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0))
|
| 312 |
+
logs['val_g_' + metric.__name__] = tracker.result()
|
| 313 |
+
|
| 314 |
+
return logs
|
| 315 |
+
|
| 316 |
+
def fit(self,
|
| 317 |
+
trainset,
|
| 318 |
+
valset=None,
|
| 319 |
+
trainsize=-1,
|
| 320 |
+
valsize=-1,
|
| 321 |
+
epochs=1,
|
| 322 |
+
display_per_epochs=5,
|
| 323 |
+
generator_callbacks=[],
|
| 324 |
+
discriminator_callbacks=[]):
|
| 325 |
+
|
| 326 |
+
print('🌊🐉 Start Training 🐉🌊')
|
| 327 |
+
gen_callback_tracker = tf.keras.callbacks.CallbackList(
|
| 328 |
+
generator_callbacks, add_history=True, model=self.generator
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
disc_callback_tracker = tf.keras.callbacks.CallbackList(
|
| 332 |
+
discriminator_callbacks, add_history=True, model=self.discriminator
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
callbacks_tracker = MultipleTrackers([gen_callback_tracker, disc_callback_tracker])
|
| 336 |
+
|
| 337 |
+
logs = {}
|
| 338 |
+
callbacks_tracker.on_train_begin(logs=logs)
|
| 339 |
+
|
| 340 |
+
for epoch in range(epochs):
|
| 341 |
+
print(f'Epochs {epoch+1}/{epochs}:')
|
| 342 |
+
callbacks_tracker.on_epoch_begin(epoch, logs=logs)
|
| 343 |
+
|
| 344 |
+
batches = tqdm(trainset,
|
| 345 |
+
desc="Train",
|
| 346 |
+
total=trainsize,
|
| 347 |
+
unit="step",
|
| 348 |
+
position=0,
|
| 349 |
+
leave=True)
|
| 350 |
+
|
| 351 |
+
for batch, image_batch in enumerate(batches):
|
| 352 |
+
|
| 353 |
+
callbacks_tracker.on_batch_begin(batch, logs=logs)
|
| 354 |
+
callbacks_tracker.on_train_batch_begin(batch, logs=logs)
|
| 355 |
+
|
| 356 |
+
train_logs = {k:v.numpy() for k, v in self.train_step(image_batch).items()}
|
| 357 |
+
logs.update(train_logs)
|
| 358 |
+
|
| 359 |
+
callbacks_tracker.on_train_batch_end(batch, logs=logs)
|
| 360 |
+
callbacks_tracker.on_batch_end(batch, logs=logs)
|
| 361 |
+
batches.set_postfix({'d_loss': train_logs['d_loss'],
|
| 362 |
+
'g_loss': train_logs['g_loss']
|
| 363 |
+
})
|
| 364 |
+
|
| 365 |
+
# Presentation
|
| 366 |
+
stats = ", ".join("{}={:.3g}".format(k, v) for k, v in logs.items() if 'val_' not in k and 'loss' not in k)
|
| 367 |
+
print('Train:', stats)
|
| 368 |
+
|
| 369 |
+
batches.close()
|
| 370 |
+
if valset:
|
| 371 |
+
self.all_trackers.reset_state()
|
| 372 |
+
|
| 373 |
+
batches = tqdm(valset,
|
| 374 |
+
desc="Valid",
|
| 375 |
+
total=valsize,
|
| 376 |
+
unit="step",
|
| 377 |
+
position=0,
|
| 378 |
+
leave=True)
|
| 379 |
+
|
| 380 |
+
for batch, image_batch in enumerate(batches):
|
| 381 |
+
callbacks_tracker.on_batch_begin(batch, logs=logs)
|
| 382 |
+
callbacks_tracker.on_test_batch_begin(batch, logs=logs)
|
| 383 |
+
val_logs = {k:v.numpy() for k, v in self.val_step(image_batch).items()}
|
| 384 |
+
logs.update(val_logs)
|
| 385 |
+
|
| 386 |
+
callbacks_tracker.on_test_batch_end(batch, logs=logs)
|
| 387 |
+
callbacks_tracker.on_batch_end(batch, logs=logs)
|
| 388 |
+
# Presentation
|
| 389 |
+
batches.set_postfix({'val_d_loss': val_logs['val_d_loss'],
|
| 390 |
+
'val_g_loss': val_logs['val_g_loss']
|
| 391 |
+
})
|
| 392 |
+
|
| 393 |
+
stats = ", ".join("{}={:.3g}".format(k, v) for k, v in logs.items() if 'val_' in k and 'loss' not in k)
|
| 394 |
+
print('Valid:', stats)
|
| 395 |
+
|
| 396 |
+
batches.close()
|
| 397 |
+
|
| 398 |
+
if epoch % display_per_epochs == 0:
|
| 399 |
+
print('-'*128)
|
| 400 |
+
self.visualize_samples((image_batch[0][:2], image_batch[1][:2]))
|
| 401 |
+
|
| 402 |
+
self.all_trackers.reset_state()
|
| 403 |
+
|
| 404 |
+
callbacks_tracker.on_epoch_end(epoch, logs=logs)
|
| 405 |
+
# tf.keras.backend.clear_session()
|
| 406 |
+
_ = gc.collect()
|
| 407 |
+
|
| 408 |
+
if self.generator.stop_training or self.discriminator.stop_training:
|
| 409 |
+
break
|
| 410 |
+
print('-'*128)
|
| 411 |
+
|
| 412 |
+
callbacks_tracker.on_train_end(logs=logs)
|
| 413 |
+
tf.keras.backend.clear_session()
|
| 414 |
+
_ = gc.collect()
|
| 415 |
+
gen_history = None
|
| 416 |
+
for cb in gen_callback_tracker:
|
| 417 |
+
if isinstance(cb, tf.keras.callbacks.History):
|
| 418 |
+
gen_history = cb
|
| 419 |
+
gen_history.history = {k:v for k,v in cb.history.items() if 'd_' not in k}
|
| 420 |
+
|
| 421 |
+
disc_history = None
|
| 422 |
+
for cb in disc_callback_tracker:
|
| 423 |
+
if isinstance(cb, tf.keras.callbacks.History):
|
| 424 |
+
disc_history = cb
|
| 425 |
+
disc_history.history = {k:v for k,v in cb.history.items() if 'g_' not in k}
|
| 426 |
+
|
| 427 |
+
return {'generator':gen_history,
|
| 428 |
+
'discriminator':disc_history}
|
| 429 |
+
|
| 430 |
+
def visualize_samples(self, samples, figsize=(12, 2)):
|
| 431 |
+
x, y = samples
|
| 432 |
+
y_pred = self.generator.predict(x[:2], verbose=0)
|
| 433 |
+
fig, axs = plt.subplots(1, 6, figsize=figsize)
|
| 434 |
+
for i in range(2):
|
| 435 |
+
pos = 3*i
|
| 436 |
+
axs[pos].imshow(x[i], cmap='gray', vmin=0., vmax=1.)
|
| 437 |
+
axs[pos].set_title('Masked')
|
| 438 |
+
axs[pos].axis('off')
|
| 439 |
+
axs[pos+1].imshow(y[i], cmap='gray', vmin=0., vmax=1.)
|
| 440 |
+
axs[pos+1].set_title('Original')
|
| 441 |
+
axs[pos+1].axis('off')
|
| 442 |
+
axs[pos+2].imshow(y_pred[i], cmap='gray', vmin=0., vmax=1.)
|
| 443 |
+
axs[pos+2].set_title('Predicted')
|
| 444 |
+
axs[pos+2].axis('off')
|
| 445 |
+
plt.show()
|
| 446 |
+
|
| 447 |
+
# tf.keras.backend.clear_session()
|
| 448 |
+
del y_pred
|
| 449 |
+
_ = gc.collect()
|
gae/modules.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras.__internal__.layers import BaseRandomLayer
|
| 3 |
+
from tensorflow.keras.layers import (
|
| 4 |
+
Dense, Flatten, Conv2D, Activation, BatchNormalization,
|
| 5 |
+
MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D,
|
| 6 |
+
Dropout, Input, concatenate, add, Conv2DTranspose, Lambda,
|
| 7 |
+
SpatialDropout2D, Cropping2D, UpSampling2D, LeakyReLU,
|
| 8 |
+
ZeroPadding2D, Reshape, Concatenate, Multiply, Permute, Add
|
| 9 |
+
)
|
| 10 |
+
from keras import backend as K
|
| 11 |
+
|
| 12 |
+
from .utils import normalize_tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MultipleTrackers():
|
| 16 |
+
def __init__(self, callback_lists: list):
|
| 17 |
+
self.callbacks_list = callback_lists
|
| 18 |
+
|
| 19 |
+
def __getattr__(self, attr):
|
| 20 |
+
def helper(*arg, **kwarg):
|
| 21 |
+
for cb in self.callbacks_list:
|
| 22 |
+
getattr(cb, attr)(*arg, **kwarg)
|
| 23 |
+
if attr in self.__class__.__dict__:
|
| 24 |
+
return getattr(self, attr)
|
| 25 |
+
else:
|
| 26 |
+
return helper
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DropBlockNoise(BaseRandomLayer):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
rate,
|
| 33 |
+
block_size,
|
| 34 |
+
seed=None,
|
| 35 |
+
**kwargs,
|
| 36 |
+
):
|
| 37 |
+
super().__init__(seed=seed, **kwargs)
|
| 38 |
+
if not 0.0 <= rate <= 1.0:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
f"rate must be a number between 0 and 1. " f"Received: {rate}"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self._rate = rate
|
| 44 |
+
(
|
| 45 |
+
self._dropblock_height,
|
| 46 |
+
self._dropblock_width,
|
| 47 |
+
) = normalize_tuple(
|
| 48 |
+
value=block_size, n=2, name="block_size", allow_zero=False
|
| 49 |
+
)
|
| 50 |
+
self.seed = seed
|
| 51 |
+
|
| 52 |
+
def call(self, x, training=None):
|
| 53 |
+
if not training or self._rate == 0.0:
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
_, height, width, _ = tf.split(tf.shape(x), 4)
|
| 57 |
+
|
| 58 |
+
# Unnest scalar values
|
| 59 |
+
height = tf.squeeze(height)
|
| 60 |
+
width = tf.squeeze(width)
|
| 61 |
+
|
| 62 |
+
dropblock_height = tf.math.minimum(self._dropblock_height, height)
|
| 63 |
+
dropblock_width = tf.math.minimum(self._dropblock_width, width)
|
| 64 |
+
|
| 65 |
+
gamma = (
|
| 66 |
+
self._rate
|
| 67 |
+
* tf.cast(width * height, dtype=tf.float32)
|
| 68 |
+
/ tf.cast(dropblock_height * dropblock_width, dtype=tf.float32)
|
| 69 |
+
/ tf.cast(
|
| 70 |
+
(width - self._dropblock_width + 1)
|
| 71 |
+
* (height - self._dropblock_height + 1),
|
| 72 |
+
tf.float32,
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Forces the block to be inside the feature map.
|
| 77 |
+
w_i, h_i = tf.meshgrid(tf.range(width), tf.range(height))
|
| 78 |
+
valid_block = tf.logical_and(
|
| 79 |
+
tf.logical_and(
|
| 80 |
+
w_i >= int(dropblock_width // 2),
|
| 81 |
+
w_i < width - (dropblock_width - 1) // 2,
|
| 82 |
+
),
|
| 83 |
+
tf.logical_and(
|
| 84 |
+
h_i >= int(dropblock_height // 2),
|
| 85 |
+
h_i < width - (dropblock_height - 1) // 2,
|
| 86 |
+
),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
valid_block = tf.reshape(valid_block, [1, height, width, 1])
|
| 90 |
+
|
| 91 |
+
random_noise = self._random_generator.random_uniform(
|
| 92 |
+
tf.shape(x), dtype=tf.float32
|
| 93 |
+
)
|
| 94 |
+
valid_block = tf.cast(valid_block, dtype=tf.float32)
|
| 95 |
+
seed_keep_rate = tf.cast(1 - gamma, dtype=tf.float32)
|
| 96 |
+
block_pattern = (1 - valid_block + seed_keep_rate + random_noise) >= 1
|
| 97 |
+
block_pattern = tf.cast(block_pattern, dtype=tf.float32)
|
| 98 |
+
|
| 99 |
+
window_size = [1, self._dropblock_height, self._dropblock_width, 1]
|
| 100 |
+
|
| 101 |
+
# Double negative and max_pool is essentially min_pooling
|
| 102 |
+
block_pattern = -tf.nn.max_pool(
|
| 103 |
+
-block_pattern,
|
| 104 |
+
ksize=window_size,
|
| 105 |
+
strides=[1, 1, 1, 1],
|
| 106 |
+
padding="SAME",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return (
|
| 110 |
+
x * tf.cast(block_pattern, x.dtype)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def squeeze_excite_block(input, ratio=16):
|
| 115 |
+
''' Create a channel-wise squeeze-excite block
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
input: input tensor
|
| 119 |
+
filters: number of output filters
|
| 120 |
+
|
| 121 |
+
Returns: a keras tensor
|
| 122 |
+
|
| 123 |
+
References
|
| 124 |
+
- [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507)
|
| 125 |
+
'''
|
| 126 |
+
init = input
|
| 127 |
+
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 128 |
+
filters = int(init.shape[channel_axis])
|
| 129 |
+
se_shape = (1, 1, filters)
|
| 130 |
+
|
| 131 |
+
se = GlobalAveragePooling2D()(init)
|
| 132 |
+
se = Reshape(se_shape)(se)
|
| 133 |
+
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
|
| 134 |
+
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
|
| 135 |
+
|
| 136 |
+
if K.image_data_format() == 'channels_first':
|
| 137 |
+
se = Permute((3, 1, 2))(se)
|
| 138 |
+
|
| 139 |
+
x = Multiply()([init, se])
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def spatial_squeeze_excite_block(input):
|
| 144 |
+
''' Create a spatial squeeze-excite block
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
input: input tensor
|
| 148 |
+
|
| 149 |
+
Returns: a keras tensor
|
| 150 |
+
|
| 151 |
+
References
|
| 152 |
+
- [Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks](https://arxiv.org/abs/1803.02579)
|
| 153 |
+
'''
|
| 154 |
+
|
| 155 |
+
se = Conv2D(1, (1, 1), activation='sigmoid', use_bias=False,
|
| 156 |
+
kernel_initializer='he_normal')(input)
|
| 157 |
+
|
| 158 |
+
x = Multiply()([input, se])
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def channel_spatial_squeeze_excite(input, ratio=16):
|
| 163 |
+
''' Create a spatial squeeze-excite block
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
input: input tensor
|
| 167 |
+
filters: number of output filters
|
| 168 |
+
|
| 169 |
+
Returns: a keras tensor
|
| 170 |
+
|
| 171 |
+
References
|
| 172 |
+
- [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507)
|
| 173 |
+
- [Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks](https://arxiv.org/abs/1803.02579)
|
| 174 |
+
'''
|
| 175 |
+
|
| 176 |
+
cse = squeeze_excite_block(input, ratio)
|
| 177 |
+
sse = spatial_squeeze_excite_block(input)
|
| 178 |
+
|
| 179 |
+
x = Add()([cse, sse])
|
| 180 |
+
return x
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def DoubleConv(filters, kernel_size, initializer='glorot_uniform'):
|
| 184 |
+
def layer(x):
|
| 185 |
+
|
| 186 |
+
x = Conv2D(filters, kernel_size, padding='same', kernel_initializer=initializer)(x)
|
| 187 |
+
x = BatchNormalization()(x)
|
| 188 |
+
x = Activation('swish')(x)
|
| 189 |
+
x = Conv2D(filters, kernel_size, padding='same', kernel_initializer=initializer)(x)
|
| 190 |
+
x = BatchNormalization()(x)
|
| 191 |
+
x = Activation('swish')(x)
|
| 192 |
+
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
return layer
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def UpSampling2D_block(filters, kernel_size=(3, 3), upsample_rate=(2, 2), interpolation='bilinear',
|
| 199 |
+
initializer='glorot_uniform', skip=None):
|
| 200 |
+
def layer(input_tensor):
|
| 201 |
+
|
| 202 |
+
x = UpSampling2D(size=upsample_rate, interpolation=interpolation)(input_tensor)
|
| 203 |
+
|
| 204 |
+
if skip is not None:
|
| 205 |
+
x = Concatenate()([x, skip])
|
| 206 |
+
|
| 207 |
+
x = DoubleConv(filters, kernel_size, initializer=initializer)(x)
|
| 208 |
+
x = channel_spatial_squeeze_excite(x)
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
return layer
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def Conv2DTranspose_block(filters, transpose_kernel_size=(3, 3), upsample_rate=(2, 2),
|
| 215 |
+
initializer='glorot_uniform', skip=None, met_input=None, sat_input=None):
|
| 216 |
+
def layer(input_tensor):
|
| 217 |
+
x = Conv2DTranspose(filters, transpose_kernel_size, strides=upsample_rate, padding='same')(input_tensor)
|
| 218 |
+
if skip is not None:
|
| 219 |
+
x = Concatenate()([x, skip])
|
| 220 |
+
|
| 221 |
+
x = DoubleConv(filters, transpose_kernel_size, initializer=initializer)(x)
|
| 222 |
+
x = channel_spatial_squeeze_excite(x)
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
return layer
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def PixelShuffle_block(filters, kernel_size=(3, 3), upsample_rate=2,
|
| 229 |
+
initializer='glorot_uniform', skip=None, met_input=None, sat_input=None):
|
| 230 |
+
def layer(input_tensor):
|
| 231 |
+
x = Conv2D(filters * (upsample_rate ** 2), kernel_size, padding="same",
|
| 232 |
+
activation="swish", kernel_initializer='Orthogonal')(input_tensor)
|
| 233 |
+
x = tf.nn.depth_to_space(x, upsample_rate)
|
| 234 |
+
if skip is not None:
|
| 235 |
+
x = Concatenate()([x, skip])
|
| 236 |
+
|
| 237 |
+
x = DoubleConv(filters, kernel_size, initializer=initializer)(x)
|
| 238 |
+
x = channel_spatial_squeeze_excite(x)
|
| 239 |
+
return x
|
| 240 |
+
|
| 241 |
+
return layer
|
gae/utils.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy import ndimage
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
from skimage import exposure
|
| 7 |
+
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
import tensorflow_addons as tfa
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
IMAGE_SIZE = 224
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def acc(y_true, y_pred, threshold=0.5):
|
| 16 |
+
threshold = tf.cast(threshold, y_pred.dtype)
|
| 17 |
+
y_pred = tf.cast(y_pred > threshold, y_pred.dtype)
|
| 18 |
+
return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))
|
| 19 |
+
|
| 20 |
+
def mae(y_true, y_pred):
|
| 21 |
+
return tf.reduce_mean(tf.abs(y_true-y_pred))
|
| 22 |
+
|
| 23 |
+
def inv_ssim(y_true, y_pred):
|
| 24 |
+
return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
|
| 25 |
+
|
| 26 |
+
def inv_msssim(y_true, y_pred):
|
| 27 |
+
return 1 - tf.reduce_mean(tf.image.ssim_multiscale(y_true, y_pred, 1.0, filter_size=4))
|
| 28 |
+
|
| 29 |
+
def inv_msssim_l1(y_true, y_pred, alpha=0.8):
|
| 30 |
+
return alpha*inv_msssim(y_true, y_pred) + (1-alpha)*mae(y_true, y_pred)
|
| 31 |
+
|
| 32 |
+
def inv_msssim_gaussian_l1(y_true, y_pred, alpha=0.8):
|
| 33 |
+
l1_diff = tf.abs(y_true-y_pred)
|
| 34 |
+
gaussian_l1 = tfa.image.gaussian_filter2d(l1_diff, filter_shape=(11, 11), sigma=1.5)
|
| 35 |
+
return alpha*inv_msssim(y_true, y_pred) + (1-alpha)*gaussian_l1
|
| 36 |
+
|
| 37 |
+
def psnr(y_true, y_pred):
|
| 38 |
+
return tf.reduce_mean(tf.image.psnr(y_true, y_pred, 1.0))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def show_image(image, title='Image', cmap_type='gray'):
|
| 42 |
+
plt.imshow(image, cmap=cmap_type)
|
| 43 |
+
plt.title(title)
|
| 44 |
+
plt.axis('off')
|
| 45 |
+
plt.show()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# đảo màu những ảnh bị ngược màu
|
| 49 |
+
def remove_negative(img):
|
| 50 |
+
outside = np.mean(img[ : , 0])
|
| 51 |
+
inside = np.mean(img[ : , int(IMAGE_SIZE / 2)])
|
| 52 |
+
if outside < inside:
|
| 53 |
+
return img
|
| 54 |
+
else:
|
| 55 |
+
return 1 - img
|
| 56 |
+
|
| 57 |
+
# lựa chọn tiền xử lý: ảnh gốc, Equalization histogram, CLAHE
|
| 58 |
+
def preprocess(img):
|
| 59 |
+
img = remove_negative(img)
|
| 60 |
+
|
| 61 |
+
img = exposure.equalize_hist(img)
|
| 62 |
+
img = exposure.equalize_adapthist(img)
|
| 63 |
+
img = exposure.equalize_hist(img)
|
| 64 |
+
return img
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# dilate contour
|
| 68 |
+
def dilate(mask_img):
|
| 69 |
+
kernel_size = 2 * 22 + 1
|
| 70 |
+
kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
|
| 71 |
+
return ndimage.binary_dilation(mask_img == 0, structure=kernel)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def normalize_tuple(value, n, name, allow_zero=False):
|
| 75 |
+
"""Transforms non-negative/positive integer/integers into an integer tuple.
|
| 76 |
+
Args:
|
| 77 |
+
value: The value to validate and convert. Could an int, or any iterable of
|
| 78 |
+
ints.
|
| 79 |
+
n: The size of the tuple to be returned.
|
| 80 |
+
name: The name of the argument being validated, e.g. "strides" or
|
| 81 |
+
"kernel_size". This is only used to format error messages.
|
| 82 |
+
allow_zero: Default to False. A ValueError will raised if zero is received
|
| 83 |
+
and this param is False.
|
| 84 |
+
Returns:
|
| 85 |
+
A tuple of n integers.
|
| 86 |
+
Raises:
|
| 87 |
+
ValueError: If something else than an int/long or iterable thereof or a
|
| 88 |
+
negative value is
|
| 89 |
+
passed.
|
| 90 |
+
"""
|
| 91 |
+
error_msg = (
|
| 92 |
+
f"The `{name}` argument must be a tuple of {n} "
|
| 93 |
+
f"integers. Received: {value}"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if isinstance(value, int):
|
| 97 |
+
value_tuple = (value,) * n
|
| 98 |
+
else:
|
| 99 |
+
try:
|
| 100 |
+
value_tuple = tuple(value)
|
| 101 |
+
except TypeError:
|
| 102 |
+
raise ValueError(error_msg)
|
| 103 |
+
if len(value_tuple) != n:
|
| 104 |
+
raise ValueError(error_msg)
|
| 105 |
+
for single_value in value_tuple:
|
| 106 |
+
try:
|
| 107 |
+
int(single_value)
|
| 108 |
+
except (ValueError, TypeError):
|
| 109 |
+
error_msg += (
|
| 110 |
+
f"including element {single_value} of "
|
| 111 |
+
f"type {type(single_value)}"
|
| 112 |
+
)
|
| 113 |
+
raise ValueError(error_msg)
|
| 114 |
+
|
| 115 |
+
if allow_zero:
|
| 116 |
+
unqualified_values = {v for v in value_tuple if v < 0}
|
| 117 |
+
req_msg = ">= 0"
|
| 118 |
+
else:
|
| 119 |
+
unqualified_values = {v for v in value_tuple if v <= 0}
|
| 120 |
+
req_msg = "> 0"
|
| 121 |
+
|
| 122 |
+
if unqualified_values:
|
| 123 |
+
error_msg += (
|
| 124 |
+
f" including {unqualified_values}"
|
| 125 |
+
f" that does not satisfy the requirement `{req_msg}`."
|
| 126 |
+
)
|
| 127 |
+
raise ValueError(error_msg)
|
| 128 |
+
|
| 129 |
+
return value_tuple
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def draw_points(image, points, color=None, random_color=False, same=True, thickness=1):
|
| 133 |
+
if color is None and not random_color:
|
| 134 |
+
color = (0, 255, 0) # Màu mặc định là xanh lá cây (BGR)
|
| 135 |
+
if random_color:
|
| 136 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 137 |
+
|
| 138 |
+
image = to_color(image)
|
| 139 |
+
|
| 140 |
+
for point in points:
|
| 141 |
+
if random_color and not same:
|
| 142 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 143 |
+
|
| 144 |
+
x, y = point
|
| 145 |
+
image = cv2.circle(image, (x, y), thickness, color, -1) # Vẽ điểm lên ảnh
|
| 146 |
+
return image
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def draw_lines(image, pairs, color=None, random_color=False, same=True, thickness=1):
|
| 150 |
+
image_with_line = to_color(np.copy(image))
|
| 151 |
+
|
| 152 |
+
if color is None and not random_color:
|
| 153 |
+
color = (0, 255, 0) # Màu mặc định là xanh lá cây (BGR)
|
| 154 |
+
if random_color:
|
| 155 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 156 |
+
|
| 157 |
+
# Vẽ đường thẳng dựa trên danh sách các cặp điểm
|
| 158 |
+
for pair in pairs:
|
| 159 |
+
|
| 160 |
+
if random_color and not same:
|
| 161 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 162 |
+
|
| 163 |
+
start_point = pair[0]
|
| 164 |
+
end_point = pair[1]
|
| 165 |
+
image_with_line = cv2.line(image_with_line, start_point, end_point, color, thickness)
|
| 166 |
+
image_with_line = cv2.circle(image_with_line, start_point, thickness + 1, color, -1)
|
| 167 |
+
image_with_line = cv2.circle(image_with_line, end_point, thickness + 1, color, -1)
|
| 168 |
+
|
| 169 |
+
return image_with_line
|
| 170 |
+
|
| 171 |
+
def to_color(image):
|
| 172 |
+
if len(image.shape) == 3 and image.shape[-1] == 3:
|
| 173 |
+
return image
|
| 174 |
+
return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def to_gray(image):
|
| 178 |
+
if len(image.shape) == 3 and image.shape[-1] == 3:
|
| 179 |
+
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 180 |
+
return image
|
gae/weights/gan_efficientunet_full_augment-hist_equal_generator.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6b3f57e8b8c2b6d56f1b0ac9eaceaa693600ccd52ba637293aee23eaf40a819
|
| 3 |
+
size 230002208
|
jsw/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .cal_jsw import get_JSW, calculate_diff, calculate_jsw_info
|
jsw/cal_jsw.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
from .utils import *
|
| 6 |
+
from .contour import get_contours
|
| 7 |
+
from .distance import distance
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def pooling_array(array, n, mode='mean'):
|
| 11 |
+
if mode == 'mean':
|
| 12 |
+
pool = lambda x: np.mean(x)
|
| 13 |
+
elif mode == 'min':
|
| 14 |
+
pool = lambda x: np.min(x)
|
| 15 |
+
elif mode == 'sum':
|
| 16 |
+
pool = lambda x: np.sum(x)
|
| 17 |
+
|
| 18 |
+
if n == 1:
|
| 19 |
+
return pool(array)
|
| 20 |
+
|
| 21 |
+
array_length = len(array)
|
| 22 |
+
if array_length < n:
|
| 23 |
+
return array
|
| 24 |
+
segment_length = array_length // n
|
| 25 |
+
remaining_elements = array_length % n
|
| 26 |
+
|
| 27 |
+
if remaining_elements == 0:
|
| 28 |
+
segments = np.split(array, n)
|
| 29 |
+
else:
|
| 30 |
+
mid = remaining_elements * (segment_length + 1)
|
| 31 |
+
segments = np.split(array[:mid], remaining_elements)
|
| 32 |
+
segments += np.split(array[mid:], n - remaining_elements)
|
| 33 |
+
|
| 34 |
+
segments = [pool(segment) for segment in segments]
|
| 35 |
+
|
| 36 |
+
return np.array(segments)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def pool_links(links, dim, mode="mean"):
|
| 40 |
+
'''
|
| 41 |
+
links là 1 list gồm nhiều cặp tọa độ trên dưới (knee join space giữa lower_contour vs upper_contour) có dạng:
|
| 42 |
+
[(array([436, 421], dtype=int32), array([436, 451], dtype=int32)), (array([436, 421], dtype=int32), array([436, 451], dtype=int32)), ...]
|
| 43 |
+
đầu tiên lấy danh sách các tọa độ x của các pair coord này, pool nó y như pool distance, rồi filter links ban đầu để lấy các pair có x thuộc pooled
|
| 44 |
+
|
| 45 |
+
'''
|
| 46 |
+
pooled_x = pooling_array(np.array(links)[:, 0, 0], dim, mode)
|
| 47 |
+
filtered_pairs = []
|
| 48 |
+
for x in pooled_x:
|
| 49 |
+
for pair in links:
|
| 50 |
+
if pair[0][0] == int(x):
|
| 51 |
+
filtered_pairs.append(pair)
|
| 52 |
+
break # Lấy 1 cặp và dừng lại
|
| 53 |
+
return filtered_pairs
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_JSW(mask, dim=None, pool='mean', p=0.3, verbose=0):
|
| 57 |
+
'''
|
| 58 |
+
input: mask (H, W) with femur is 1 and tibia is 2
|
| 59 |
+
output: 2 distance vectors (left, right) (aggregated to <dim> by <pool>), links (list of pairs coords for each side)
|
| 60 |
+
'''
|
| 61 |
+
if isinstance(mask, str):
|
| 62 |
+
mask = cv2.imread(mask, 0)
|
| 63 |
+
if mask is None:
|
| 64 |
+
return np.zeros(10), np.zeros(10)
|
| 65 |
+
uc, lc = get_contours(mask, verbose=verbose)
|
| 66 |
+
left_distances, right_distances, links, contours = distance(mask, uc, lc, p=p, verbose=verbose)
|
| 67 |
+
if verbose:
|
| 68 |
+
# print('in getjsw')
|
| 69 |
+
temp = draw_points(mask * 127, contours[0], thickness=3, color=(255, 0, 0))
|
| 70 |
+
temp = draw_points(temp, contours[1], thickness=5, color=(255, 0, 0))
|
| 71 |
+
temp = draw_points(temp, contours[2], thickness=5, color=(0, 255, 0))
|
| 72 |
+
temp = draw_points(temp, contours[3], thickness=5, color=(0, 255, 0))
|
| 73 |
+
temp = draw_lines(temp, links[::6], color=(0, 0, 255), thickness=2)
|
| 74 |
+
# cv2_imshow(temp)
|
| 75 |
+
# cv2.imwrite("drawn_lines.png", temp)
|
| 76 |
+
if dim:
|
| 77 |
+
left_distances = pooling_array(left_distances, dim, pool)
|
| 78 |
+
right_distances = pooling_array(right_distances, dim, pool)
|
| 79 |
+
# print(left_distances)
|
| 80 |
+
# print(right_distances)
|
| 81 |
+
|
| 82 |
+
# print(links[1])
|
| 83 |
+
links = pool_links(links[0], dim, mode="mean"), pool_links(links[1], dim, mode="mean")
|
| 84 |
+
return left_distances, right_distances, links
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def calculate_diff(left_jsw, right_jsw):
|
| 88 |
+
'''
|
| 89 |
+
input: left_distances and right_distances vectors
|
| 90 |
+
output: jsw_m and jsw_mm
|
| 91 |
+
'''
|
| 92 |
+
jsw_max = max(np.max(left_jsw), np.max(right_jsw))
|
| 93 |
+
jsw_max_side = np.argmax([np.max(left_jsw), np.max(right_jsw)])
|
| 94 |
+
if jsw_max_side == 0:
|
| 95 |
+
jsw_min = np.min(right_jsw)
|
| 96 |
+
else:
|
| 97 |
+
jsw_min = np.min(left_jsw)
|
| 98 |
+
|
| 99 |
+
diff_mean = abs(np.mean(left_jsw) - np.mean(right_jsw)) / jsw_max
|
| 100 |
+
diff_max_min = (jsw_max - jsw_min) / jsw_max
|
| 101 |
+
|
| 102 |
+
return diff_mean, diff_max_min
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def calculate_jsw_info(left_jsw, right_jsw):
|
| 106 |
+
'''
|
| 107 |
+
input: left_distances and right_distances vectors
|
| 108 |
+
output: % diff, mean_left, mean_right, side_min (0 is left, 1 is right), index with value min (in side min), value min (in side min)
|
| 109 |
+
'''
|
| 110 |
+
mean_left = np.mean(left_jsw)
|
| 111 |
+
mean_right = np.mean(right_jsw)
|
| 112 |
+
|
| 113 |
+
side_min, index_min, value_min = (0, np.argmin(left_jsw), np.min(left_jsw)) if mean_left <= mean_right else (1, np.argmin(right_jsw), np.min(right_jsw))
|
| 114 |
+
|
| 115 |
+
diff_percentage = np.abs((mean_left - mean_right) / mean_right) * 100
|
| 116 |
+
|
| 117 |
+
return diff_percentage, mean_left, mean_right, side_min, index_min, value_min
|
jsw/contour.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
from .utils import *
|
| 6 |
+
|
| 7 |
+
def detect_limit_points(mask, verbose=0):
|
| 8 |
+
# tìm giới hạn hai bên của khớp gối
|
| 9 |
+
'''
|
| 10 |
+
input: mask (H, W) with femur is 1 and tibia is 2
|
| 11 |
+
output: 4 points serve as limit points to determine the upper and lower contours from the full contours of the femur and tibia
|
| 12 |
+
'''
|
| 13 |
+
h, w = mask.shape
|
| 14 |
+
res = []
|
| 15 |
+
upper_pivot = np.array([0, w // 2]) # r c
|
| 16 |
+
lower_pivot = np.array([h, w // 2]) # r c
|
| 17 |
+
|
| 18 |
+
left_slice = slice(0, w // 2)
|
| 19 |
+
right_slice = slice(w // 2, None)
|
| 20 |
+
center_slice = slice(int(0.2 * h), int(0.8 * h))
|
| 21 |
+
|
| 22 |
+
left = np.zeros_like(mask)
|
| 23 |
+
left[center_slice, left_slice] = mask[center_slice, left_slice]
|
| 24 |
+
|
| 25 |
+
right = np.zeros_like(mask)
|
| 26 |
+
right[center_slice, right_slice] = mask[center_slice, right_slice]
|
| 27 |
+
|
| 28 |
+
if verbose:
|
| 29 |
+
cv2_imshow([left, right])
|
| 30 |
+
|
| 31 |
+
pivot = np.array([0, w])
|
| 32 |
+
coords = np.argwhere(left == 1)
|
| 33 |
+
distances = ((coords - pivot) ** 2).sum(axis=-1)
|
| 34 |
+
point = coords[distances.argmax()][::-1]
|
| 35 |
+
res.append(point)
|
| 36 |
+
|
| 37 |
+
pivot = np.array([0, 0])
|
| 38 |
+
coords = np.argwhere(right == 1)
|
| 39 |
+
distances = ((coords - pivot) ** 2).sum(axis=-1)
|
| 40 |
+
point = coords[distances.argmax()][::-1]
|
| 41 |
+
res.append(point)
|
| 42 |
+
|
| 43 |
+
pivot = np.array([h, w])
|
| 44 |
+
coords = np.argwhere(left == 2)
|
| 45 |
+
distances = ((coords - pivot) ** 2).sum(axis=-1)
|
| 46 |
+
point = coords[distances.argmax()][::-1]
|
| 47 |
+
res.append(point)
|
| 48 |
+
|
| 49 |
+
pivot = np.array([h, 0])
|
| 50 |
+
coords = np.argwhere(right == 2)
|
| 51 |
+
distances = ((coords - pivot) ** 2).sum(axis=-1)
|
| 52 |
+
point = coords[distances.argmax()][::-1]
|
| 53 |
+
res.append(point)
|
| 54 |
+
|
| 55 |
+
if verbose:
|
| 56 |
+
cv2_imshow(draw_points(127 * mask, res))
|
| 57 |
+
|
| 58 |
+
return res
|
| 59 |
+
|
| 60 |
+
def find_boundaries(mask, start, end, top=True, verbose=0):
|
| 61 |
+
# nếu top = True, tìm đường bao bên trên cùng từ left đến right
|
| 62 |
+
# nếu top = False, tìm đường bao dưới cùng từ left đến right
|
| 63 |
+
'''
|
| 64 |
+
input:
|
| 65 |
+
mask (H, W) of femur or tibia
|
| 66 |
+
start, end is limit point to extract upper_contour/lower_contour from femur/tibia full contour
|
| 67 |
+
output: upper_contour/lower_contour
|
| 68 |
+
use top = True if determine lower contour from tibia mask
|
| 69 |
+
'''
|
| 70 |
+
boundaries = []
|
| 71 |
+
height, width = mask.shape
|
| 72 |
+
|
| 73 |
+
contours, _ = cv2.findContours(255 * mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
| 74 |
+
|
| 75 |
+
areas = np.array([cv2.contourArea(cnt) for cnt in contours])
|
| 76 |
+
contour = contours[areas.argmax()]
|
| 77 |
+
contour = contour.reshape(-1, 2)
|
| 78 |
+
org_contour = contour.copy()
|
| 79 |
+
|
| 80 |
+
start_idx = ((start - contour) ** 2).sum(axis=-1).argmin()
|
| 81 |
+
end_idx = ((end - contour) ** 2).sum(axis=-1).argmin()
|
| 82 |
+
if start_idx <= end_idx:
|
| 83 |
+
contour = contour[start_idx:end_idx + 1]
|
| 84 |
+
else:
|
| 85 |
+
contour = np.concatenate([contour[start_idx:], contour[:end_idx + 1]])
|
| 86 |
+
|
| 87 |
+
if top:
|
| 88 |
+
sorted_indices = np.argsort(contour[:, 1])[::-1]
|
| 89 |
+
else:
|
| 90 |
+
sorted_indices = np.argsort(contour[:, 1])
|
| 91 |
+
contour = contour[sorted_indices]
|
| 92 |
+
|
| 93 |
+
unique_indices = sorted(np.unique(contour[:, 0], return_index=True)[1])
|
| 94 |
+
contour = contour[unique_indices]
|
| 95 |
+
sorted_indices = np.argsort(contour[:, 0])
|
| 96 |
+
contour = contour[sorted_indices]
|
| 97 |
+
if verbose:
|
| 98 |
+
temp = draw_points(127 * mask.astype(np.uint8), contour, thickness=5)
|
| 99 |
+
temp = draw_points(temp, [start, end], color=[155, 155], thickness=15)
|
| 100 |
+
# cv2_imshow(temp)
|
| 101 |
+
|
| 102 |
+
return np.array(contour), np.array(org_contour)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_contours(mask, verbose=0):
|
| 106 |
+
'''
|
| 107 |
+
input: mask (H, W) with femur is 1 and tibia is 2
|
| 108 |
+
output: upper_contour, lower_contour
|
| 109 |
+
'''
|
| 110 |
+
limit_points = detect_limit_points(mask, verbose=verbose)
|
| 111 |
+
upper_contour, full_upper = find_boundaries(mask == 1, limit_points[0], limit_points[1], top=False, verbose=verbose)
|
| 112 |
+
lower_contour, full_lower = find_boundaries(mask == 2, limit_points[3], limit_points[2], top=True, verbose=verbose)
|
| 113 |
+
if verbose:
|
| 114 |
+
temp = draw_points(127 * mask, full_upper, thickness=3, color=(255, 0, 0))
|
| 115 |
+
temp = draw_points(temp, full_lower, thickness=3)
|
| 116 |
+
# cv2_imshow(temp)
|
| 117 |
+
# cv2.imwrite('full.png', temp)
|
| 118 |
+
temp = draw_points(temp, limit_points, thickness=7, color=(0, 0, 255))
|
| 119 |
+
# cv2_imshow(temp)
|
| 120 |
+
# cv2.imwrite('limit_points.png', temp)
|
| 121 |
+
if verbose:
|
| 122 |
+
temp = draw_points(127 * mask, upper_contour, thickness=3, color=(255, 0, 0))
|
| 123 |
+
temp = draw_points(temp, lower_contour, thickness=3)
|
| 124 |
+
# cv2_imshow(temp)
|
| 125 |
+
# cv2.imwrite('cropped.png', temp)
|
| 126 |
+
|
| 127 |
+
return upper_contour, lower_contour
|
jsw/distance.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.linear_model import LinearRegression
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
from .utils import *
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def distance(mask, upper_contour, lower_contour, p=0.12, verbose = 0):
|
| 10 |
+
'''
|
| 11 |
+
input:
|
| 12 |
+
mask (H, W) with femur is 1 and tibia is 2
|
| 13 |
+
upper_contour, lower_contour
|
| 14 |
+
output: left_distances and right_distances vectors
|
| 15 |
+
'''
|
| 16 |
+
left, right = getMiddle(mask, lower_contour, verbose = verbose)
|
| 17 |
+
x_min = max(lower_contour[0,0],upper_contour[0,0])
|
| 18 |
+
x_max = min(lower_contour[-1,0],upper_contour[-1,0])
|
| 19 |
+
|
| 20 |
+
left_idx = np.where(lower_contour[:,0] == left)[0][0]
|
| 21 |
+
right_idx = np.where(lower_contour[:,0] == right)[0][0]
|
| 22 |
+
left_lower_contour = lower_contour[left_idx:]
|
| 23 |
+
right_lower_contour = lower_contour[:right_idx+1][::-1]
|
| 24 |
+
|
| 25 |
+
left_lower_contour = lower_contour[(lower_contour[:,0]<=left) & (lower_contour[:,0]>=x_min) ]
|
| 26 |
+
right_lower_contour = lower_contour[(lower_contour[:,0]>=right) & (lower_contour[:,0]<=x_max)][::-1]
|
| 27 |
+
|
| 28 |
+
left_upper_contour = upper_contour[(upper_contour[:,0]<=left) & (upper_contour[:,0]>=x_min)]
|
| 29 |
+
right_upper_contour = upper_contour[(upper_contour[:,0]>=right) & (upper_contour[:,0]<=x_max)][::-1]
|
| 30 |
+
|
| 31 |
+
if verbose == 1:
|
| 32 |
+
temp = draw_points(mask*127, left_lower_contour, color = (0, 255, 0), thickness = 5)
|
| 33 |
+
temp = draw_points(temp, right_lower_contour, color = (0, 255, 0), thickness = 5)
|
| 34 |
+
temp = draw_points(temp, left_upper_contour, color = (255, 0, 0), thickness = 5)
|
| 35 |
+
temp = draw_points(temp, right_upper_contour, color = (255, 0, 0), thickness = 5)
|
| 36 |
+
# cv2_imshow(temp)
|
| 37 |
+
# cv2.imwrite('center_cropped.png', temp)
|
| 38 |
+
links = list(zip(left_upper_contour,left_lower_contour)), list(zip(right_upper_contour,right_lower_contour))
|
| 39 |
+
|
| 40 |
+
temp = left_upper_contour, right_upper_contour, left_lower_contour, right_lower_contour
|
| 41 |
+
|
| 42 |
+
return left_lower_contour[:,1]-left_upper_contour[:,1], right_lower_contour[:,1] - right_upper_contour[:,1], links, temp
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def getMiddle(mask, contour, verbose=0):
|
| 46 |
+
'''
|
| 47 |
+
use Linear Regression to construct a straight line to remove the bone segment from upper_contour and lower_contour
|
| 48 |
+
input:
|
| 49 |
+
mask (H, W) with femur is 1 and tibia is 2
|
| 50 |
+
lower_contour
|
| 51 |
+
output: 2 point (left and right) define the boundary within which every point of the bone segment is located
|
| 52 |
+
'''
|
| 53 |
+
X = contour[:, 0].reshape(-1, 1)
|
| 54 |
+
y = contour[:, 1]
|
| 55 |
+
reg = LinearRegression().fit(X, y)
|
| 56 |
+
i_min = np.argmin(y[int(len(y) * 0.2):int(len(y) * 0.8)]) + int(len(y) * 0.2)
|
| 57 |
+
left = i_min - 1
|
| 58 |
+
right = i_min + 1
|
| 59 |
+
left_check = False
|
| 60 |
+
right_check = False
|
| 61 |
+
if verbose == 1:
|
| 62 |
+
# print('get Middle 1')
|
| 63 |
+
cmask = draw_points(mask, contour, thickness=3, color=(255, 0, 0))
|
| 64 |
+
cmask = draw_points(cmask, np.hstack([X, reg.predict(X).reshape(-1, 1).astype('int')]),thickness = 5,
|
| 65 |
+
color=(0, 255, 0))
|
| 66 |
+
# cv2.imwrite("lr_mask.png", cmask)
|
| 67 |
+
# cv2_imshow(cmask)
|
| 68 |
+
# plt.show()
|
| 69 |
+
|
| 70 |
+
while True:
|
| 71 |
+
while not left_check:
|
| 72 |
+
if y[left] > reg.predict(X[left].reshape(-1, 1)):
|
| 73 |
+
break
|
| 74 |
+
left -= 1
|
| 75 |
+
while not right_check:
|
| 76 |
+
if y[right] > reg.predict(X[right].reshape(-1, 1)):
|
| 77 |
+
break
|
| 78 |
+
right += 1
|
| 79 |
+
if verbose == 1:
|
| 80 |
+
# print('get middle 2')
|
| 81 |
+
cmask = draw_points(cmask, [contour[left]], thickness=10, color=(255, 255, 0))
|
| 82 |
+
cmask = draw_points(cmask, [contour[right]], thickness=10, color=(0, 255, 255))
|
| 83 |
+
# print(cmask.shape)
|
| 84 |
+
# cv2.imwrite("lr.png", cmask)
|
| 85 |
+
# cv2_imshow(cmask)
|
| 86 |
+
# plt.show()
|
| 87 |
+
|
| 88 |
+
left_min = np.argmin(y[int(len(y) * 0.2):left]) + int(len(y) * 0.2) if int(len(y) * 0.2) < left else left
|
| 89 |
+
right_min = np.argmin(y[right:int(len(y) * 0.8)]) + right if right < int(len(y) * 0.8) else right
|
| 90 |
+
if y[left_min] > reg.predict(X[left_min].reshape(-1, 1)):
|
| 91 |
+
left_check = True
|
| 92 |
+
if y[right_min] > reg.predict(X[right_min].reshape(-1, 1)):
|
| 93 |
+
right_check = True
|
| 94 |
+
if right_check and left_check:
|
| 95 |
+
break
|
| 96 |
+
left = left_min - 1
|
| 97 |
+
right = right_min + 1
|
| 98 |
+
return min(X.flatten()[left], X.flatten()[right]), max(X.flatten()[left], X.flatten()[right])
|
jsw/utils.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import cv2
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def draw_points(image, points, color=None, random_color=False, same=True, thickness=1):
|
| 7 |
+
if color is None and not random_color:
|
| 8 |
+
color = (0, 255, 0) # Màu mặc định là xanh lá cây (BGR)
|
| 9 |
+
if random_color:
|
| 10 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 11 |
+
|
| 12 |
+
image = to_color(image)
|
| 13 |
+
|
| 14 |
+
for point in points:
|
| 15 |
+
if random_color and not same:
|
| 16 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 17 |
+
|
| 18 |
+
x, y = point
|
| 19 |
+
image = cv2.circle(image, (x, y), thickness, color, -1) # Vẽ điểm lên ảnh
|
| 20 |
+
return image
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def draw_lines(image, pairs, color=None, random_color=False, same=True, thickness=1):
|
| 24 |
+
image_with_line = to_color(np.copy(image))
|
| 25 |
+
|
| 26 |
+
if color is None and not random_color:
|
| 27 |
+
color = (255, 0, 0) # Màu mặc định là xanh lá cây (BGR)
|
| 28 |
+
if random_color:
|
| 29 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 30 |
+
|
| 31 |
+
# Vẽ đường thẳng dựa trên danh sách các cặp điểm
|
| 32 |
+
for pair in pairs:
|
| 33 |
+
|
| 34 |
+
if random_color and not same:
|
| 35 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 36 |
+
|
| 37 |
+
start_point = pair[0]
|
| 38 |
+
end_point = pair[1]
|
| 39 |
+
image_with_line = cv2.line(image_with_line, start_point, end_point, color, thickness)
|
| 40 |
+
image_with_line = cv2.circle(image_with_line, start_point, thickness + 1, color, -1)
|
| 41 |
+
image_with_line = cv2.circle(image_with_line, end_point, thickness + 1, color, -1)
|
| 42 |
+
|
| 43 |
+
return image_with_line
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def center(contour):
|
| 47 |
+
idx = len(contour) // 2
|
| 48 |
+
return contour[idx]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def to_color(image):
|
| 52 |
+
if len(image.shape) == 3 and image.shape[-1] == 3:
|
| 53 |
+
return image
|
| 54 |
+
return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def to_gray(image):
|
| 58 |
+
if len(image.shape) == 3 and image.shape[-1] == 3:
|
| 59 |
+
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 60 |
+
return image
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def cv2_imshow(images):
|
| 64 |
+
if not isinstance(images, list):
|
| 65 |
+
images = [images]
|
| 66 |
+
|
| 67 |
+
num_images = len(images)
|
| 68 |
+
|
| 69 |
+
# Hiển thị ảnh đơn lẻ trực tiếp bằng imshow
|
| 70 |
+
if num_images == 1:
|
| 71 |
+
image = images[0]
|
| 72 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 73 |
+
# Ảnh màu (RGB)
|
| 74 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 75 |
+
plt.imshow(image_rgb)
|
| 76 |
+
else:
|
| 77 |
+
# Ảnh xám
|
| 78 |
+
plt.imshow(image, cmap='gray')
|
| 79 |
+
|
| 80 |
+
plt.axis("off")
|
| 81 |
+
plt.show()
|
| 82 |
+
else:
|
| 83 |
+
# Hiển thị nhiều ảnh trên cùng một cột
|
| 84 |
+
fig, ax = plt.subplots(num_images, 1, figsize=(4, 4 * num_images))
|
| 85 |
+
|
| 86 |
+
for i in range(num_images):
|
| 87 |
+
image = images[i]
|
| 88 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 89 |
+
# Ảnh màu (RGB)
|
| 90 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 91 |
+
ax[i].imshow(image_rgb)
|
| 92 |
+
else:
|
| 93 |
+
# Ảnh xám
|
| 94 |
+
ax[i].imshow(image, cmap='gray')
|
| 95 |
+
|
| 96 |
+
ax[i].axis("off")
|
| 97 |
+
|
| 98 |
+
plt.tight_layout()
|
| 99 |
+
plt.show()
|
| 100 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit==1.32.2
|
| 2 |
+
opencv-python==4.9.0.80
|
| 3 |
+
ultralytics
|
| 4 |
+
numpy==1.26.4
|
| 5 |
+
torch==2.2.2
|
| 6 |
+
scikit-learn==1.2.2
|
| 7 |
+
|
| 8 |
+
keras_cv~=0.5.0
|
| 9 |
+
tqdm==4.66.2
|
| 10 |
+
keras-core==0.1.7
|
| 11 |
+
tensorflow==2.14
|
| 12 |
+
tensorflow-addons==0.22.0
|
| 13 |
+
scipy==1.12.0
|
| 14 |
+
matplotlib==3.8.3
|
| 15 |
+
scikit-image==0.22.0
|
| 16 |
+
Pillow==10.2.0
|
| 17 |
+
|
| 18 |
+
xgboost==2.1.2
|
| 19 |
+
lime==0.2.0.1
|
segmentation/__init__.py
ADDED
|
File without changes
|
segmentation/model.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ultralytics import YOLO
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
BEST_WEIGHT = './segmentation/weights/oai_s_best4.pt'
|
| 6 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
+
|
| 8 |
+
def create_model(weight_path=None):
|
| 9 |
+
if weight_path:
|
| 10 |
+
return YOLO(weight_path)
|
| 11 |
+
else:
|
| 12 |
+
return YOLO(BEST_WEIGHT)
|
| 13 |
+
|
| 14 |
+
class Segmenter():
|
| 15 |
+
def __init__(self, weight_path=None):
|
| 16 |
+
self.model = create_model(weight_path).to(DEVICE)
|
| 17 |
+
|
| 18 |
+
def segment(self, img):
|
| 19 |
+
"""
|
| 20 |
+
input: image (H, W, C)
|
| 21 |
+
output: mask (H, W) with femur is 1 and tibia is 2
|
| 22 |
+
"""
|
| 23 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 24 |
+
|
| 25 |
+
eimg = cv2.equalizeHist(img)
|
| 26 |
+
|
| 27 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 28 |
+
eimg = clahe.apply(eimg)
|
| 29 |
+
|
| 30 |
+
eimg = cv2.cvtColor(eimg, cv2.COLOR_GRAY2RGB)
|
| 31 |
+
|
| 32 |
+
res = self.model(eimg, verbose=False)
|
| 33 |
+
|
| 34 |
+
mask = res[0].masks.data[0] * (res[0].boxes.cls[0] + 1) + res[0].masks.data[1] * (res[0].boxes.cls[1] + 1)
|
| 35 |
+
mask = mask.cpu().numpy()
|
| 36 |
+
|
| 37 |
+
return mask
|
segmentation/weights/oai_s_best4.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0159a91db899213c34db9cc93db1b5a31576eb3b78eb61bd1d9a29a4ef92e843
|
| 3 |
+
size 6771320
|
utils.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
# from shared_queue import image_queue
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def read_image(file_bytes):
|
| 7 |
+
np_img = np.frombuffer(file_bytes, np.uint8)
|
| 8 |
+
img = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
|
| 9 |
+
return img
|
| 10 |
+
|
| 11 |
+
def combine_mask(image, mask, label2color={1: (255, 255, 0), 2: (0, 255, 255)}, alpha=0.1):
|
| 12 |
+
image = to_color(image)
|
| 13 |
+
image = cv2.resize(image, mask.shape)
|
| 14 |
+
mask_image = np.zeros_like(image)
|
| 15 |
+
for label, color in label2color.items():
|
| 16 |
+
mask_image[mask == label] = color
|
| 17 |
+
|
| 18 |
+
mask_image = cv2.addWeighted(image, 1 - alpha, mask_image, alpha, 0)
|
| 19 |
+
return mask_image
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def to_color(image):
|
| 23 |
+
if len(image.shape) == 3 and image.shape[-1] == 3:
|
| 24 |
+
return image
|
| 25 |
+
return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def to_gray(image):
|
| 29 |
+
if len(image.shape) == 3 and image.shape[-1] == 3:
|
| 30 |
+
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 31 |
+
return image
|
| 32 |
+
|
| 33 |
+
def prob_process(prob_list):
|
| 34 |
+
prob = prob_list[0]
|
| 35 |
+
max_idx = np.argmax(prob)
|
| 36 |
+
min_idx = np.argmin(prob)
|
| 37 |
+
prob[max_idx] += prob[min_idx]
|
| 38 |
+
prob[min_idx] = 0
|
| 39 |
+
|
| 40 |
+
if max_idx == 0:
|
| 41 |
+
return 1 - prob[max_idx]
|
| 42 |
+
elif max_idx ==1:
|
| 43 |
+
return prob[max_idx]
|
| 44 |
+
else:
|
| 45 |
+
return 1 + prob[max_idx]
|
| 46 |
+
|
| 47 |
+
def combine_prob_jsw(prob, jsw_m, jsw_mm):
|
| 48 |
+
processed_prob = prob_process(prob)
|
| 49 |
+
temp = np.array([processed_prob, jsw_m, jsw_mm])
|
| 50 |
+
if temp.any():
|
| 51 |
+
temp = np.hstack(temp)
|
| 52 |
+
else:
|
| 53 |
+
temp = temp.flatten()
|
| 54 |
+
|
| 55 |
+
return temp
|
| 56 |
+
|
| 57 |
+
def scale_coordinates(links, original_size, mask_size):
|
| 58 |
+
scale_factor = original_size / mask_size
|
| 59 |
+
scaled_links = [(int(x * scale_factor), int(y * scale_factor)) for x, y in links]
|
| 60 |
+
return scaled_links
|
| 61 |
+
|
| 62 |
+
def get_annotations(probabilites):
|
| 63 |
+
OSTEOPYTE_LEVELS = {
|
| 64 |
+
0: "Definite osteophytes",
|
| 65 |
+
1: "No osteophytes",
|
| 66 |
+
2: "Possible osteophytes",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
JNS_LEVELS = {
|
| 70 |
+
0: "Definite JSN",
|
| 71 |
+
1: "Mild JSN",
|
| 72 |
+
2: "No JSN",
|
| 73 |
+
3: "Severe JSN",
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return {
|
| 77 |
+
"osteophyte": OSTEOPYTE_LEVELS[np.argmax(probabilites[5:8])],
|
| 78 |
+
"jsn": JNS_LEVELS[np.argmax(probabilites[8:12])],
|
| 79 |
+
}
|