|
|
import streamlit as st |
|
|
import cv2 |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patheffects as path_effects |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import os |
|
|
import streamlit.components.v1 as components |
|
|
import pandas as pd |
|
|
|
|
|
from utils import read_image, combine_mask, combine_prob_jsw, scale_coordinates, get_annotations |
|
|
from segmentation.model import Segmenter |
|
|
from jsw import get_JSW, calculate_diff, calculate_jsw_info |
|
|
from classification.model import Classifier |
|
|
from ML_model.model import MLModel |
|
|
from gae import AnomalyExtractor |
|
|
|
|
|
import os |
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
|
|
|
|
st.set_page_config(layout="wide", page_title="xDesCO", page_icon="🏥") |
|
|
|
|
|
@st.cache_resource |
|
|
def load_models(): |
|
|
seg_model = Segmenter() |
|
|
classif_model = Classifier('convnet') |
|
|
ml_model = MLModel() |
|
|
anomaly_extractor = AnomalyExtractor() |
|
|
|
|
|
return seg_model, classif_model, ml_model, anomaly_extractor |
|
|
|
|
|
|
|
|
seg_model, classif_model, ml_model, anomaly_extractor = load_models() |
|
|
|
|
|
st.sidebar.title("xDesCO: Explainable AI for Knee Osteoarthritis Diagnosis") |
|
|
|
|
|
|
|
|
|
|
|
uploaded_file = st.sidebar.file_uploader("Choose a knee join x-ray image. Download the OAI dataset at [here](https://www.kaggle.com/datasets/shashwatwork/knee-osteoarthritis-dataset-with-severity).", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
|
|
|
st.sidebar.markdown("---") |
|
|
st.sidebar.caption(":bulb: Diagnose knee osteoarthritis with explainable AI insights, enabling users to upload X-ray images for knee osteoarthritis diagnosis with detailed visual explanations of abnormalities and disease severity.") |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
[data-testid="stSidebar"]::after { |
|
|
content: "Powered by the xDesCO framework."; |
|
|
position: absolute; |
|
|
bottom: 20px; |
|
|
left: 20px; |
|
|
font-size: 15px; |
|
|
color: gray; |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
if uploaded_file is not None: |
|
|
file_bytes = uploaded_file.read() |
|
|
file_name = uploaded_file.name |
|
|
img = read_image(file_bytes) |
|
|
|
|
|
mask = seg_model.segment(img) |
|
|
|
|
|
mask_image = combine_mask(img, mask) |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
|
|
|
|
|
|
left_distances, right_distances, links = get_JSW(mask, dim = 10, verbose = 0) |
|
|
left_links, right_links = links |
|
|
jsw_m, jsw_mm = calculate_diff(left_distances, right_distances) |
|
|
diff_percentage, mean_left, mean_right, side_min, index_min, value_min = calculate_jsw_info(left_distances, right_distances) |
|
|
|
|
|
scaled_left_links = scale_coordinates([coord for pair in left_links for coord in pair], original_size=224, mask_size=640) |
|
|
left_pairs = list(zip(scaled_left_links[::2], scaled_left_links[1::2])) |
|
|
scaled_right_links = scale_coordinates([coord for pair in right_links for coord in pair], original_size=224, mask_size=640) |
|
|
right_pairs = list(zip(scaled_right_links[::2], scaled_right_links[1::2])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
probabilites = classif_model.predict(img) |
|
|
|
|
|
|
|
|
annotations = get_annotations(probabilites[0]) |
|
|
|
|
|
predicted, exp = ml_model.predict_explain(probabilites[0], [jsw_m, jsw_mm], filename=file_name) |
|
|
|
|
|
processed_img, anomaly = anomaly_extractor.extract(mask, img, verbose=0) |
|
|
|
|
|
def plot_anomaly_with_clues(processed_img, anomaly, diff_percentage, mean_left, mean_right, side_min, index_min, value_min, |
|
|
left_pairs, right_pairs, color='r', thickness=1): |
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
ax.imshow(processed_img, cmap = 'gray') |
|
|
ax.imshow(anomaly, cmap='turbo', alpha = 0.3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mid_x = anomaly.shape[1] // 2 |
|
|
mid_y = anomaly.shape[0] // 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (diff_percentage > 0) or ((mean_left / value_min > 2) if side_min == 0 else (mean_right / value_min > 2)): |
|
|
min_pairs = left_pairs if side_min == 0 else right_pairs |
|
|
min_pair = min_pairs[index_min] |
|
|
start_point = min_pair[0] |
|
|
end_point = min_pair[1] |
|
|
|
|
|
|
|
|
padding_x = 23 |
|
|
padding_y = 13 |
|
|
min_x = start_point[0] - padding_x |
|
|
max_x = start_point[0] + padding_x |
|
|
min_y = min(start_point[1], end_point[1]) - padding_y |
|
|
max_y = max(start_point[1], end_point[1]) + padding_y |
|
|
|
|
|
rect = plt.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y, linewidth=3, edgecolor='red', facecolor='none') |
|
|
ax.add_patch(rect) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 3.2, f'{annotations["osteophyte"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()]) |
|
|
ax.text(anomaly.shape[1] // 2, anomaly.shape[0] // 1.3, f'{annotations["jsn"]}', color='yellow', fontsize=19, ha='center', va='center', path_effects=[path_effects.Stroke(linewidth=1.5, foreground='black'), path_effects.Normal()]) |
|
|
|
|
|
|
|
|
ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.05, f'$JSW_{{Mean}}$: {jsw_m:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()]) |
|
|
ax.text(anomaly.shape[1] // 40, anomaly.shape[0] // 1.1, f'$JSW_{{MM}}$: {jsw_mm:.2f}', color='white', fontsize=10, ha='left', va='center', path_effects=[path_effects.Stroke(linewidth=1, foreground='black'), path_effects.Normal()]) |
|
|
|
|
|
ax.axis('off') |
|
|
|
|
|
|
|
|
return fig |
|
|
|
|
|
with col1: |
|
|
caption = "Uploaded Image" |
|
|
st.markdown( |
|
|
f"<h3 style='text-align: center;'>{caption}</h3>", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
st.image(img, channels="BGR", use_column_width=True) |
|
|
|
|
|
|
|
|
|
|
|
with col2: |
|
|
|
|
|
|
|
|
fig = plot_anomaly_with_clues( |
|
|
processed_img, |
|
|
anomaly, |
|
|
diff_percentage, |
|
|
mean_left, |
|
|
mean_right, |
|
|
side_min, index_min, |
|
|
value_min, |
|
|
left_pairs, |
|
|
right_pairs, |
|
|
color='r', |
|
|
thickness=1 |
|
|
) |
|
|
caption = "Annotated Anomaly Map" |
|
|
st.markdown( |
|
|
f"<h3 style='text-align: center;'>{caption}</h3>", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
st.pyplot(fig, bbox_inches='tight', pad_inches=0) |
|
|
|
|
|
|
|
|
exp_html = exp.as_html() |
|
|
full_html = exp_html + """ |
|
|
<style> |
|
|
div[class="lime explanation"] { display: none; } /* This targets the tree section */ |
|
|
.lime.table_div {overflow-x: hidden} |
|
|
</style> |
|
|
""" |
|
|
components.html(full_html, height=320, width=None, scrolling=True) |
|
|
|
|
|
|
|
|
def extract_explain(exp, label): |
|
|
ans = exp.local_exp[label] |
|
|
ans = [(exp.domain_mapper.feature_names[x[0]], |
|
|
exp.domain_mapper.feature_values[x[0]], |
|
|
exp.domain_mapper.discretized_feature_names[x[0]], |
|
|
float(x[1]) |
|
|
) for x in ans] |
|
|
|
|
|
return ans |
|
|
|
|
|
explanation_list = extract_explain(exp, predicted) |
|
|
explanation_df = pd.DataFrame(explanation_list, columns=['Feature', 'Value', 'Explain', 'Weight']) |
|
|
st.table(explanation_df) |
|
|
|
|
|
|