clementBE commited on
Commit
26083b5
·
verified ·
1 Parent(s): 8d1aad9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -55
app.py CHANGED
@@ -2,7 +2,10 @@ import gradio as gr
2
  import pandas as pd
3
  from sklearn.model_selection import train_test_split
4
  from sklearn.ensemble import RandomForestClassifier
5
- from sklearn.metrics import classification_report
 
 
 
6
 
7
  def load_data(file):
8
  if file is None:
@@ -13,93 +16,96 @@ def load_data(file):
13
  else:
14
  df = pd.read_excel(file.name)
15
  columns = list(df.columns)
16
- return df, columns, df.head(100)
17
- except Exception:
18
  return None, [], pd.DataFrame()
19
 
20
- def generate_dynamic_help(report_dict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  help_lines = []
22
- classes = [k for k in report_dict if k not in ('accuracy', 'macro avg', 'weighted avg')]
23
-
24
- # Find class with lowest recall and lowest precision
25
- lowest_recall_class = min(classes, key=lambda c: report_dict[c]['recall'])
26
- lowest_precision_class = min(classes, key=lambda c: report_dict[c]['precision'])
27
-
28
- # Overall accuracy
29
- accuracy = report_dict.get('accuracy', None)
30
- if accuracy is not None:
31
- help_lines.append(f"**Overall accuracy** of the model is {accuracy:.2f}.")
32
-
33
- # Comment on recall
34
- recall_val = report_dict[lowest_recall_class]['recall']
35
- if recall_val < 0.5:
36
- help_lines.append(f"Class '{lowest_recall_class}' has a low recall ({recall_val:.2f}), meaning many true instances of this class are missed.")
37
- else:
38
- help_lines.append(f"Class '{lowest_recall_class}' has the lowest recall ({recall_val:.2f}), but it's above 0.5, which is reasonable.")
39
-
40
- # Comment on precision
41
- precision_val = report_dict[lowest_precision_class]['precision']
42
- if precision_val < 0.5:
43
- help_lines.append(f"Class '{lowest_precision_class}' has a low precision ({precision_val:.2f}), indicating many false positives.")
44
- else:
45
- help_lines.append(f"Class '{lowest_precision_class}' has the lowest precision ({precision_val:.2f}), which is acceptable.")
46
-
47
- # Warn about low support classes
48
- low_support_classes = [c for c in classes if report_dict[c]['support'] < 10]
49
- if low_support_classes:
50
- help_lines.append(f"Note: Classes {low_support_classes} have very few samples (support < 10), which may affect metric reliability.")
51
-
52
- # General advice
53
- help_lines.append("Consider collecting more data or tuning the model if some classes show poor performance.")
54
-
55
  return "\n\n".join(help_lines)
56
 
57
  def train_model(df, target_col, feature_cols):
58
  if df is None or df.empty:
59
- return "Please upload a valid dataset first.", ""
60
  if target_col not in df.columns:
61
- return "Target column not found in dataset.", ""
62
  if not feature_cols:
63
- return "Please select at least one feature column.", ""
64
 
65
  df_clean = df[[target_col] + feature_cols].dropna()
66
  if df_clean.empty:
67
- return "No data left after removing missing values.", ""
68
 
69
  X = df_clean[feature_cols]
70
  y = df_clean[target_col]
71
 
72
  if y.nunique() < 2:
73
- return "Target must have at least 2 classes.", ""
74
 
75
  X_enc = pd.get_dummies(X)
76
 
77
  try:
78
  X_train, X_test, y_train, y_test = train_test_split(X_enc, y, test_size=0.2, random_state=42)
79
  except ValueError as e:
80
- return f"Error splitting data: {e}", ""
81
 
82
  if X_train.shape[0] == 0 or X_test.shape[0] == 0:
83
- return "Empty train or test set after splitting.", ""
84
 
85
  model = RandomForestClassifier(random_state=42)
86
  model.fit(X_train, y_train)
87
  y_pred = model.predict(X_test)
88
 
89
- report_dict = classification_report(y_test, y_pred, output_dict=True)
90
- report_str = classification_report(y_test, y_pred)
91
- help_str = generate_dynamic_help(report_dict)
92
 
93
- return report_str, help_str
 
 
 
94
 
95
  def on_file_change(file):
96
  df, columns, preview = load_data(file)
97
  if df is None:
98
- return None, gr.update(choices=[], value=None), gr.update(choices=[], value=[]), pd.DataFrame()
99
- return df, gr.update(choices=columns, value=None), gr.update(choices=columns, value=[]), preview
100
 
101
  with gr.Blocks() as demo:
102
- gr.Markdown("# XLSX/CSV Classification App with Dynamic Help")
103
 
104
  df_state = gr.State(None)
105
 
@@ -111,9 +117,14 @@ with gr.Blocks() as demo:
111
  target_col = gr.Dropdown(label="Select Target Column", choices=[])
112
  with gr.Row():
113
  feature_cols = gr.CheckboxGroup(label="Select Feature Columns", choices=[])
114
- train_btn = gr.Button("Train Model")
115
- output_report = gr.Textbox(label="Classification Report", lines=10)
116
- output_help = gr.Markdown(label="Dynamic Help")
 
 
 
 
 
117
 
118
  file_input.change(
119
  fn=on_file_change,
@@ -124,7 +135,7 @@ with gr.Blocks() as demo:
124
  train_btn.click(
125
  fn=train_model,
126
  inputs=[df_state, target_col, feature_cols],
127
- outputs=[output_report, output_help]
128
  )
129
 
130
  demo.launch()
 
2
  import pandas as pd
3
  from sklearn.model_selection import train_test_split
4
  from sklearn.ensemble import RandomForestClassifier
5
+ from sklearn.metrics import classification_report, confusion_matrix
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import io
9
 
10
  def load_data(file):
11
  if file is None:
 
16
  else:
17
  df = pd.read_excel(file.name)
18
  columns = list(df.columns)
19
+ return df, columns, df.head(100) # Show first 100 rows as preview
20
+ except Exception as e:
21
  return None, [], pd.DataFrame()
22
 
23
+ def plot_confusion_matrix(y_true, y_pred, labels):
24
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
25
+ plt.figure(figsize=(6,5))
26
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
27
+ xticklabels=labels, yticklabels=labels)
28
+ plt.xlabel("Predicted")
29
+ plt.ylabel("Actual")
30
+ plt.title("Confusion Matrix")
31
+ buf = io.BytesIO()
32
+ plt.savefig(buf, format='png')
33
+ plt.close()
34
+ buf.seek(0)
35
+ return buf
36
+
37
+ def generate_dynamic_help(report):
38
+ # Simple example: check for precision or recall < 0.5 and suggest caution
39
+ lines = report.splitlines()
40
  help_lines = []
41
+ for line in lines:
42
+ if line.strip() == "":
43
+ continue
44
+ parts = line.split()
45
+ if len(parts) >= 4 and parts[0] not in ("accuracy", "macro", "weighted"):
46
+ try:
47
+ precision = float(parts[1])
48
+ recall = float(parts[2])
49
+ f1 = float(parts[3])
50
+ cls = parts[0]
51
+ if precision < 0.5:
52
+ help_lines.append(f"⚠️ Precision for class **{cls}** is low ({precision:.2f}). The model often misclassifies samples as this class.")
53
+ if recall < 0.5:
54
+ help_lines.append(f"⚠️ Recall for class **{cls}** is low ({recall:.2f}). The model misses many samples of this class.")
55
+ except:
56
+ continue
57
+ if not help_lines:
58
+ return "✅ Model performance looks good across all classes."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  return "\n\n".join(help_lines)
60
 
61
  def train_model(df, target_col, feature_cols):
62
  if df is None or df.empty:
63
+ return "Please upload a valid dataset first.", "", None
64
  if target_col not in df.columns:
65
+ return "Target column not found in dataset.", "", None
66
  if not feature_cols:
67
+ return "Please select at least one feature column.", "", None
68
 
69
  df_clean = df[[target_col] + feature_cols].dropna()
70
  if df_clean.empty:
71
+ return "No data left after removing missing values.", "", None
72
 
73
  X = df_clean[feature_cols]
74
  y = df_clean[target_col]
75
 
76
  if y.nunique() < 2:
77
+ return "Target must have at least 2 classes.", "", None
78
 
79
  X_enc = pd.get_dummies(X)
80
 
81
  try:
82
  X_train, X_test, y_train, y_test = train_test_split(X_enc, y, test_size=0.2, random_state=42)
83
  except ValueError as e:
84
+ return f"Error splitting data: {e}", "", None
85
 
86
  if X_train.shape[0] == 0 or X_test.shape[0] == 0:
87
+ return "Empty train or test set after splitting.", "", None
88
 
89
  model = RandomForestClassifier(random_state=42)
90
  model.fit(X_train, y_train)
91
  y_pred = model.predict(X_test)
92
 
93
+ report = classification_report(y_test, y_pred)
94
+ dynamic_help = generate_dynamic_help(report)
 
95
 
96
+ labels = sorted(y_test.unique())
97
+ cm_buf = plot_confusion_matrix(y_test, y_pred, labels)
98
+
99
+ return report, dynamic_help, cm_buf
100
 
101
  def on_file_change(file):
102
  df, columns, preview = load_data(file)
103
  if df is None:
104
+ return None, gr.Dropdown.update(choices=[], value=None), gr.CheckboxGroup.update(choices=[], value=[]), pd.DataFrame()
105
+ return df, gr.Dropdown.update(choices=columns, value=None), gr.CheckboxGroup.update(choices=columns, value=[]), preview
106
 
107
  with gr.Blocks() as demo:
108
+ gr.Markdown("# XLSX/CSV Classification App with Table Preview and Visualization")
109
 
110
  df_state = gr.State(None)
111
 
 
117
  target_col = gr.Dropdown(label="Select Target Column", choices=[])
118
  with gr.Row():
119
  feature_cols = gr.CheckboxGroup(label="Select Feature Columns", choices=[])
120
+ with gr.Row():
121
+ train_btn = gr.Button("Train Model")
122
+ with gr.Row():
123
+ output_report = gr.Textbox(label="Classification Report", lines=10)
124
+ with gr.Row():
125
+ output_help = gr.Markdown(label="Model Performance Help")
126
+ with gr.Row():
127
+ cm_image = gr.Image(label="Confusion Matrix")
128
 
129
  file_input.change(
130
  fn=on_file_change,
 
135
  train_btn.click(
136
  fn=train_model,
137
  inputs=[df_state, target_col, feature_cols],
138
+ outputs=[output_report, output_help, cm_image]
139
  )
140
 
141
  demo.launch()