Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -40,7 +40,20 @@ def load_data(file):
|
|
| 40 |
except Exception as e:
|
| 41 |
return None, [], pd.DataFrame(), "", f"β Error loading file: {e}"
|
| 42 |
|
| 43 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
if df is None or df.empty:
|
| 45 |
return "Please upload a valid dataset first.", None, ""
|
| 46 |
if target_col not in df.columns:
|
|
@@ -52,37 +65,37 @@ def train_model(df, target_col, feature_cols):
|
|
| 52 |
if df_clean.empty:
|
| 53 |
return "No data left after removing missing values.", None, ""
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
|
| 61 |
-
try:
|
| 62 |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 63 |
-
except ValueError as e:
|
| 64 |
-
return f"Error splitting data: {e}", None, ""
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
ax.set(xlabel='Predicted', ylabel='True', title='Confusion Matrix')
|
| 77 |
-
plt.tight_layout()
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
plt.close(fig)
|
| 82 |
-
img_html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" alt="Confusion Matrix"/>'
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
def generate_help_text(report_text):
|
| 88 |
try:
|
|
@@ -126,9 +139,13 @@ with gr.Blocks() as demo:
|
|
| 126 |
data_summary = gr.Markdown()
|
| 127 |
|
| 128 |
with gr.Row():
|
| 129 |
-
target_col = gr.Dropdown(label="
|
| 130 |
feature_cols = gr.CheckboxGroup(label="π Feature Columns")
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
train_btn = gr.Button("π Train Model")
|
| 133 |
|
| 134 |
output = gr.Textbox(label="π Classification Report", lines=10)
|
|
@@ -143,8 +160,9 @@ with gr.Blocks() as demo:
|
|
| 143 |
|
| 144 |
train_btn.click(
|
| 145 |
fn=train_model,
|
| 146 |
-
inputs=[df_state, target_col, feature_cols],
|
| 147 |
outputs=[output, confusion_plot, help_box]
|
| 148 |
)
|
| 149 |
|
| 150 |
-
|
|
|
|
|
|
| 40 |
except Exception as e:
|
| 41 |
return None, [], pd.DataFrame(), "", f"β Error loading file: {e}"
|
| 42 |
|
| 43 |
+
def preprocess_features(df, feature_cols, recategorize_quartiles=False, count_words=False):
|
| 44 |
+
processed_df = df.copy()
|
| 45 |
+
|
| 46 |
+
for col in feature_cols:
|
| 47 |
+
if recategorize_quartiles and pd.api.types.is_numeric_dtype(processed_df[col]):
|
| 48 |
+
processed_df[col] = pd.qcut(processed_df[col], q=4, duplicates='drop').astype(str)
|
| 49 |
+
|
| 50 |
+
if count_words and processed_df[col].dtype == object:
|
| 51 |
+
processed_df[col] = processed_df[col].fillna("").apply(lambda x: len(str(x).split()))
|
| 52 |
+
|
| 53 |
+
X = pd.get_dummies(processed_df[feature_cols])
|
| 54 |
+
return X
|
| 55 |
+
|
| 56 |
+
def train_model(df, target_col, feature_cols, recategorize_quartiles=False, count_words=False):
|
| 57 |
if df is None or df.empty:
|
| 58 |
return "Please upload a valid dataset first.", None, ""
|
| 59 |
if target_col not in df.columns:
|
|
|
|
| 65 |
if df_clean.empty:
|
| 66 |
return "No data left after removing missing values.", None, ""
|
| 67 |
|
| 68 |
+
try:
|
| 69 |
+
X = preprocess_features(df_clean, feature_cols, recategorize_quartiles, count_words)
|
| 70 |
+
y = df_clean[target_col]
|
| 71 |
|
| 72 |
+
if y.nunique() < 2:
|
| 73 |
+
return "Target must have at least two classes.", None, ""
|
| 74 |
|
|
|
|
| 75 |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
model = RandomForestClassifier(random_state=42)
|
| 78 |
+
model.fit(X_train, y_train)
|
| 79 |
+
y_pred = model.predict(X_test)
|
| 80 |
+
|
| 81 |
+
report = classification_report(y_test, y_pred)
|
| 82 |
|
| 83 |
+
cm = confusion_matrix(y_test, y_pred)
|
| 84 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 85 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
|
| 86 |
+
ax.set(xlabel='Predicted', ylabel='True', title='Confusion Matrix')
|
| 87 |
+
plt.tight_layout()
|
| 88 |
|
| 89 |
+
buf = io.BytesIO()
|
| 90 |
+
plt.savefig(buf, format="png")
|
| 91 |
+
plt.close(fig)
|
| 92 |
+
img_html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" alt="Confusion Matrix"/>'
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
help_text = generate_help_text(report)
|
| 95 |
+
return report, img_html, help_text
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
except Exception as e:
|
| 98 |
+
return f"Error during training: {e}", None, ""
|
| 99 |
|
| 100 |
def generate_help_text(report_text):
|
| 101 |
try:
|
|
|
|
| 139 |
data_summary = gr.Markdown()
|
| 140 |
|
| 141 |
with gr.Row():
|
| 142 |
+
target_col = gr.Dropdown(label="π Target Column")
|
| 143 |
feature_cols = gr.CheckboxGroup(label="π Feature Columns")
|
| 144 |
|
| 145 |
+
with gr.Row():
|
| 146 |
+
recategorize_quartiles = gr.Checkbox(label="Discretize Numeric Columns into Quartiles")
|
| 147 |
+
count_words = gr.Checkbox(label="Count Words in Text Columns")
|
| 148 |
+
|
| 149 |
train_btn = gr.Button("π Train Model")
|
| 150 |
|
| 151 |
output = gr.Textbox(label="π Classification Report", lines=10)
|
|
|
|
| 160 |
|
| 161 |
train_btn.click(
|
| 162 |
fn=train_model,
|
| 163 |
+
inputs=[df_state, target_col, feature_cols, recategorize_quartiles, count_words],
|
| 164 |
outputs=[output, confusion_plot, help_box]
|
| 165 |
)
|
| 166 |
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
demo.launch()
|