File size: 5,220 Bytes
a1ad22e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import streamlit as st
import numpy as np
import os
import pickle
import zipfile
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image

# --- CONFIG ---
st.set_page_config(page_title="Leukemia Subtype Detector", layout="centered")

st.markdown("""

    <div style='background-color:#57068c;padding:20px;border-radius:10px'>

    <h2 style='color:white;text-align:center'>🧬 Leukemia Subtype Detection</h2>

    <p style='color:white;text-align:center'>

        Uses an ensemble of 4 models: <b>DenseNet121</b>, <b>MobileNetV2</b>, <b>VGG16</b>, and <b>Custom CNN</b>.<br>

        Predicts: <i>Benign</i>, <i>Pre</i>, <i>Pro</i>, <i>Early</i>.

    </p>

    </div>

""", unsafe_allow_html=True)

# --- CONSTANTS ---
IMG_HEIGHT, IMG_WIDTH = 224, 224
CLASS_NAMES = ['Benign', 'Pre', 'Pro', 'Early']
SAVE_DIR = 'saved_leukemia_ensemble'
MODEL_ZIPS = {
    "DenseNet121": "DenseNet121_model.zip",
    "MobileNetV2": "MobileNetV2_model.zip",
    "VGG16": "VGG16_model.zip",
    "CustomCNN": "CustomCNN_model.zip"
}
ENSEMBLE_WEIGHTS = {
    "DenseNet121": 0.28,
    "MobileNetV2": 0.30,
    "VGG16": 0.22,
    "CustomCNN": 0.20
}
HISTORY_PATHS = {
    name: os.path.join(SAVE_DIR, f"{name}_history.pkl") for name in MODEL_ZIPS
}

# --- UTIL FUNCTION ---
def extract_model_if_needed(zip_path, output_path):
    if not os.path.exists(output_path):
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(os.path.dirname(output_path))

# --- LOAD MODELS ---
@st.cache_resource
def load_all_models():
    models = {}
    for name, zip_file in MODEL_ZIPS.items():
        zip_path = os.path.join(SAVE_DIR, zip_file)
        keras_path = zip_path.replace(".zip", ".keras")

        extract_model_if_needed(zip_path, keras_path)

        if os.path.exists(keras_path):
            models[name] = load_model(keras_path, compile=False)
        else:
            st.warning(f"❌ Model not found: {keras_path}")
    return models

models = load_all_models()

# --- UPLOAD IMAGE ---
uploaded_file = st.file_uploader("πŸ“ Upload a blood smear image", type=["jpg", "jpeg", "png"])

if uploaded_file:
    st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
    
    if st.button("πŸ” Enter"):
        with st.spinner("⏳ Please wait while results are being computed..."):
            try:
                img = image.load_img(uploaded_file, target_size=(IMG_HEIGHT, IMG_WIDTH))
                img_array = image.img_to_array(img) / 255.0
                img_array = np.expand_dims(img_array, axis=0)

                st.markdown("### πŸ§ͺ Model Predictions:")
                col1, col2 = st.columns(2)
                individual_preds = {}

                for i, (name, model) in enumerate(models.items()):
                    pred = model.predict(img_array)
                    individual_preds[name] = pred
                    cls = CLASS_NAMES[np.argmax(pred)]
                    conf = pred[0][np.argmax(pred)]
                    with [col1, col2][i % 2]:
                        st.info(f"**{name}** ➜ `{cls}` ({conf:.2%})")

                ensemble_pred = sum(ENSEMBLE_WEIGHTS[name] * pred for name, pred in individual_preds.items())
                final_class = CLASS_NAMES[np.argmax(ensemble_pred)]
                final_conf = float(np.max(ensemble_pred))

                st.markdown(f"""

                    <div style="background-color:#c6f6d5;padding:15px;border-radius:10px">

                    <h4 style="color:#2f855a">βœ… Ensemble Prediction: <b>{final_class}</b></h4>

                    <p style="font-size:16px;color:#22543d">Confidence: <b>{final_conf:.2%}</b></p>

                    </div>

                """, unsafe_allow_html=True)

                st.bar_chart({CLASS_NAMES[i]: float(ensemble_pred[0][i]) for i in range(4)})

            except Exception as e:
                st.error(f"⚠️ Error: {e}")
else:
    st.info("πŸ‘ˆ Please upload an image to begin.")

# --- OPTIONAL TRAINING VISUALIZATION ---
st.markdown("---")
st.subheader("πŸ“ˆ Model Training History")

if st.checkbox("Show training curves"):
    for name, path in HISTORY_PATHS.items():
        if os.path.exists(path):
            with open(path, "rb") as f:
                hist = pickle.load(f)
                acc = hist['accuracy']
                val_acc = hist['val_accuracy']
                loss = hist['loss']
                val_loss = hist['val_loss']

                fig, ax = plt.subplots(1, 2, figsize=(12, 4))
                ax[0].plot(acc, label='Train Acc')
                ax[0].plot(val_acc, label='Val Acc')
                ax[0].set_title(f'{name} Accuracy')
                ax[0].legend()

                ax[1].plot(loss, label='Train Loss')
                ax[1].plot(val_loss, label='Val Loss')
                ax[1].set_title(f'{name} Loss')
                ax[1].legend()

                st.pyplot(fig)
        else:
            st.warning(f"No training history found for {name}")