xDesCO / app.py
Nguyễn Thành Đạt
update GUI
68847f4
import streamlit as st
import cv2
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
from PIL import Image
import numpy as np
import os
import streamlit.components.v1 as components
import pandas as pd
from utils import read_image, combine_mask, combine_prob_jsw, scale_coordinates, get_annotations
from segmentation.model import Segmenter
from jsw import get_JSW, calculate_diff, calculate_jsw_info
from classification.model import Classifier
from ML_model.model import MLModel
from gae import AnomalyExtractor
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
st.set_page_config(layout="wide", page_title="xDesCO", page_icon="🏥")
@st.cache_resource
def load_models():
seg_model = Segmenter()
classif_model = Classifier('convnet')
ml_model = MLModel()
anomaly_extractor = AnomalyExtractor()
return seg_model, classif_model, ml_model, anomaly_extractor
seg_model, classif_model, ml_model, anomaly_extractor = load_models()
st.sidebar.title("xDesCO: Explainable AI for Knee Osteoarthritis Diagnosis")
# with st.sidebar.expander("Download data"):
uploaded_file = st.sidebar.file_uploader("Choose a knee join x-ray image. Download the OAI dataset at [here](https://www.kaggle.com/datasets/shashwatwork/knee-osteoarthritis-dataset-with-severity).", type=["jpg", "jpeg", "png"])
st.sidebar.markdown("---")
st.sidebar.caption(":bulb: Diagnose knee osteoarthritis with explainable AI insights, enabling users to upload X-ray images for knee osteoarthritis diagnosis with detailed visual explanations of abnormalities and disease severity.")
# Add space to push footer to the bottom
st.markdown(
"""
<style>
[data-testid="stSidebar"]::after {
content: "Powered by the xDesCO framework.";
position: absolute;
bottom: 20px;
left: 20px;
font-size: 15px;
color: gray;
}
</style>
""",
unsafe_allow_html=True,
)
if uploaded_file is not None:
file_bytes = uploaded_file.read()
file_name = uploaded_file.name
img = read_image(file_bytes)
mask = seg_model.segment(img)
mask_image = combine_mask(img, mask)
col1, col2 = st.columns(2)
left_distances, right_distances, links = get_JSW(mask, dim = 10, verbose = 0)
left_links, right_links = links
jsw_m, jsw_mm = calculate_diff(left_distances, right_distances)
diff_percentage, mean_left, mean_right, side_min, index_min, value_min = calculate_jsw_info(left_distances, right_distances)
scaled_left_links = scale_coordinates([coord for pair in left_links for coord in pair], original_size=224, mask_size=640)
left_pairs = list(zip(scaled_left_links[::2], scaled_left_links[1::2]))
scaled_right_links = scale_coordinates([coord for pair in right_links for coord in pair], original_size=224, mask_size=640)
right_pairs = list(zip(scaled_right_links[::2], scaled_right_links[1::2]))
# st.write("Left JSW: [{}]".format(", ".join(f"{x:.2f}" for x in left_distances)))
# st.write("Right JSW: [{}]".format(", ".join(f"{x:.2f}" for x in right_distances)))
# st.write("$JSW_M$: ", jsw_m)
# st.write("$JSW_{MM}$: ", jsw_mm)
probabilites = classif_model.predict(img)
# st.write("Probability: [{}]".format(", ".join(f"{x:.2f}" for x in probabilites[0])))
# st.write("len: ", len(probabilites[0]))
annotations = get_annotations(probabilites[0])
predicted, exp = ml_model.predict_explain(probabilites[0], [jsw_m, jsw_mm], filename=file_name)
processed_img, anomaly = anomaly_extractor.extract(mask, img, verbose=0)
def plot_anomaly_with_clues(processed_img, anomaly, diff_percentage, mean_left, mean_right, side_min, index_min, value_min,
left_pairs, right_pairs, color='r', thickness=1):
# Tạo một figure và axes từ matplotlib
fig, ax = plt.subplots()
ax.imshow(processed_img, cmap = 'gray')
ax.imshow(anomaly, cmap='turbo', alpha = 0.3)
# # Vẽ đường thẳng từ pairs_left
# for pairs in [left_pairs, right_pairs]:
# for pair in pairs:
# start_point = pair[0]
# end_point = pair[1]
# ax.plot([start_point[0], end_point[0]], [start_point[1], end_point[1]], color=color, linewidth=thickness)
# ax.scatter(*start_point, c=color, s=thickness*2) # Vẽ điểm đầu
# ax.scatter(*end_point, c=color, s=thickness*2) # Vẽ điểm cuối
# Thêm chữ diff_percentage vào giữa ảnh
mid_x = anomaly.shape[1] // 2
mid_y = anomaly.shape[0] // 10
# ax.text(mid_x, mid_y, f'% difference between left & right joint space distance: {diff_percentage:.2f}%',
# color='white', fontsize=8, ha='center', va='center')
# Thêm mean_distance cho left và right
# left_mid_index = len(left_pairs) // 2
# right_mid_index = len(right_pairs) // 2
# left_mid_point = left_pairs[left_mid_index][0]
# right_mid_point = right_pairs[right_mid_index][0]
# ax.text(left_mid_point[0], anomaly.shape[0] // 1.5, f'{mean_left:.2f}', color='yellow', fontsize=20, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()])
# ax.text(right_mid_point[0], anomaly.shape[0] // 1.5, f'{mean_right:.2f}', color='yellow', fontsize=20, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()])
# print((mean_left // value_min > 2) if side_min == 0 else (mean_right // value_min > 2))
if (diff_percentage > 0) or ((mean_left / value_min > 2) if side_min == 0 else (mean_right / value_min > 2)):
min_pairs = left_pairs if side_min == 0 else right_pairs
min_pair = min_pairs[index_min]
start_point = min_pair[0]
end_point = min_pair[1]
# Xác định tọa độ để vẽ bbox xung quanh đường thẳng đứng với padding
padding_x = 23 # Độ rộng padding theo chiều x
padding_y = 13 # Độ rộng padding theo chiều y
min_x = start_point[0] - padding_x
max_x = start_point[0] + padding_x
min_y = min(start_point[1], end_point[1]) - padding_y
max_y = max(start_point[1], end_point[1]) + padding_y
rect = plt.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y, linewidth=3, edgecolor='red', facecolor='none')
ax.add_patch(rect)
# Thêm văn bản với giá trị value_min tại vị trí của đường thẳng có khoảng cách nhỏ nhất
# text = ax.text(start_point[0], min_y-padding_y//2, f'{value_min:.2f}', color='red', fontsize=18, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
# add annotation osteophyte & jsn
ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 3.2, f'{annotations["osteophyte"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 1.3, f'{annotations["jsn"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()])
# add annoatation jsw_m & jsw_mm info
ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.05, f'$JSW_{{Mean}}$: {jsw_m:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()])
ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.1, f'$JSW_{{MM}}$: {jsw_mm:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()])
ax.axis('off')
# ax.set_title('Annotated anomaly map')
# plt.show()
return fig
with col1:
caption = "Uploaded Image"
st.markdown(
f"<h3 style='text-align: center;'>{caption}</h3>",
unsafe_allow_html=True
)
st.image(img, channels="BGR", use_column_width=True)
# plt.imshow(anomaly, cmap="turbo")
# plt.axis('off')
# st.pyplot()
with col2:
# st.image(mask_image, channels="BGR", caption='mask image', use_column_width=True)
fig = plot_anomaly_with_clues(
processed_img,
anomaly,
diff_percentage,
mean_left,
mean_right,
side_min, index_min,
value_min,
left_pairs,
right_pairs,
color='r',
thickness=1
)
caption = "Annotated Anomaly Map"
st.markdown(
f"<h3 style='text-align: center;'>{caption}</h3>",
unsafe_allow_html=True
)
st.pyplot(fig, bbox_inches='tight', pad_inches=0)
exp_html = exp.as_html()
full_html = exp_html + """
<style>
div[class="lime explanation"] { display: none; } /* This targets the tree section */
.lime.table_div {overflow-x: hidden}
</style>
"""
components.html(full_html, height=320, width=None, scrolling=True)
def extract_explain(exp, label):
ans = exp.local_exp[label]
ans = [(exp.domain_mapper.feature_names[x[0]],
exp.domain_mapper.feature_values[x[0]],
exp.domain_mapper.discretized_feature_names[x[0]],
float(x[1])
) for x in ans]
return ans
explanation_list = extract_explain(exp, predicted)
explanation_df = pd.DataFrame(explanation_list, columns=['Feature', 'Value', 'Explain', 'Weight'])
st.table(explanation_df)