MohammedAH commited on
Commit
02cb7a2
Β·
verified Β·
1 Parent(s): 0eef1c8

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ survival_model.keras filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import joblib
5
+ import json
6
+ from PIL import Image
7
+ import pandas as pd
8
+ from lifelines import CoxPHFitter
9
+ import huggingface_hub
10
+
11
+ # ---------------------------------------------------
12
+ # CONFIG
13
+ # ---------------------------------------------------
14
+
15
+ st.set_page_config(
16
+ page_title="Breast Cancer Survival Prediction",
17
+ page_icon="🧬",
18
+ layout="wide"
19
+ )
20
+
21
+ # CNN_MODEL_PATH = "best_breast_cancer_cnn.keras"
22
+ CNN_MODEL_PATH = "hf://MohammedAH/BreastCancerPrediction"
23
+ DNN_MODEL_PATH = "survival_model.keras"
24
+
25
+ SCALER_PATH = "scaler.pkl"
26
+ FEATURES_PATH = "features.json"
27
+
28
+ DATASET_PATH = 'processed_breast_cancer_data(1).csv'
29
+ TIME_COL = "Overall_Survival_Months"
30
+ EVENT_COL = "Event"
31
+ ID_COL = "Patient_ID"
32
+
33
+ # ---------------------------------------------------
34
+ # LOAD MODELS
35
+ # ---------------------------------------------------
36
+
37
+ @st.cache_resource
38
+ def load_cnn():
39
+ return tf.keras.models.load_model(CNN_MODEL_PATH, compile=False)
40
+
41
+ @st.cache_resource
42
+ def load_dnn():
43
+ return tf.keras.models.load_model(DNN_MODEL_PATH, compile=False)
44
+
45
+ # ---------------------------------------------------
46
+ # LOAD SURVIVAL ASSETS (COMPUTE BRESLOW BASELINE)
47
+ # ---------------------------------------------------
48
+
49
+ @st.cache_resource
50
+ def load_survival_assets():
51
+
52
+ scaler = joblib.load(SCALER_PATH)
53
+ features = json.load(open(FEATURES_PATH))
54
+
55
+ df = pd.read_csv(DATASET_PATH)
56
+
57
+ feature_df = df[features].copy()
58
+ feature_df["duration"] = df[TIME_COL]
59
+ feature_df["event"] = df[EVENT_COL]
60
+
61
+ cox = CoxPHFitter()
62
+ cox.fit(feature_df, duration_col="duration", event_col="event")
63
+
64
+ baseline = cox.baseline_cumulative_hazard_
65
+
66
+ breslow_times = baseline.index.values
67
+ breslow_H0 = baseline.values.flatten()
68
+
69
+ return scaler, features, breslow_times, breslow_H0
70
+
71
+
72
+ cnn_model = load_cnn()
73
+ dnn_model = load_dnn()
74
+ scaler, feature_cols, breslow_times, breslow_H0 = load_survival_assets()
75
+
76
+ # ---------------------------------------------------
77
+ # IMAGE PREPROCESSING
78
+ # ---------------------------------------------------
79
+
80
+ def preprocess_image(image):
81
+
82
+ if image.mode != "L":
83
+ image = image.convert("L")
84
+
85
+ image = image.resize((224, 224))
86
+
87
+ img = np.array(image) / 255.0
88
+ img = img[np.newaxis, ..., np.newaxis]
89
+
90
+ return img
91
+
92
+ # ---------------------------------------------------
93
+ # CNN PREDICTION
94
+ # ---------------------------------------------------
95
+
96
+ def predict_cancer(image):
97
+
98
+ img = preprocess_image(image)
99
+
100
+ pred = cnn_model.predict(img, verbose=0)[0][0]
101
+
102
+ result = "Malignant" if pred > 0.5 else "Benign"
103
+
104
+ confidence = pred if pred > 0.5 else 1 - pred
105
+
106
+ return result, confidence, pred
107
+
108
+ # ---------------------------------------------------
109
+ # SURVIVAL FUNCTION
110
+ # ---------------------------------------------------
111
+
112
+ def survival_prob(risk, t):
113
+
114
+ idx = np.searchsorted(breslow_times, t, side="right") - 1
115
+
116
+ if idx < 0:
117
+ return 1.0
118
+
119
+ h0 = breslow_H0[idx]
120
+
121
+ return float(np.exp(-h0 * np.exp(risk)))
122
+
123
+ # ---------------------------------------------------
124
+ # SURVIVAL PREDICTION
125
+ # ---------------------------------------------------
126
+
127
+ def predict_survival(feature_values):
128
+
129
+ row = np.array([feature_values], dtype=np.float32)
130
+
131
+ row = scaler.transform(row)
132
+
133
+ risk = float(dnn_model.predict(row, verbose=0)[0][0])
134
+
135
+ s1 = survival_prob(risk, 12) * 100
136
+ s3 = survival_prob(risk, 36) * 100
137
+ s5 = survival_prob(risk, 60) * 100
138
+
139
+ return risk, s1, s3, s5
140
+
141
+ # ---------------------------------------------------
142
+ # UI
143
+ # ---------------------------------------------------
144
+
145
+ st.title("🧬 Breast Cancer AI Diagnosis & Survival System")
146
+
147
+ st.markdown(
148
+ """
149
+ This system integrates two AI models:
150
+
151
+ β€’ **CNN model** β†’ detects tumor malignancy from medical images
152
+ β€’ **Survival DNN** β†’ predicts patient survival probabilities
153
+ """
154
+ )
155
+
156
+ tab1, tab2 = st.tabs(["πŸ”¬ Image Diagnosis", "πŸ“ˆ Survival Analysis"])
157
+
158
+ # ---------------------------------------------------
159
+ # TAB 1 : IMAGE PREDICTION
160
+ # ---------------------------------------------------
161
+
162
+ with tab1:
163
+
164
+ st.header("Tumor Image Classification")
165
+
166
+ uploaded = st.file_uploader(
167
+ "Upload Histopathology Image",
168
+ type=["png", "jpg", "jpeg"]
169
+ )
170
+
171
+ if uploaded:
172
+
173
+ image = Image.open(uploaded)
174
+
175
+ st.image(image, width=300)
176
+
177
+ if st.button("Analyze Image"):
178
+
179
+ result, conf, score = predict_cancer(image)
180
+
181
+ st.subheader("Prediction Result")
182
+
183
+ col1, col2 = st.columns(2)
184
+
185
+ col1.metric("Diagnosis", result)
186
+
187
+ col2.metric("Confidence", f"{conf*100:.2f}%")
188
+
189
+ st.write("Prediction Score:", round(score, 4))
190
+
191
+ # ---------------------------------------------------
192
+ # TAB 2 : SURVIVAL ANALYSIS
193
+ # ---------------------------------------------------
194
+
195
+ with tab2:
196
+
197
+ st.header("Patient Survival Prediction")
198
+
199
+ st.write("Enter patient clinical features")
200
+
201
+ inputs = []
202
+
203
+ cols = st.columns(3)
204
+
205
+ for i, f in enumerate(feature_cols):
206
+
207
+ value = cols[i % 3].number_input(
208
+ f,
209
+ value=0.0,
210
+ step=0.1
211
+ )
212
+
213
+ inputs.append(value)
214
+
215
+ if st.button("Predict Survival"):
216
+
217
+ risk, s1, s3, s5 = predict_survival(inputs)
218
+
219
+ st.subheader("Risk Score")
220
+
221
+ st.metric("Risk Score", round(risk, 4))
222
+
223
+ st.subheader("Survival Probability")
224
+
225
+ c1, c2, c3 = st.columns(3)
226
+
227
+ c1.metric("1-Year Survival", f"{s1:.1f}%")
228
+ c2.metric("3-Year Survival", f"{s3:.1f}%")
229
+ c3.metric("5-Year Survival", f"{s5:.1f}%")
230
+
231
+ if risk > 0:
232
+ st.error("High Risk Category")
233
+ else:
234
+ st.success("Low Risk Category")
235
+
236
+ # ---------------------------------------------------
237
+ # FOOTER
238
+ # ---------------------------------------------------
239
+
240
+ st.markdown("---")
241
+ st.caption("AI-assisted clinical decision support system")
features.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["Age at Diagnosis", "Lymph nodes examined positive", "Tumor Size", "Mutation Count", "Nottingham prognostic index", "Tumor Stage_encoded", "Neoplasm Histologic Grade_encoded", "Cellularity_encoded", "ER Status_encoded", "HER2 Status_encoded", "Hormone Therapy_encoded", "Chemotherapy_encoded", "Inferred Menopausal State_encoded", "Type of Breast Surgery_encoded", "PR Status_encoded", "Integrative Cluster_target_enc", "tumor_size_log", "lymph_node_ratio", "age_stage_interaction", "favorable_biomarker", "high_risk_molecular", "high_prolif", "treatment_intensity", "early_event"]
km_curve.png ADDED
processed_breast_cancer_data(1).csv ADDED
The diff for this file is too large to render. See raw diff
 
scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdb09afa6aa89eb5c5a431717373bcb56645c64254757cfad2828ac0fce96032
3
+ size 1191
survival_model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da1f58915a573ecd0fd0126b501f39e760a628f2ecc9dd17c400dd24288b93d7
3
+ size 265034