GarimaSharma75's picture
Upload 11 files
a1ad22e verified
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import streamlit as st
import numpy as np
import os
import pickle
import zipfile
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
# --- CONFIG ---
st.set_page_config(page_title="Leukemia Subtype Detector", layout="centered")
st.markdown("""
<div style='background-color:#57068c;padding:20px;border-radius:10px'>
<h2 style='color:white;text-align:center'>🧬 Leukemia Subtype Detection</h2>
<p style='color:white;text-align:center'>
Uses an ensemble of 4 models: <b>DenseNet121</b>, <b>MobileNetV2</b>, <b>VGG16</b>, and <b>Custom CNN</b>.<br>
Predicts: <i>Benign</i>, <i>Pre</i>, <i>Pro</i>, <i>Early</i>.
</p>
</div>
""", unsafe_allow_html=True)
# --- CONSTANTS ---
IMG_HEIGHT, IMG_WIDTH = 224, 224
CLASS_NAMES = ['Benign', 'Pre', 'Pro', 'Early']
SAVE_DIR = 'saved_leukemia_ensemble'
MODEL_ZIPS = {
"DenseNet121": "DenseNet121_model.zip",
"MobileNetV2": "MobileNetV2_model.zip",
"VGG16": "VGG16_model.zip",
"CustomCNN": "CustomCNN_model.zip"
}
ENSEMBLE_WEIGHTS = {
"DenseNet121": 0.28,
"MobileNetV2": 0.30,
"VGG16": 0.22,
"CustomCNN": 0.20
}
HISTORY_PATHS = {
name: os.path.join(SAVE_DIR, f"{name}_history.pkl") for name in MODEL_ZIPS
}
# --- UTIL FUNCTION ---
def extract_model_if_needed(zip_path, output_path):
if not os.path.exists(output_path):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(os.path.dirname(output_path))
# --- LOAD MODELS ---
@st.cache_resource
def load_all_models():
models = {}
for name, zip_file in MODEL_ZIPS.items():
zip_path = os.path.join(SAVE_DIR, zip_file)
keras_path = zip_path.replace(".zip", ".keras")
extract_model_if_needed(zip_path, keras_path)
if os.path.exists(keras_path):
models[name] = load_model(keras_path, compile=False)
else:
st.warning(f"❌ Model not found: {keras_path}")
return models
models = load_all_models()
# --- UPLOAD IMAGE ---
uploaded_file = st.file_uploader("πŸ“ Upload a blood smear image", type=["jpg", "jpeg", "png"])
if uploaded_file:
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
if st.button("πŸ” Enter"):
with st.spinner("⏳ Please wait while results are being computed..."):
try:
img = image.load_img(uploaded_file, target_size=(IMG_HEIGHT, IMG_WIDTH))
img_array = image.img_to_array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
st.markdown("### πŸ§ͺ Model Predictions:")
col1, col2 = st.columns(2)
individual_preds = {}
for i, (name, model) in enumerate(models.items()):
pred = model.predict(img_array)
individual_preds[name] = pred
cls = CLASS_NAMES[np.argmax(pred)]
conf = pred[0][np.argmax(pred)]
with [col1, col2][i % 2]:
st.info(f"**{name}** ➜ `{cls}` ({conf:.2%})")
ensemble_pred = sum(ENSEMBLE_WEIGHTS[name] * pred for name, pred in individual_preds.items())
final_class = CLASS_NAMES[np.argmax(ensemble_pred)]
final_conf = float(np.max(ensemble_pred))
st.markdown(f"""
<div style="background-color:#c6f6d5;padding:15px;border-radius:10px">
<h4 style="color:#2f855a">βœ… Ensemble Prediction: <b>{final_class}</b></h4>
<p style="font-size:16px;color:#22543d">Confidence: <b>{final_conf:.2%}</b></p>
</div>
""", unsafe_allow_html=True)
st.bar_chart({CLASS_NAMES[i]: float(ensemble_pred[0][i]) for i in range(4)})
except Exception as e:
st.error(f"⚠️ Error: {e}")
else:
st.info("πŸ‘ˆ Please upload an image to begin.")
# --- OPTIONAL TRAINING VISUALIZATION ---
st.markdown("---")
st.subheader("πŸ“ˆ Model Training History")
if st.checkbox("Show training curves"):
for name, path in HISTORY_PATHS.items():
if os.path.exists(path):
with open(path, "rb") as f:
hist = pickle.load(f)
acc = hist['accuracy']
val_acc = hist['val_accuracy']
loss = hist['loss']
val_loss = hist['val_loss']
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(acc, label='Train Acc')
ax[0].plot(val_acc, label='Val Acc')
ax[0].set_title(f'{name} Accuracy')
ax[0].legend()
ax[1].plot(loss, label='Train Loss')
ax[1].plot(val_loss, label='Val Loss')
ax[1].set_title(f'{name} Loss')
ax[1].legend()
st.pyplot(fig)
else:
st.warning(f"No training history found for {name}")