Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.metrics import classification_report, accuracy_score | |
| from PIL import Image | |
| original_df = None | |
| processed_df = None | |
| trained_model = None | |
| processed_X_columns = None # Keep processed features list for importances | |
| def load_data(file): | |
| global original_df | |
| try: | |
| if file.name.endswith('.csv'): | |
| original_df = pd.read_csv(file) | |
| else: | |
| original_df = pd.read_excel(file) | |
| help_text = ( | |
| "Step 1: Data loaded successfully!\n" | |
| "- Preview shows first 10 rows.\n" | |
| "- Next: Click 'Process Data' to discretize numeric columns and add word counts for text." | |
| ) | |
| return original_df.head(10), "β File loaded successfully.", help_text | |
| except Exception as e: | |
| return pd.DataFrame(), f"β Error loading file: {e}", "Please upload a valid CSV or Excel file." | |
| def process_data(): | |
| global original_df, processed_df | |
| if original_df is None: | |
| return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]), "β οΈ Please load a dataset first.", "" | |
| df = original_df.copy() | |
| # Quartiles discretization | |
| for col in df.select_dtypes(include=np.number).columns: | |
| try: | |
| df[col + "_qbin"] = pd.qcut(df[col], 4, labels=False, duplicates='drop') | |
| except Exception: | |
| pass | |
| # Deciles discretization | |
| for col in df.select_dtypes(include=np.number).columns: | |
| try: | |
| df[col + "_decil"] = pd.qcut(df[col], 10, labels=False, duplicates='drop') | |
| except Exception: | |
| pass | |
| # Word counts for text columns | |
| for col in df.select_dtypes(include='object').columns: | |
| df[col + "_wordcount"] = df[col].astype(str).apply(lambda x: len(x.split())) | |
| processed_df = df.copy() | |
| all_columns = df.columns.tolist() | |
| help_text = ( | |
| "Step 2: Data processed!\n" | |
| "- Numeric columns discretized into quartiles and deciles.\n" | |
| "- Word counts added for text columns.\n" | |
| "- You can now select your target and feature columns." | |
| ) | |
| return df.head(10), gr.update(choices=all_columns), gr.update(choices=all_columns), "β Data processed.", help_text | |
| def train_model(target_col, feature_cols): | |
| global processed_df, trained_model, processed_X_columns | |
| if processed_df is None: | |
| return "β οΈ Please process your data first.", None, "" | |
| if not target_col or not feature_cols: | |
| return "β οΈ Please select a target and at least one feature.", None, "" | |
| try: | |
| X = processed_df[feature_cols] | |
| y = processed_df[target_col] | |
| # One-hot encoding categorical features if any | |
| X = pd.get_dummies(X) | |
| processed_X_columns = X.columns.tolist() | |
| # Train/test split | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| # Train Random Forest Classifier | |
| clf = RandomForestClassifier(random_state=42) | |
| clf.fit(X_train, y_train) | |
| trained_model = clf | |
| # Predict & evaluate | |
| y_pred = clf.predict(X_test) | |
| accuracy = accuracy_score(y_test, y_pred) | |
| report = classification_report(y_test, y_pred) | |
| # Feature importances | |
| fi = clf.feature_importances_ | |
| fi_df = pd.DataFrame({'Feature': processed_X_columns, 'Importance': fi}) | |
| fi_df = fi_df.sort_values(by='Importance', ascending=False).head(20) | |
| plt.figure(figsize=(10, 6)) | |
| sns.heatmap(fi_df.set_index('Feature').T, annot=True, cmap="YlGnBu", cbar_kws={'label': 'Feature Importance'}) | |
| plt.title("Feature Importances Heatmap (Top 20)") | |
| plt.tight_layout() | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png") | |
| plt.close() | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| # Detailed help text | |
| help_text = ( | |
| f"π Model type: Random Forest Classifier\n" | |
| f"π― Target: '{target_col}'\n" | |
| f"π§ͺ Features used: {len(feature_cols)}\n" | |
| f"β Accuracy on test set: {accuracy:.2%}\n\n" | |
| "π Classification Report Explanation:\n" | |
| "- Precision: Of predicted positives, how many are correct?\n" | |
| "- Recall: Of actual positives, how many were found?\n" | |
| "- F1-Score: Harmonic mean of precision & recall.\n\n" | |
| "π‘οΈ Heatmap Explanation:\n" | |
| "- Shows top 20 most important features by model.\n" | |
| "- Darker cells = higher influence on predictions.\n" | |
| "- Use this to understand which variables drive decisions." | |
| ) | |
| return report, img, help_text | |
| except Exception as e: | |
| return f"β Model training failed: {e}", None, "" | |
| with gr.Blocks(title="Step-by-Step Model Trainer with Help and Heatmap") as app: | |
| gr.Markdown("## π§ Step-by-Step Model Trainer\nUpload your data, process it, train a model, and get help at each step!") | |
| with gr.Row(): | |
| file_input = gr.File(label="π Upload CSV or Excel file") | |
| load_status = gr.Textbox(label="βΉοΈ File Load Status", interactive=False) | |
| original_preview = gr.DataFrame(label="π Original Data Preview (first 10 rows)") | |
| load_help = gr.Textbox(label="π Step 1 Help", interactive=False) | |
| process_button = gr.Button("βοΈ Process Data") | |
| processed_preview = gr.DataFrame(label="π¬ Processed Data Preview (first 10 rows)") | |
| process_status = gr.Textbox(label="βΉοΈ Process Status", interactive=False) | |
| process_help = gr.Textbox(label="π Step 2 Help", interactive=False) | |
| target_selector = gr.Dropdown(label="π― Select Target Column", choices=[]) | |
| feature_selector = gr.CheckboxGroup(label="π Select Feature Columns", choices=[]) | |
| train_button = gr.Button("π Train Model") | |
| train_output = gr.Textbox(label="π Classification Report", lines=15) | |
| heatmap_output = gr.Image(label="π‘οΈ Feature Importance Heatmap") | |
| train_help = gr.Textbox(label="π Help to read results", interactive=False, lines=12) | |
| # Callbacks | |
| file_input.change( | |
| fn=load_data, | |
| inputs=[file_input], | |
| outputs=[original_preview, load_status, load_help] | |
| ) | |
| process_button.click( | |
| fn=process_data, | |
| inputs=[], | |
| outputs=[processed_preview, target_selector, feature_selector, process_status, process_help] | |
| ) | |
| train_button.click( | |
| fn=train_model, | |
| inputs=[target_selector, feature_selector], | |
| outputs=[train_output, heatmap_output, train_help] | |
| ) | |
| app.launch() | |