Abd-trauma-RU / app.py
aimedica's picture
Upload app.py with huggingface_hub
703296a verified
import streamlit as st
import numpy as np
import pydicom
import tensorflow as tf
from tensorflow.keras.models import load_model
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import cv2
import PIL.Image
import pandas as pd
import io
import traceback
# ──────────────────────────────────────────────
# Названия меток травм
# ──────────────────────────────────────────────
injury_cols = [
'bowel_injury',
'extravasation_injury',
'kidney_healthy', 'kidney_low', 'kidney_high',
'liver_healthy', 'liver_low', 'liver_high',
'spleen_healthy', 'spleen_low', 'spleen_high',
]
LABEL_RU = {
'bowel_injury': 'Травма кишечника',
'extravasation_injury':'Экстравазация (кровотечение)',
'kidney_healthy': 'Почка — норма',
'kidney_low': 'Почка — низкий риск',
'kidney_high': 'Почка — высокий риск',
'liver_healthy': 'Печень — норма',
'liver_low': 'Печень — низкий риск',
'liver_high': 'Печень — высокий риск',
'spleen_healthy': 'Селезёнка — норма',
'spleen_low': 'Селезёнка — низкий риск',
'spleen_high': 'Селезёнка — высокий риск',
}
# ──────────────────────────────────────────────
# Предобработка изображения
# ──────────────────────────────────────────────
def load_and_preprocess_image(uploaded_file, target_size=(224, 224)):
file_extension = uploaded_file.name.split('.')[-1].lower()
uploaded_file.seek(0)
if file_extension == 'dcm':
raw_bytes = uploaded_file.read()
dicom = pydicom.dcmread(io.BytesIO(raw_bytes))
img = dicom.pixel_array.astype('float32')
max_val = img.max()
if max_val > 0:
img = img / max_val
else:
raw_bytes = uploaded_file.read()
pil_img = PIL.Image.open(io.BytesIO(raw_bytes))
if pil_img.mode in ('RGBA', 'LA'):
pil_img = pil_img.convert('RGB')
elif pil_img.mode == 'P':
pil_img = pil_img.convert('RGB')
img_arr = np.array(pil_img)
if len(img_arr.shape) == 3 and img_arr.shape[2] >= 3:
img = cv2.cvtColor(img_arr[:, :, :3], cv2.COLOR_RGB2GRAY)
else:
img = img_arr
img = img.astype('float32') / 255.0
img = cv2.resize(img, target_size)
if len(img.shape) == 2:
img = np.stack((img,) * 3, axis=-1)
elif img.shape[2] == 1:
img = np.repeat(img, 3, axis=-1)
img = np.expand_dims(img, axis=0)
return img
# ──────────────────────────────────────────────
# Основное приложение
# ──────────────────────────────────────────────
def main():
st.set_page_config(
page_title="Диагностика травм живота",
page_icon="🏥",
layout="wide",
)
st.title('🏥 Система диагностики травм живота')
st.write('Загрузите медицинское изображение (DICOM, JPG, PNG) для анализа травм')
# ── Загрузка модели ──────────────────────────
@st.cache_resource
def load_trained_model():
try:
return load_model('best_model.keras')
except Exception:
return None
model = load_trained_model()
if model is None:
# Демонстрационный режим с реалистичными предсказаниями
class DummyModel:
def predict(self, x):
# Реалистичный пример: здоровый пациент
vals = {
'bowel_injury': 0.023,
'extravasation_injury':0.031,
'kidney_healthy': 0.912,
'kidney_low': 0.054,
'kidney_high': 0.018,
'liver_healthy': 0.887,
'liver_low': 0.071,
'liver_high': 0.025,
'spleen_healthy': 0.873,
'spleen_low': 0.059,
'spleen_high': 0.041,
}
return np.array([[vals[c] for c in injury_cols]], dtype='float32')
model = DummyModel()
_model_missing = True
else:
_model_missing = False
# ── Загрузка файла ───────────────────────────
uploaded_file = st.file_uploader(
"Выберите файл изображения",
type=["dcm", "jpg", "jpeg", "png", "bmp", "tiff"],
help="Поддерживаются форматы: DICOM (.dcm), JPG, PNG, BMP, TIFF",
)
if uploaded_file is not None:
col1, col2 = st.columns(2)
with col1:
st.subheader("Загруженное изображение")
file_extension = uploaded_file.name.split('.')[-1].lower()
try:
uploaded_file.seek(0)
if file_extension == 'dcm':
raw = uploaded_file.read()
dicom_img = pydicom.dcmread(io.BytesIO(raw))
fig_img, ax = plt.subplots(figsize=(6, 6))
ax.imshow(dicom_img.pixel_array, cmap='gray')
ax.axis('off')
st.pyplot(fig_img)
plt.close(fig_img)
else:
uploaded_file.seek(0)
st.image(uploaded_file, width=700)
except Exception as e:
st.error(f"Ошибка отображения изображения: {e}")
with col2:
try:
processed_img = load_and_preprocess_image(uploaded_file)
with st.spinner("Выполняется анализ..."):
prediction = model.predict(processed_img)
if _model_missing:
st.info("ℹ️ Демонстрационный режим — модель `best_model.keras` не найдена.")
raw_probs = prediction[0].astype(float)
probs = {injury_cols[i]: float(raw_probs[i]) for i in range(len(injury_cols))}
# ── Таблица повреждений ──────────────
st.subheader('🔍 Вероятности повреждений')
injury_only_cols = [c for c in injury_cols if 'healthy' not in c]
injury_df = pd.DataFrame({
'Тип повреждения': [LABEL_RU.get(c, c) for c in injury_only_cols],
'Вероятность': [probs[c] for c in injury_only_cols],
'_raw_label': injury_only_cols,
})
injury_df = injury_df.sort_values('Вероятность', ascending=False).reset_index(drop=True)
def color_injury(row):
prob = float(row['Вероятность'])
if prob > 0.7:
style = 'background-color: #721c24; color: white; font-weight:bold'
elif prob > 0.5:
style = 'background-color: #fd7e14; color: black; font-weight:bold'
elif prob > 0.3:
style = 'background-color: #fff3cd; color: #856404'
else:
style = 'background-color: #f8f9fa; color: #6c757d'
return [style] * len(row)
styled_injury = (
injury_df
.style
.apply(color_injury, axis=1)
.format({'Вероятность': '{:.1%}'})
.hide(axis='columns', subset=['_raw_label'])
.hide(axis='index')
)
st.write(styled_injury.to_html(), unsafe_allow_html=True)
# ── Таблица состояния органов ────────
st.subheader('🫀 Состояние органов')
healthy_cols = [c for c in injury_cols if 'healthy' in c]
healthy_df = pd.DataFrame({
'Орган': [LABEL_RU.get(c, c) for c in healthy_cols],
'Норма (%)': [probs[c] for c in healthy_cols],
'_raw_label': healthy_cols,
})
healthy_df = healthy_df.sort_values('Норма (%)', ascending=False).reset_index(drop=True)
def color_healthy(row):
prob = float(row['Норма (%)'])
if prob > 0.5:
style = 'background-color: #d4edda; color: #155724'
else:
style = 'background-color: #f8d7da; color: #721c24; font-weight:bold'
return [style] * len(row)
styled_healthy = (
healthy_df
.style
.apply(color_healthy, axis=1)
.format({'Норма (%)': '{:.1%}'})
.hide(axis='columns', subset=['_raw_label'])
.hide(axis='index')
)
st.write(styled_healthy.to_html(), unsafe_allow_html=True)
st.caption("🟢 Зелёный — орган здоров (> 50 %) · 🔴 Красный — орган под угрозой (< 50 %)")
# ── График повреждений ────────────────
st.subheader('📊 График вероятностей повреждений')
def get_bar_color(label, prob):
if prob > 0.7: return '#dc3545'
elif prob > 0.5: return '#fd7e14'
elif prob > 0.3: return '#ffc107'
else: return '#adb5bd'
fig_bar, ax_bar = plt.subplots(figsize=(10, 5))
bar_labels = list(injury_df['Тип повреждения'])
bar_probs = list(injury_df['Вероятность'])
bar_raw = list(injury_df['_raw_label'])
bar_colors = [get_bar_color(lbl, p) for lbl, p in zip(bar_raw, bar_probs)]
bars = ax_bar.bar(bar_labels, bar_probs, color=bar_colors, edgecolor='white')
for bar, prob in zip(bars, bar_probs):
if prob > 0.01:
ax_bar.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.01,
f'{prob:.1%}', ha='center', va='bottom', fontsize=8
)
ax_bar.set_title('Вероятности повреждений органов', fontsize=14)
ax_bar.set_ylabel('Вероятность')
ax_bar.set_ylim(0, 1.1)
ax_bar.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Порог 50%')
ax_bar.axhline(y=0.3, color='orange', linestyle='--', alpha=0.4, label='Порог 30%')
ax_bar.legend(fontsize=8)
plt.xticks(rotation=45, ha='right', fontsize=8)
plt.tight_layout()
st.pyplot(fig_bar)
plt.close(fig_bar)
# ── Сводка рисков ─────────────────────
st.subheader('🚨 Сводка рисков')
critical_labels = [
lbl for lbl in injury_only_cols
if (lbl.endswith('_high') or lbl in ('bowel_injury', 'extravasation_injury'))
and probs[lbl] > 0.5
]
moderate_labels = [
lbl for lbl in injury_only_cols
if lbl.endswith('_low') and probs[lbl] > 0.3
]
at_risk_organs = [
lbl for lbl in healthy_cols if probs[lbl] < 0.5
]
any_risk = False
if critical_labels:
any_risk = True
st.error("🔴 КРИТИЧЕСКИЙ РИСК — требуется срочное вмешательство:")
for lbl in sorted(critical_labels, key=lambda l: probs[l], reverse=True):
st.markdown(
f"<div style='background:#f8d7da;border-radius:6px;padding:10px 16px;"
f"margin:4px 0;border-left:4px solid #dc3545'>"
f"🔴 <b>{LABEL_RU[lbl]}</b> — вероятность: <b>{probs[lbl]:.1%}</b>"
f"</div>",
unsafe_allow_html=True,
)
if at_risk_organs:
any_risk = True
st.warning("⚠️ ОРГАНЫ ПОД УГРОЗОЙ (вероятность нормы < 50 %):")
for lbl in sorted(at_risk_organs, key=lambda l: probs[l]):
organ_name = LABEL_RU[lbl].replace(' — норма', '')
st.markdown(
f"<div style='background:#fff3cd;border-radius:6px;padding:10px 16px;"
f"margin:4px 0;border-left:4px solid #fd7e14'>"
f"⚠️ <b>{organ_name}</b> — вероятность нормы: <b>{probs[lbl]:.1%}</b>"
f"</div>",
unsafe_allow_html=True,
)
if moderate_labels:
any_risk = True
st.info("🟡 УМЕРЕННЫЙ РИСК — рекомендуется дообследование:")
for lbl in sorted(moderate_labels, key=lambda l: probs[l], reverse=True):
st.markdown(
f"<div style='background:#fff3cd;border-radius:6px;padding:10px 16px;"
f"margin:4px 0;border-left:4px solid #ffc107'>"
f"🟡 <b>{LABEL_RU[lbl]}</b> — вероятность: <b>{probs[lbl]:.1%}</b>"
f"</div>",
unsafe_allow_html=True,
)
if not any_risk:
st.success("✅ Патологий не выявлено. Все органы в норме.")
st.markdown("---")
st.caption(
"**Шкала:** 🔴 > 70 % — критический · "
"🟠 50–70 % — высокий · "
"🟡 30–50 % — умеренный · "
"⚪ < 30 % — низкий"
)
except Exception:
pass
# ── Боковая панель ───────────────────────────
st.sidebar.title("О модели")
st.sidebar.info(
"""
**Модель диагностики травм живота**
- Многоклассовая детекция повреждений
- Обучена на КТ-снимках живота
- Архитектура: ResNet50
- Определяемые органы: кишечник, почки, печень, селезёнка
- Выявляет экстравазацию (внутреннее кровотечение)
"""
)
st.sidebar.title("Легенда цветов")
st.sidebar.markdown(
"""
🔴 **Красный** — критический риск (> 70 %)
🟠 **Оранжевый** — высокий риск (50–70 %)
🟡 **Жёлтый** — умеренный риск (30–50 %)
🟢 **Зелёный** — норма (орган здоров)
"""
)
st.sidebar.title("Поддерживаемые форматы")
st.sidebar.markdown(
"""
- **DICOM** (.dcm) — медицинские снимки
- **JPEG** (.jpg, .jpeg)
- **PNG** (.png)
- **BMP** (.bmp)
- **TIFF** (.tiff)
"""
)
st.sidebar.warning(
"⚠️ **Внимание:** данный инструмент предназначен "
"исключительно для исследовательских целей и не является "
"медицинским диагностическим устройством."
)
if __name__ == '__main__':
main()