varshitha22 commited on
Commit
5c15b59
·
verified ·
1 Parent(s): 7dff983

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -7,9 +7,10 @@ from sklearn.impute import SimpleImputer
7
  from sklearn.compose import ColumnTransformer
8
  from sklearn.model_selection import train_test_split
9
  from sklearn.tree import DecisionTreeClassifier
10
- from sklearn.svm import SVC
11
  from sklearn.linear_model import LogisticRegression
12
  from sklearn.neighbors import KNeighborsClassifier
 
 
13
 
14
  # Load dataset
15
  def load_data():
@@ -44,9 +45,10 @@ def preprocess_data(df):
44
  def train_model(X_train, y_train, preprocess, model_name):
45
  models = {
46
  'Decision Tree': DecisionTreeClassifier(),
47
- 'SVM': SVC(),
48
  'Logistic Regression': LogisticRegression(),
49
- 'KNN': KNeighborsClassifier()
 
 
50
  }
51
  pipeline = Pipeline([
52
  ('preprocessor', preprocess),
@@ -61,7 +63,7 @@ st.set_page_config(page_title='Cancer Prediction App', layout='wide')
61
  with st.sidebar:
62
  st.image('https://via.placeholder.com/300x150.png?text=Cancer+Prediction')
63
  st.markdown("### Select Machine Learning Model")
64
- model_name = st.radio("Choose a Model", ['Decision Tree', 'SVM', 'Logistic Regression', 'KNN'])
65
  if st.button("Train Model"):
66
  df = load_data()
67
  (X_train, X_test, y_train, y_test), preprocess = preprocess_data(df)
 
7
  from sklearn.compose import ColumnTransformer
8
  from sklearn.model_selection import train_test_split
9
  from sklearn.tree import DecisionTreeClassifier
 
10
  from sklearn.linear_model import LogisticRegression
11
  from sklearn.neighbors import KNeighborsClassifier
12
+ from sklearn.ensemble import RandomForestClassifier
13
+ from xgboost import XGBClassifier
14
 
15
  # Load dataset
16
  def load_data():
 
45
  def train_model(X_train, y_train, preprocess, model_name):
46
  models = {
47
  'Decision Tree': DecisionTreeClassifier(),
 
48
  'Logistic Regression': LogisticRegression(),
49
+ 'KNN': KNeighborsClassifier(),
50
+ 'Random Forest': RandomForestClassifier(),
51
+ 'XGBoost': XGBClassifier()
52
  }
53
  pipeline = Pipeline([
54
  ('preprocessor', preprocess),
 
63
  with st.sidebar:
64
  st.image('https://via.placeholder.com/300x150.png?text=Cancer+Prediction')
65
  st.markdown("### Select Machine Learning Model")
66
+ model_name = st.radio("Choose a Model", ['Decision Tree', 'Logistic Regression', 'KNN', 'Random Forest', 'XGBoost'])
67
  if st.button("Train Model"):
68
  df = load_data()
69
  (X_train, X_test, y_train, y_test), preprocess = preprocess_data(df)