MohammedAH commited on
Commit
10f9b1b
·
verified ·
1 Parent(s): 419b61e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -206
app.py CHANGED
@@ -1,218 +1,240 @@
1
- # app.py
2
  import streamlit as st
3
- import numpy as np
4
- import os
5
  import tensorflow as tf
6
- import logging
 
 
7
  from PIL import Image
 
 
8
 
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
 
13
- # Set page configuration
14
  st.set_page_config(
15
- page_title="Breast Cancer Prediction",
16
- page_icon="🩺",
17
- layout="wide",
18
- initial_sidebar_state="expanded"
19
  )
20
 
21
- # Disable GPU to save memory
22
- tf.config.set_visible_devices([], 'GPU')
23
- logger.info("TensorFlow configured for CPU-only")
24
-
25
- # ===== Model Loading =====
26
- MODEL_FILE = "final_combined_model.keras"
27
-
28
- @st.cache_resource(show_spinner=False)
29
- def load_model():
30
- """Load TensorFlow model from local file with caching"""
31
- try:
32
- # Verify file exists
33
- if not os.path.exists(MODEL_FILE):
34
- logger.error(f"❌ Model file not found: {MODEL_FILE}")
35
- return None
36
-
37
- logger.info(f"⏳ Loading model from local file: {MODEL_FILE}")
38
-
39
- # Load model with memory optimization
40
- model = tf.keras.models.load_model(MODEL_FILE, compile=False)
41
-
42
- # Test prediction to verify loading
43
- test_input = np.random.rand(1, 224, 224, 1).astype(np.float32)
44
- test_pred = model.predict(test_input, verbose=0)
45
- logger.info(f"🧪 Test prediction: {test_pred[0][0]:.4f}")
46
-
47
- logger.info("✅ Model loaded successfully")
48
- return model
49
- except Exception as e:
50
- logger.error(f"❌ Error loading model: {e}")
51
- # Print detailed traceback
52
- import traceback
53
- logger.error(traceback.format_exc())
54
- return None
55
-
56
- # Load model at startup
57
- model = load_model()
58
-
59
- # ===== Image Preprocessing =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def preprocess_image(image):
61
- """Preprocess image for model prediction"""
62
- try:
63
- # Convert to PIL Image
64
- if isinstance(image, np.ndarray):
65
- img = Image.fromarray(image.astype('uint8'))
66
- else:
67
- img = image
68
-
69
- # Processing pipeline
70
- img = img.convert('L') # Grayscale
71
- img = img.resize((224, 224)) # Resize
72
- img_array = np.array(img) / 255.0 # Normalize
73
-
74
- # Add batch and channel dimensions
75
- return img_array[np.newaxis, ..., np.newaxis]
76
- except Exception as e:
77
- logger.error(f"🖼️ Image preprocessing error: {e}")
78
- return None
79
-
80
- # ===== Prediction Function =====
81
- def predict(image):
82
- """Make prediction using the loaded model"""
83
- if model is None:
84
- return "Model failed to load", "Check logs", None
85
-
86
- try:
87
- # Preprocess image
88
- processed_image = preprocess_image(image)
89
- if processed_image is None:
90
- return "Invalid image", "Try another", image
91
-
92
- # Make prediction
93
- prediction = model.predict(processed_image, verbose=0)[0][0]
94
-
95
- # Format results
96
- confidence = abs(prediction - 0.5) + 0.5 # Convert to 0.5-1.0 scale
97
- result = "Malignant" if prediction > 0.5 else "Benign"
98
-
99
- return result, f"{confidence*100:.2f}%", image
100
- except Exception as e:
101
- error_msg = f"Prediction error: {str(e)}"
102
- logger.error(error_msg)
103
- return error_msg, "Try again", image
104
-
105
- # ===== Streamlit UI =====
106
-
107
- # Custom CSS for styling
108
- st.markdown("""
109
- <style>
110
- .stApp {
111
- background-color: #f0f2f6;
112
- }
113
- .header {
114
- color: #2c3e50;
115
- text-align: center;
116
- padding: 1rem;
117
- }
118
- .result-box {
119
- border-radius: 10px;
120
- padding: 1.5rem;
121
- margin: 1rem 0;
122
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
123
- }
124
- .malignant {
125
- background-color: #ffcccc;
126
- border-left: 5px solid #e74c3c;
127
- }
128
- .benign {
129
- background-color: #ccffcc;
130
- border-left: 5px solid #2ecc71;
131
- }
132
- .stButton>button {
133
- background-color: #3498db;
134
- color: white;
135
- border-radius: 5px;
136
- padding: 0.5rem 1rem;
137
- width: 100%;
138
- }
139
- .stButton>button:hover {
140
- background-color: #2980b9;
141
- }
142
- </style>
143
- """, unsafe_allow_html=True)
144
-
145
- # Header
146
- st.markdown("<h1 class='header'>🩺 Breast Cancer Prediction</h1>", unsafe_allow_html=True)
147
- st.markdown("Upload a breast medical image for cancer prediction")
148
-
149
- # Status indicator
150
- status = "✅ Model loaded successfully" if model else "❌ Model failed to load"
151
- st.info(status)
152
-
153
- # Create two columns for layout
154
- col1, col2 = st.columns([1, 1])
155
-
156
- # Input column
157
- with col1:
158
- st.subheader("Patient Information")
159
-
160
- # Input fields
161
- age = st.number_input("Patient Age", min_value=18, max_value=100, value=45)
162
- tumor_size = st.number_input("Tumor Size (mm)", min_value=0.1, value=15.0)
163
-
164
- # Image upload
165
- uploaded_file = st.file_uploader(
166
- "Upload Medical Image",
167
- type=["jpg", "jpeg", "png"],
168
- help="Supported formats: JPG, JPEG, PNG"
169
  )
170
-
171
- # Predict button
172
- predict_btn = st.button("Analyze Image")
173
-
174
- # Results column
175
- with col2:
176
- st.subheader("Prediction Results")
177
-
178
- # Initialize session state for results
179
- if 'result' not in st.session_state:
180
- st.session_state.result = None
181
- st.session_state.confidence = None
182
- st.session_state.image = None
183
-
184
- # Process image when button is clicked
185
- if predict_btn and uploaded_file is not None:
186
- try:
187
- image = Image.open(uploaded_file)
188
- st.session_state.result, st.session_state.confidence, st.session_state.image = predict(image)
189
- except Exception as e:
190
- st.error(f"Error processing image: {str(e)}")
191
-
192
- # Display results if available
193
- if st.session_state.result:
194
- # Result box with color coding
195
- result_class = "malignant" if st.session_state.result == "Malignant" else "benign"
196
- st.markdown(
197
- f"<div class='result-box {result_class}'>"
198
- f"<h3>Diagnosis: {st.session_state.result}</h3>"
199
- f"<p>Confidence: {st.session_state.confidence}</p>"
200
- "</div>",
201
- unsafe_allow_html=True
 
 
 
 
 
 
 
 
 
202
  )
203
-
204
- # Display image
205
- if st.session_state.image:
206
- st.image(
207
- st.session_state.image,
208
- caption="Uploaded Image",
209
- use_container_width=True
210
- )
211
-
212
- # Show placeholder if no results
213
- elif not predict_btn:
214
- st.info("Upload an image and click 'Analyze Image' to get prediction")
215
-
216
- # Footer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  st.markdown("---")
218
- st.caption("This tool is for research purposes only. Consult a medical professional for clinical diagnosis.")
 
 
1
  import streamlit as st
 
 
2
  import tensorflow as tf
3
+ import numpy as np
4
+ import joblib
5
+ import json
6
  from PIL import Image
7
+ import pandas as pd
8
+ from lifelines import CoxPHFitter
9
 
10
+ # ---------------------------------------------------
11
+ # CONFIG
12
+ # ---------------------------------------------------
13
 
 
14
  st.set_page_config(
15
+ page_title="Breast Cancer Survival Prediction",
16
+ page_icon="🧬",
17
+ layout="wide"
 
18
  )
19
 
20
+ # CNN_MODEL_PATH = "best_breast_cancer_cnn.keras"
21
+ CNN_MODEL_PATH = "final_combined_model.keras"
22
+ DNN_MODEL_PATH = "survival_model.keras"
23
+
24
+ SCALER_PATH = "scaler.pkl"
25
+ FEATURES_PATH = "features.json"
26
+
27
+ DATASET_PATH = 'processed_breast_cancer_data(1).csv'
28
+ TIME_COL = "Overall_Survival_Months"
29
+ EVENT_COL = "Event"
30
+ ID_COL = "Patient_ID"
31
+
32
+ # ---------------------------------------------------
33
+ # LOAD MODELS
34
+ # ---------------------------------------------------
35
+
36
+ @st.cache_resource
37
+ def load_cnn():
38
+ return tf.keras.models.load_model(CNN_MODEL_PATH, compile=False)
39
+
40
+ @st.cache_resource
41
+ def load_dnn():
42
+ return tf.keras.models.load_model(DNN_MODEL_PATH, compile=False)
43
+
44
+ # ---------------------------------------------------
45
+ # LOAD SURVIVAL ASSETS (COMPUTE BRESLOW BASELINE)
46
+ # ---------------------------------------------------
47
+
48
+ @st.cache_resource
49
+ def load_survival_assets():
50
+
51
+ scaler = joblib.load(SCALER_PATH)
52
+ features = json.load(open(FEATURES_PATH))
53
+
54
+ df = pd.read_csv(DATASET_PATH)
55
+
56
+ feature_df = df[features].copy()
57
+ feature_df["duration"] = df[TIME_COL]
58
+ feature_df["event"] = df[EVENT_COL]
59
+
60
+ cox = CoxPHFitter()
61
+ cox.fit(feature_df, duration_col="duration", event_col="event")
62
+
63
+ baseline = cox.baseline_cumulative_hazard_
64
+
65
+ breslow_times = baseline.index.values
66
+ breslow_H0 = baseline.values.flatten()
67
+
68
+ return scaler, features, breslow_times, breslow_H0
69
+
70
+
71
+ cnn_model = load_cnn()
72
+ dnn_model = load_dnn()
73
+ scaler, feature_cols, breslow_times, breslow_H0 = load_survival_assets()
74
+
75
+ # ---------------------------------------------------
76
+ # IMAGE PREPROCESSING
77
+ # ---------------------------------------------------
78
+
79
  def preprocess_image(image):
80
+
81
+ if image.mode != "L":
82
+ image = image.convert("L")
83
+
84
+ image = image.resize((224, 224))
85
+
86
+ img = np.array(image) / 255.0
87
+ img = img[np.newaxis, ..., np.newaxis]
88
+
89
+ return img
90
+
91
+ # ---------------------------------------------------
92
+ # CNN PREDICTION
93
+ # ---------------------------------------------------
94
+
95
+ def predict_cancer(image):
96
+
97
+ img = preprocess_image(image)
98
+
99
+ pred = cnn_model.predict(img, verbose=0)[0][0]
100
+
101
+ result = "Malignant" if pred > 0.5 else "Benign"
102
+
103
+ confidence = pred if pred > 0.5 else 1 - pred
104
+
105
+ return result, confidence, pred
106
+
107
+ # ---------------------------------------------------
108
+ # SURVIVAL FUNCTION
109
+ # ---------------------------------------------------
110
+
111
+ def survival_prob(risk, t):
112
+
113
+ idx = np.searchsorted(breslow_times, t, side="right") - 1
114
+
115
+ if idx < 0:
116
+ return 1.0
117
+
118
+ h0 = breslow_H0[idx]
119
+
120
+ return float(np.exp(-h0 * np.exp(risk)))
121
+
122
+ # ---------------------------------------------------
123
+ # SURVIVAL PREDICTION
124
+ # ---------------------------------------------------
125
+
126
+ def predict_survival(feature_values):
127
+
128
+ row = np.array([feature_values], dtype=np.float32)
129
+
130
+ row = scaler.transform(row)
131
+
132
+ risk = float(dnn_model.predict(row, verbose=0)[0][0])
133
+
134
+ s1 = survival_prob(risk, 12) * 100
135
+ s3 = survival_prob(risk, 36) * 100
136
+ s5 = survival_prob(risk, 60) * 100
137
+
138
+ return risk, s1, s3, s5
139
+
140
+ # ---------------------------------------------------
141
+ # UI
142
+ # ---------------------------------------------------
143
+
144
+ st.title("🧬 Breast Cancer AI Diagnosis & Survival System")
145
+
146
+ st.markdown(
147
+ """
148
+ This system integrates two AI models:
149
+
150
+ • **CNN model** → detects tumor malignancy from medical images
151
+ **Survival DNN** → predicts patient survival probabilities
152
+ """
153
+ )
154
+
155
+ tab1, tab2 = st.tabs(["🔬 Image Diagnosis", "📈 Survival Analysis"])
156
+
157
+ # ---------------------------------------------------
158
+ # TAB 1 : IMAGE PREDICTION
159
+ # ---------------------------------------------------
160
+
161
+ with tab1:
162
+
163
+ st.header("Tumor Image Classification")
164
+
165
+ uploaded = st.file_uploader(
166
+ "Upload Histopathology Image",
167
+ type=["png", "jpg", "jpeg"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  )
169
+
170
+ if uploaded:
171
+
172
+ image = Image.open(uploaded)
173
+
174
+ st.image(image, width=300)
175
+
176
+ if st.button("Analyze Image"):
177
+
178
+ result, conf, score = predict_cancer(image)
179
+
180
+ st.subheader("Prediction Result")
181
+
182
+ col1, col2 = st.columns(2)
183
+
184
+ col1.metric("Diagnosis", result)
185
+
186
+ col2.metric("Confidence", f"{conf*100:.2f}%")
187
+
188
+ st.write("Prediction Score:", round(score, 4))
189
+
190
+ # ---------------------------------------------------
191
+ # TAB 2 : SURVIVAL ANALYSIS
192
+ # ---------------------------------------------------
193
+
194
+ with tab2:
195
+
196
+ st.header("Patient Survival Prediction")
197
+
198
+ st.write("Enter patient clinical features")
199
+
200
+ inputs = []
201
+
202
+ cols = st.columns(3)
203
+
204
+ for i, f in enumerate(feature_cols):
205
+
206
+ value = cols[i % 3].number_input(
207
+ f,
208
+ value=0.0,
209
+ step=0.1
210
  )
211
+
212
+ inputs.append(value)
213
+
214
+ if st.button("Predict Survival"):
215
+
216
+ risk, s1, s3, s5 = predict_survival(inputs)
217
+
218
+ st.subheader("Risk Score")
219
+
220
+ st.metric("Risk Score", round(risk, 4))
221
+
222
+ st.subheader("Survival Probability")
223
+
224
+ c1, c2, c3 = st.columns(3)
225
+
226
+ c1.metric("1-Year Survival", f"{s1:.1f}%")
227
+ c2.metric("3-Year Survival", f"{s3:.1f}%")
228
+ c3.metric("5-Year Survival", f"{s5:.1f}%")
229
+
230
+ if risk > 0:
231
+ st.error("High Risk Category")
232
+ else:
233
+ st.success("Low Risk Category")
234
+
235
+ # ---------------------------------------------------
236
+ # FOOTER
237
+ # ---------------------------------------------------
238
+
239
  st.markdown("---")
240
+ st.caption("AI-assisted clinical decision support system")