File size: 9,991 Bytes
036e7c4 8531190 036e7c4 68847f4 036e7c4 19bd631 036e7c4 19bd631 036e7c4 19bd631 036e7c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import streamlit as st
import cv2
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
from PIL import Image
import numpy as np
import os
import streamlit.components.v1 as components
import pandas as pd
from utils import read_image, combine_mask, combine_prob_jsw, scale_coordinates, get_annotations
from segmentation.model import Segmenter
from jsw import get_JSW, calculate_diff, calculate_jsw_info
from classification.model import Classifier
from ML_model.model import MLModel
from gae import AnomalyExtractor
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
st.set_page_config(layout="wide", page_title="xDesCO", page_icon="🏥")
@st.cache_resource
def load_models():
seg_model = Segmenter()
classif_model = Classifier('convnet')
ml_model = MLModel()
anomaly_extractor = AnomalyExtractor()
return seg_model, classif_model, ml_model, anomaly_extractor
seg_model, classif_model, ml_model, anomaly_extractor = load_models()
st.sidebar.title("xDesCO: Explainable AI for Knee Osteoarthritis Diagnosis")
# with st.sidebar.expander("Download data"):
uploaded_file = st.sidebar.file_uploader("Choose a knee join x-ray image. Download the OAI dataset at [here](https://www.kaggle.com/datasets/shashwatwork/knee-osteoarthritis-dataset-with-severity).", type=["jpg", "jpeg", "png"])
st.sidebar.markdown("---")
st.sidebar.caption(":bulb: Diagnose knee osteoarthritis with explainable AI insights, enabling users to upload X-ray images for knee osteoarthritis diagnosis with detailed visual explanations of abnormalities and disease severity.")
# Add space to push footer to the bottom
st.markdown(
"""
<style>
[data-testid="stSidebar"]::after {
content: "Powered by the xDesCO framework.";
position: absolute;
bottom: 20px;
left: 20px;
font-size: 15px;
color: gray;
}
</style>
""",
unsafe_allow_html=True,
)
if uploaded_file is not None:
file_bytes = uploaded_file.read()
file_name = uploaded_file.name
img = read_image(file_bytes)
mask = seg_model.segment(img)
mask_image = combine_mask(img, mask)
col1, col2 = st.columns(2)
left_distances, right_distances, links = get_JSW(mask, dim = 10, verbose = 0)
left_links, right_links = links
jsw_m, jsw_mm = calculate_diff(left_distances, right_distances)
diff_percentage, mean_left, mean_right, side_min, index_min, value_min = calculate_jsw_info(left_distances, right_distances)
scaled_left_links = scale_coordinates([coord for pair in left_links for coord in pair], original_size=224, mask_size=640)
left_pairs = list(zip(scaled_left_links[::2], scaled_left_links[1::2]))
scaled_right_links = scale_coordinates([coord for pair in right_links for coord in pair], original_size=224, mask_size=640)
right_pairs = list(zip(scaled_right_links[::2], scaled_right_links[1::2]))
# st.write("Left JSW: [{}]".format(", ".join(f"{x:.2f}" for x in left_distances)))
# st.write("Right JSW: [{}]".format(", ".join(f"{x:.2f}" for x in right_distances)))
# st.write("$JSW_M$: ", jsw_m)
# st.write("$JSW_{MM}$: ", jsw_mm)
probabilites = classif_model.predict(img)
# st.write("Probability: [{}]".format(", ".join(f"{x:.2f}" for x in probabilites[0])))
# st.write("len: ", len(probabilites[0]))
annotations = get_annotations(probabilites[0])
predicted, exp = ml_model.predict_explain(probabilites[0], [jsw_m, jsw_mm], filename=file_name)
processed_img, anomaly = anomaly_extractor.extract(mask, img, verbose=0)
def plot_anomaly_with_clues(processed_img, anomaly, diff_percentage, mean_left, mean_right, side_min, index_min, value_min,
left_pairs, right_pairs, color='r', thickness=1):
# Tạo một figure và axes từ matplotlib
fig, ax = plt.subplots()
ax.imshow(processed_img, cmap = 'gray')
ax.imshow(anomaly, cmap='turbo', alpha = 0.3)
# # Vẽ đường thẳng từ pairs_left
# for pairs in [left_pairs, right_pairs]:
# for pair in pairs:
# start_point = pair[0]
# end_point = pair[1]
# ax.plot([start_point[0], end_point[0]], [start_point[1], end_point[1]], color=color, linewidth=thickness)
# ax.scatter(*start_point, c=color, s=thickness*2) # Vẽ điểm đầu
# ax.scatter(*end_point, c=color, s=thickness*2) # Vẽ điểm cuối
# Thêm chữ diff_percentage vào giữa ảnh
mid_x = anomaly.shape[1] // 2
mid_y = anomaly.shape[0] // 10
# ax.text(mid_x, mid_y, f'% difference between left & right joint space distance: {diff_percentage:.2f}%',
# color='white', fontsize=8, ha='center', va='center')
# Thêm mean_distance cho left và right
# left_mid_index = len(left_pairs) // 2
# right_mid_index = len(right_pairs) // 2
# left_mid_point = left_pairs[left_mid_index][0]
# right_mid_point = right_pairs[right_mid_index][0]
# ax.text(left_mid_point[0], anomaly.shape[0] // 1.5, f'{mean_left:.2f}', color='yellow', fontsize=20, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()])
# ax.text(right_mid_point[0], anomaly.shape[0] // 1.5, f'{mean_right:.2f}', color='yellow', fontsize=20, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()])
# print((mean_left // value_min > 2) if side_min == 0 else (mean_right // value_min > 2))
if (diff_percentage > 0) or ((mean_left / value_min > 2) if side_min == 0 else (mean_right / value_min > 2)):
min_pairs = left_pairs if side_min == 0 else right_pairs
min_pair = min_pairs[index_min]
start_point = min_pair[0]
end_point = min_pair[1]
# Xác định tọa độ để vẽ bbox xung quanh đường thẳng đứng với padding
padding_x = 23 # Độ rộng padding theo chiều x
padding_y = 13 # Độ rộng padding theo chiều y
min_x = start_point[0] - padding_x
max_x = start_point[0] + padding_x
min_y = min(start_point[1], end_point[1]) - padding_y
max_y = max(start_point[1], end_point[1]) + padding_y
rect = plt.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y, linewidth=3, edgecolor='red', facecolor='none')
ax.add_patch(rect)
# Thêm văn bản với giá trị value_min tại vị trí của đường thẳng có khoảng cách nhỏ nhất
# text = ax.text(start_point[0], min_y-padding_y//2, f'{value_min:.2f}', color='red', fontsize=18, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
# add annotation osteophyte & jsn
ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 3.2, f'{annotations["osteophyte"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 1.3, f'{annotations["jsn"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
# add annoatation jsw_m & jsw_mm info
ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.05, f'$JSW_{{Mean}}$: {jsw_m:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()])
ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.1, f'$JSW_{{MM}}$: {jsw_mm:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()])
ax.axis('off')
# ax.set_title('Annotated anomaly map')
# plt.show()
return fig
with col1:
caption = "Uploaded Image"
st.markdown(
f"<h3 style='text-align: center;'>{caption}</h3>",
unsafe_allow_html=True
)
st.image(img, channels="BGR", use_column_width=True)
# plt.imshow(anomaly, cmap="turbo")
# plt.axis('off')
# st.pyplot()
with col2:
# st.image(mask_image, channels="BGR", caption='mask image', use_column_width=True)
fig = plot_anomaly_with_clues(
processed_img,
anomaly,
diff_percentage,
mean_left,
mean_right,
side_min, index_min,
value_min,
left_pairs,
right_pairs,
color='r',
thickness=1
)
caption = "Annotated Anomaly Map"
st.markdown(
f"<h3 style='text-align: center;'>{caption}</h3>",
unsafe_allow_html=True
)
st.pyplot(fig, bbox_inches='tight', pad_inches=0)
exp_html = exp.as_html()
full_html = exp_html + """
<style>
div[class="lime explanation"] { display: none; } /* This targets the tree section */
.lime.table_div {overflow-x: hidden}
</style>
"""
components.html(full_html, height=320, width=None, scrolling=True)
def extract_explain(exp, label):
ans = exp.local_exp[label]
ans = [(exp.domain_mapper.feature_names[x[0]],
exp.domain_mapper.feature_values[x[0]],
exp.domain_mapper.discretized_feature_names[x[0]],
float(x[1])
) for x in ans]
return ans
explanation_list = extract_explain(exp, predicted)
explanation_df = pd.DataFrame(explanation_list, columns=['Feature', 'Value', 'Explain', 'Weight'])
st.table(explanation_df)
|