Spaces:
Sleeping
Sleeping
Update cancer.py
Browse files
cancer.py
CHANGED
|
@@ -10,6 +10,7 @@ from sklearn.linear_model import LogisticRegression
|
|
| 10 |
from sklearn.neighbors import KNeighborsClassifier
|
| 11 |
from sklearn.ensemble import RandomForestClassifier
|
| 12 |
from xgboost import XGBClassifier
|
|
|
|
| 13 |
|
| 14 |
# Load dataset
|
| 15 |
def load_data():
|
|
@@ -34,21 +35,28 @@ def preprocess_data(df):
|
|
| 34 |
('imputer', SimpleImputer(strategy='most_frequent')),
|
| 35 |
('encoder', OneHotEncoder(sparse_output=False, handle_unknown='ignore'))
|
| 36 |
]), nominal)
|
| 37 |
-
], remainder='
|
| 38 |
|
| 39 |
x = df.drop('Cancer_Present', axis=1)
|
| 40 |
y = df['Cancer_Present']
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# Train Model
|
| 44 |
def train_model(x_train, y_train, preprocess, model_name):
|
| 45 |
models = {
|
| 46 |
-
'Decision Tree': DecisionTreeClassifier(),
|
| 47 |
-
'Logistic Regression': LogisticRegression(),
|
| 48 |
-
'KNN': KNeighborsClassifier(),
|
| 49 |
-
'Random Forest': RandomForestClassifier(),
|
| 50 |
-
'XGBoost': XGBClassifier()
|
| 51 |
}
|
|
|
|
| 52 |
pipeline = Pipeline([
|
| 53 |
('preprocessor', preprocess),
|
| 54 |
('classifier', models[model_name])
|
|
@@ -62,9 +70,10 @@ st.set_page_config(page_title='Cancer Prediction App', layout='wide')
|
|
| 62 |
with st.sidebar:
|
| 63 |
st.markdown("### Select Machine Learning Model")
|
| 64 |
model_name = st.radio("Choose a Model", ['Decision Tree', 'Logistic Regression', 'KNN', 'Random Forest', 'XGBoost'])
|
|
|
|
| 65 |
if st.button("Train Model"):
|
| 66 |
df = load_data()
|
| 67 |
-
|
| 68 |
model = train_model(x_train, y_train, preprocess, model_name)
|
| 69 |
accuracy = model.score(x_test, y_test)
|
| 70 |
st.session_state['trained_model'] = model
|
|
@@ -85,8 +94,8 @@ with col1:
|
|
| 85 |
|
| 86 |
with col2:
|
| 87 |
smoking_history = st.selectbox("Smoking History", ['Never Smoker', 'Former Smoker', 'Current Smoker'])
|
| 88 |
-
alcohol_consumption = st.selectbox("Alcohol Consumption", ['Low','Moderate','High'])
|
| 89 |
-
exercise_frequency = st.selectbox("Exercise Frequency", ['Rarely', 'Occasionally', 'Regularly','Never'])
|
| 90 |
gender = st.selectbox("Gender", ['Male', 'Female'])
|
| 91 |
family_history = st.selectbox("Family History", ["No", "Yes"])
|
| 92 |
|
|
@@ -105,11 +114,8 @@ if st.button("Predict Cancer Presence"):
|
|
| 105 |
for col in ['Age', 'Tumor_Size']:
|
| 106 |
input_df[col] = pd.to_numeric(input_df[col], errors='coerce')
|
| 107 |
|
| 108 |
-
# Apply preprocessing
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
# Make prediction
|
| 112 |
-
prediction = model.named_steps['classifier'].predict(input_transformed)
|
| 113 |
|
| 114 |
if prediction[0] == 1:
|
| 115 |
st.markdown("<h3 style='color: red;'>Cancer Prediction: Positive 🟥</h3>", unsafe_allow_html=True)
|
|
@@ -120,3 +126,4 @@ if st.button("Predict Cancer Presence"):
|
|
| 120 |
else:
|
| 121 |
st.error("Please train a model first!")
|
| 122 |
|
|
|
|
|
|
| 10 |
from sklearn.neighbors import KNeighborsClassifier
|
| 11 |
from sklearn.ensemble import RandomForestClassifier
|
| 12 |
from xgboost import XGBClassifier
|
| 13 |
+
from imblearn.over_sampling import SMOTE # For handling class imbalance
|
| 14 |
|
| 15 |
# Load dataset
|
| 16 |
def load_data():
|
|
|
|
| 35 |
('imputer', SimpleImputer(strategy='most_frequent')),
|
| 36 |
('encoder', OneHotEncoder(sparse_output=False, handle_unknown='ignore'))
|
| 37 |
]), nominal)
|
| 38 |
+
], remainder='drop') # Drop unlisted columns
|
| 39 |
|
| 40 |
x = df.drop('Cancer_Present', axis=1)
|
| 41 |
y = df['Cancer_Present']
|
| 42 |
+
|
| 43 |
+
# Handling class imbalance using SMOTE
|
| 44 |
+
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=23, stratify=y)
|
| 45 |
+
smote = SMOTE(random_state=23)
|
| 46 |
+
x_train, y_train = smote.fit_resample(x_train, y_train)
|
| 47 |
+
|
| 48 |
+
return x_train, x_test, y_train, y_test, preprocess
|
| 49 |
|
| 50 |
# Train Model
|
| 51 |
def train_model(x_train, y_train, preprocess, model_name):
|
| 52 |
models = {
|
| 53 |
+
'Decision Tree': DecisionTreeClassifier(max_depth=5),
|
| 54 |
+
'Logistic Regression': LogisticRegression(max_iter=1000),
|
| 55 |
+
'KNN': KNeighborsClassifier(n_neighbors=5),
|
| 56 |
+
'Random Forest': RandomForestClassifier(n_estimators=100, max_depth=5),
|
| 57 |
+
'XGBoost': XGBClassifier(use_label_encoder=False, eval_metric='logloss')
|
| 58 |
}
|
| 59 |
+
|
| 60 |
pipeline = Pipeline([
|
| 61 |
('preprocessor', preprocess),
|
| 62 |
('classifier', models[model_name])
|
|
|
|
| 70 |
with st.sidebar:
|
| 71 |
st.markdown("### Select Machine Learning Model")
|
| 72 |
model_name = st.radio("Choose a Model", ['Decision Tree', 'Logistic Regression', 'KNN', 'Random Forest', 'XGBoost'])
|
| 73 |
+
|
| 74 |
if st.button("Train Model"):
|
| 75 |
df = load_data()
|
| 76 |
+
x_train, x_test, y_train, y_test, preprocess = preprocess_data(df)
|
| 77 |
model = train_model(x_train, y_train, preprocess, model_name)
|
| 78 |
accuracy = model.score(x_test, y_test)
|
| 79 |
st.session_state['trained_model'] = model
|
|
|
|
| 94 |
|
| 95 |
with col2:
|
| 96 |
smoking_history = st.selectbox("Smoking History", ['Never Smoker', 'Former Smoker', 'Current Smoker'])
|
| 97 |
+
alcohol_consumption = st.selectbox("Alcohol Consumption", ['Low', 'Moderate', 'High'])
|
| 98 |
+
exercise_frequency = st.selectbox("Exercise Frequency", ['Rarely', 'Occasionally', 'Regularly', 'Never'])
|
| 99 |
gender = st.selectbox("Gender", ['Male', 'Female'])
|
| 100 |
family_history = st.selectbox("Family History", ["No", "Yes"])
|
| 101 |
|
|
|
|
| 114 |
for col in ['Age', 'Tumor_Size']:
|
| 115 |
input_df[col] = pd.to_numeric(input_df[col], errors='coerce')
|
| 116 |
|
| 117 |
+
# Apply preprocessing using the same pipeline
|
| 118 |
+
prediction = model.predict(input_df)
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
if prediction[0] == 1:
|
| 121 |
st.markdown("<h3 style='color: red;'>Cancer Prediction: Positive 🟥</h3>", unsafe_allow_html=True)
|
|
|
|
| 126 |
else:
|
| 127 |
st.error("Please train a model first!")
|
| 128 |
|
| 129 |
+
|