import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import streamlit as st
import numpy as np
import os
import pickle
import zipfile
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
# --- CONFIG ---
st.set_page_config(page_title="Leukemia Subtype Detector", layout="centered")
st.markdown("""
๐งฌ Leukemia Subtype Detection
Uses an ensemble of 4 models: DenseNet121, MobileNetV2, VGG16, and Custom CNN.
Predicts: Benign, Pre, Pro, Early.
""", unsafe_allow_html=True)
# --- CONSTANTS ---
IMG_HEIGHT, IMG_WIDTH = 224, 224
CLASS_NAMES = ['Benign', 'Pre', 'Pro', 'Early']
SAVE_DIR = 'saved_leukemia_ensemble'
MODEL_ZIPS = {
"DenseNet121": "DenseNet121_model.zip",
"MobileNetV2": "MobileNetV2_model.zip",
"VGG16": "VGG16_model.zip",
"CustomCNN": "CustomCNN_model.zip"
}
ENSEMBLE_WEIGHTS = {
"DenseNet121": 0.28,
"MobileNetV2": 0.30,
"VGG16": 0.22,
"CustomCNN": 0.20
}
HISTORY_PATHS = {
name: os.path.join(SAVE_DIR, f"{name}_history.pkl") for name in MODEL_ZIPS
}
# --- UTIL FUNCTION ---
def extract_model_if_needed(zip_path, output_path):
if not os.path.exists(output_path):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(os.path.dirname(output_path))
# --- LOAD MODELS ---
@st.cache_resource
def load_all_models():
models = {}
for name, zip_file in MODEL_ZIPS.items():
zip_path = os.path.join(SAVE_DIR, zip_file)
keras_path = zip_path.replace(".zip", ".keras")
extract_model_if_needed(zip_path, keras_path)
if os.path.exists(keras_path):
models[name] = load_model(keras_path, compile=False)
else:
st.warning(f"โ Model not found: {keras_path}")
return models
models = load_all_models()
# --- UPLOAD IMAGE ---
uploaded_file = st.file_uploader("๐ Upload a blood smear image", type=["jpg", "jpeg", "png"])
if uploaded_file:
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
if st.button("๐ Enter"):
with st.spinner("โณ Please wait while results are being computed..."):
try:
img = image.load_img(uploaded_file, target_size=(IMG_HEIGHT, IMG_WIDTH))
img_array = image.img_to_array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
st.markdown("### ๐งช Model Predictions:")
col1, col2 = st.columns(2)
individual_preds = {}
for i, (name, model) in enumerate(models.items()):
pred = model.predict(img_array)
individual_preds[name] = pred
cls = CLASS_NAMES[np.argmax(pred)]
conf = pred[0][np.argmax(pred)]
with [col1, col2][i % 2]:
st.info(f"**{name}** โ `{cls}` ({conf:.2%})")
ensemble_pred = sum(ENSEMBLE_WEIGHTS[name] * pred for name, pred in individual_preds.items())
final_class = CLASS_NAMES[np.argmax(ensemble_pred)]
final_conf = float(np.max(ensemble_pred))
st.markdown(f"""
โ
Ensemble Prediction: {final_class}
Confidence: {final_conf:.2%}
""", unsafe_allow_html=True)
st.bar_chart({CLASS_NAMES[i]: float(ensemble_pred[0][i]) for i in range(4)})
except Exception as e:
st.error(f"โ ๏ธ Error: {e}")
else:
st.info("๐ Please upload an image to begin.")
# --- OPTIONAL TRAINING VISUALIZATION ---
st.markdown("---")
st.subheader("๐ Model Training History")
if st.checkbox("Show training curves"):
for name, path in HISTORY_PATHS.items():
if os.path.exists(path):
with open(path, "rb") as f:
hist = pickle.load(f)
acc = hist['accuracy']
val_acc = hist['val_accuracy']
loss = hist['loss']
val_loss = hist['val_loss']
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(acc, label='Train Acc')
ax[0].plot(val_acc, label='Val Acc')
ax[0].set_title(f'{name} Accuracy')
ax[0].legend()
ax[1].plot(loss, label='Train Loss')
ax[1].plot(val_loss, label='Val Loss')
ax[1].set_title(f'{name} Loss')
ax[1].legend()
st.pyplot(fig)
else:
st.warning(f"No training history found for {name}")