saherPervaiz commited on
Commit
7b01798
·
verified ·
1 Parent(s): dacbe8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -83,8 +83,11 @@ if uploaded_file is not None:
83
  X = df_cleaned[features]
84
  y = df_cleaned[target]
85
 
86
- # Determine if the target is continuous or categorical
87
- is_classification = len(y.unique()) <= 10 # If target has fewer than or equal to 10 unique values, treat as classification
 
 
 
88
 
89
  # Ensure there is enough data before proceeding with train-test split
90
  if len(X) == 0 or len(y) == 0:
@@ -98,7 +101,6 @@ if uploaded_file is not None:
98
  results = []
99
 
100
  # Model Selection and Evaluation
101
- models = []
102
  if is_classification:
103
  model_choices = [
104
  ("Random Forest", RandomForestClassifier(n_estimators=50)),
@@ -116,7 +118,7 @@ if uploaded_file is not None:
116
  class_report = classification_report(y_test, y_pred)
117
  results.append([name, accuracy, class_report])
118
 
119
- else:
120
  model_choices = [
121
  ("Random Forest", RandomForestRegressor(n_estimators=50)),
122
  ("Linear Regression", LinearRegression()),
 
83
  X = df_cleaned[features]
84
  y = df_cleaned[target]
85
 
86
+ # Check if the target is continuous (for regression) or categorical (for classification)
87
+ if y.dtype == 'O' or len(y.unique()) <= 10: # Treat as classification if target is categorical or has <= 10 unique values
88
+ is_classification = True
89
+ else:
90
+ is_classification = False
91
 
92
  # Ensure there is enough data before proceeding with train-test split
93
  if len(X) == 0 or len(y) == 0:
 
101
  results = []
102
 
103
  # Model Selection and Evaluation
 
104
  if is_classification:
105
  model_choices = [
106
  ("Random Forest", RandomForestClassifier(n_estimators=50)),
 
118
  class_report = classification_report(y_test, y_pred)
119
  results.append([name, accuracy, class_report])
120
 
121
+ else: # Regression models
122
  model_choices = [
123
  ("Random Forest", RandomForestRegressor(n_estimators=50)),
124
  ("Linear Regression", LinearRegression()),