clementBE commited on
Commit
624ddf1
Β·
verified Β·
1 Parent(s): 76e3b5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -27
app.py CHANGED
@@ -40,7 +40,20 @@ def load_data(file):
40
  except Exception as e:
41
  return None, [], pd.DataFrame(), "", f"❌ Error loading file: {e}"
42
 
43
- def train_model(df, target_col, feature_cols):
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  if df is None or df.empty:
45
  return "Please upload a valid dataset first.", None, ""
46
  if target_col not in df.columns:
@@ -52,37 +65,37 @@ def train_model(df, target_col, feature_cols):
52
  if df_clean.empty:
53
  return "No data left after removing missing values.", None, ""
54
 
55
- X = pd.get_dummies(df_clean[feature_cols])
56
- y = df_clean[target_col]
 
57
 
58
- if y.nunique() < 2:
59
- return "Target must have at least two classes.", None, ""
60
 
61
- try:
62
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
63
- except ValueError as e:
64
- return f"Error splitting data: {e}", None, ""
65
 
66
- model = RandomForestClassifier(random_state=42)
67
- model.fit(X_train, y_train)
68
- y_pred = model.predict(X_test)
 
 
69
 
70
- report = classification_report(y_test, y_pred)
 
 
 
 
71
 
72
- # Plot confusion matrix
73
- cm = confusion_matrix(y_test, y_pred)
74
- fig, ax = plt.subplots(figsize=(6, 5))
75
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
76
- ax.set(xlabel='Predicted', ylabel='True', title='Confusion Matrix')
77
- plt.tight_layout()
78
 
79
- buf = io.BytesIO()
80
- plt.savefig(buf, format="png")
81
- plt.close(fig)
82
- img_html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" alt="Confusion Matrix"/>'
83
 
84
- help_text = generate_help_text(report)
85
- return report, img_html, help_text
86
 
87
  def generate_help_text(report_text):
88
  try:
@@ -126,9 +139,13 @@ with gr.Blocks() as demo:
126
  data_summary = gr.Markdown()
127
 
128
  with gr.Row():
129
- target_col = gr.Dropdown(label="🎯 Target Column")
130
  feature_cols = gr.CheckboxGroup(label="πŸ“Š Feature Columns")
131
 
 
 
 
 
132
  train_btn = gr.Button("πŸš€ Train Model")
133
 
134
  output = gr.Textbox(label="πŸ“‹ Classification Report", lines=10)
@@ -143,8 +160,9 @@ with gr.Blocks() as demo:
143
 
144
  train_btn.click(
145
  fn=train_model,
146
- inputs=[df_state, target_col, feature_cols],
147
  outputs=[output, confusion_plot, help_box]
148
  )
149
 
150
- demo.launch()
 
 
40
  except Exception as e:
41
  return None, [], pd.DataFrame(), "", f"❌ Error loading file: {e}"
42
 
43
+ def preprocess_features(df, feature_cols, recategorize_quartiles=False, count_words=False):
44
+ processed_df = df.copy()
45
+
46
+ for col in feature_cols:
47
+ if recategorize_quartiles and pd.api.types.is_numeric_dtype(processed_df[col]):
48
+ processed_df[col] = pd.qcut(processed_df[col], q=4, duplicates='drop').astype(str)
49
+
50
+ if count_words and processed_df[col].dtype == object:
51
+ processed_df[col] = processed_df[col].fillna("").apply(lambda x: len(str(x).split()))
52
+
53
+ X = pd.get_dummies(processed_df[feature_cols])
54
+ return X
55
+
56
+ def train_model(df, target_col, feature_cols, recategorize_quartiles=False, count_words=False):
57
  if df is None or df.empty:
58
  return "Please upload a valid dataset first.", None, ""
59
  if target_col not in df.columns:
 
65
  if df_clean.empty:
66
  return "No data left after removing missing values.", None, ""
67
 
68
+ try:
69
+ X = preprocess_features(df_clean, feature_cols, recategorize_quartiles, count_words)
70
+ y = df_clean[target_col]
71
 
72
+ if y.nunique() < 2:
73
+ return "Target must have at least two classes.", None, ""
74
 
 
75
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
 
 
76
 
77
+ model = RandomForestClassifier(random_state=42)
78
+ model.fit(X_train, y_train)
79
+ y_pred = model.predict(X_test)
80
+
81
+ report = classification_report(y_test, y_pred)
82
 
83
+ cm = confusion_matrix(y_test, y_pred)
84
+ fig, ax = plt.subplots(figsize=(6, 5))
85
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
86
+ ax.set(xlabel='Predicted', ylabel='True', title='Confusion Matrix')
87
+ plt.tight_layout()
88
 
89
+ buf = io.BytesIO()
90
+ plt.savefig(buf, format="png")
91
+ plt.close(fig)
92
+ img_html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" alt="Confusion Matrix"/>'
 
 
93
 
94
+ help_text = generate_help_text(report)
95
+ return report, img_html, help_text
 
 
96
 
97
+ except Exception as e:
98
+ return f"Error during training: {e}", None, ""
99
 
100
  def generate_help_text(report_text):
101
  try:
 
139
  data_summary = gr.Markdown()
140
 
141
  with gr.Row():
142
+ target_col = gr.Dropdown(label="🌟 Target Column")
143
  feature_cols = gr.CheckboxGroup(label="πŸ“Š Feature Columns")
144
 
145
+ with gr.Row():
146
+ recategorize_quartiles = gr.Checkbox(label="Discretize Numeric Columns into Quartiles")
147
+ count_words = gr.Checkbox(label="Count Words in Text Columns")
148
+
149
  train_btn = gr.Button("πŸš€ Train Model")
150
 
151
  output = gr.Textbox(label="πŸ“‹ Classification Report", lines=10)
 
160
 
161
  train_btn.click(
162
  fn=train_model,
163
+ inputs=[df_state, target_col, feature_cols, recategorize_quartiles, count_words],
164
  outputs=[output, confusion_plot, help_box]
165
  )
166
 
167
+ if __name__ == "__main__":
168
+ demo.launch()