File size: 5,220 Bytes
a1ad22e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import streamlit as st
import numpy as np
import os
import pickle
import zipfile
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
# --- CONFIG ---
st.set_page_config(page_title="Leukemia Subtype Detector", layout="centered")
st.markdown("""
<div style='background-color:#57068c;padding:20px;border-radius:10px'>
<h2 style='color:white;text-align:center'>𧬠Leukemia Subtype Detection</h2>
<p style='color:white;text-align:center'>
Uses an ensemble of 4 models: <b>DenseNet121</b>, <b>MobileNetV2</b>, <b>VGG16</b>, and <b>Custom CNN</b>.<br>
Predicts: <i>Benign</i>, <i>Pre</i>, <i>Pro</i>, <i>Early</i>.
</p>
</div>
""", unsafe_allow_html=True)
# --- CONSTANTS ---
IMG_HEIGHT, IMG_WIDTH = 224, 224
CLASS_NAMES = ['Benign', 'Pre', 'Pro', 'Early']
SAVE_DIR = 'saved_leukemia_ensemble'
MODEL_ZIPS = {
"DenseNet121": "DenseNet121_model.zip",
"MobileNetV2": "MobileNetV2_model.zip",
"VGG16": "VGG16_model.zip",
"CustomCNN": "CustomCNN_model.zip"
}
ENSEMBLE_WEIGHTS = {
"DenseNet121": 0.28,
"MobileNetV2": 0.30,
"VGG16": 0.22,
"CustomCNN": 0.20
}
HISTORY_PATHS = {
name: os.path.join(SAVE_DIR, f"{name}_history.pkl") for name in MODEL_ZIPS
}
# --- UTIL FUNCTION ---
def extract_model_if_needed(zip_path, output_path):
if not os.path.exists(output_path):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(os.path.dirname(output_path))
# --- LOAD MODELS ---
@st.cache_resource
def load_all_models():
models = {}
for name, zip_file in MODEL_ZIPS.items():
zip_path = os.path.join(SAVE_DIR, zip_file)
keras_path = zip_path.replace(".zip", ".keras")
extract_model_if_needed(zip_path, keras_path)
if os.path.exists(keras_path):
models[name] = load_model(keras_path, compile=False)
else:
st.warning(f"β Model not found: {keras_path}")
return models
models = load_all_models()
# --- UPLOAD IMAGE ---
uploaded_file = st.file_uploader("π Upload a blood smear image", type=["jpg", "jpeg", "png"])
if uploaded_file:
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
if st.button("π Enter"):
with st.spinner("β³ Please wait while results are being computed..."):
try:
img = image.load_img(uploaded_file, target_size=(IMG_HEIGHT, IMG_WIDTH))
img_array = image.img_to_array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
st.markdown("### π§ͺ Model Predictions:")
col1, col2 = st.columns(2)
individual_preds = {}
for i, (name, model) in enumerate(models.items()):
pred = model.predict(img_array)
individual_preds[name] = pred
cls = CLASS_NAMES[np.argmax(pred)]
conf = pred[0][np.argmax(pred)]
with [col1, col2][i % 2]:
st.info(f"**{name}** β `{cls}` ({conf:.2%})")
ensemble_pred = sum(ENSEMBLE_WEIGHTS[name] * pred for name, pred in individual_preds.items())
final_class = CLASS_NAMES[np.argmax(ensemble_pred)]
final_conf = float(np.max(ensemble_pred))
st.markdown(f"""
<div style="background-color:#c6f6d5;padding:15px;border-radius:10px">
<h4 style="color:#2f855a">β
Ensemble Prediction: <b>{final_class}</b></h4>
<p style="font-size:16px;color:#22543d">Confidence: <b>{final_conf:.2%}</b></p>
</div>
""", unsafe_allow_html=True)
st.bar_chart({CLASS_NAMES[i]: float(ensemble_pred[0][i]) for i in range(4)})
except Exception as e:
st.error(f"β οΈ Error: {e}")
else:
st.info("π Please upload an image to begin.")
# --- OPTIONAL TRAINING VISUALIZATION ---
st.markdown("---")
st.subheader("π Model Training History")
if st.checkbox("Show training curves"):
for name, path in HISTORY_PATHS.items():
if os.path.exists(path):
with open(path, "rb") as f:
hist = pickle.load(f)
acc = hist['accuracy']
val_acc = hist['val_accuracy']
loss = hist['loss']
val_loss = hist['val_loss']
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(acc, label='Train Acc')
ax[0].plot(val_acc, label='Val Acc')
ax[0].set_title(f'{name} Accuracy')
ax[0].legend()
ax[1].plot(loss, label='Train Loss')
ax[1].plot(val_loss, label='Val Loss')
ax[1].set_title(f'{name} Loss')
ax[1].legend()
st.pyplot(fig)
else:
st.warning(f"No training history found for {name}")
|