jaker86 commited on
Commit
6753c42
·
verified ·
1 Parent(s): 439cd2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -55
app.py CHANGED
@@ -4,114 +4,141 @@ import gradio as gr
4
  from sklearn.model_selection import train_test_split
5
  from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
6
  from sklearn.metrics import confusion_matrix, classification_report, mean_squared_error, r2_score
7
- from sklearn.cluster import KMeans
8
  from sklearn.decomposition import PCA
 
 
9
  import matplotlib.pyplot as plt
10
  import seaborn as sns
11
  import io
12
 
13
- def analyze_csv(file, label_col):
14
  try:
15
  df = pd.read_csv(file.name if hasattr(file, "name") else file)
16
  except Exception as e:
17
- return None, None, None, None, f"error reading csv: {e}"
18
 
19
  if label_col not in df.columns:
20
- return None, None, None, None, f"label column '{label_col}' not in data"
21
 
22
  df = df.dropna()
23
- X = df.drop(columns=[label_col])
24
  y = df[label_col]
 
 
 
25
 
26
- # use only numeric features for modeling; drop non-numeric columns
27
- X_numeric = X.select_dtypes(include=[np.number])
28
- if X_numeric.shape[1] == 0:
29
- return None, None, None, None, "no numeric features available for modeling"
30
 
31
  results_text = ""
32
- cm_img = None
33
- reg_img = None
34
- fi_img = None
35
-
36
- # placeholder for feature importances
37
- feature_importances = None
38
 
39
- # if label is numeric, treat as regression; otherwise, classification
40
  if pd.api.types.is_numeric_dtype(y):
41
- task = "regression"
42
- X_train, X_test, y_train, y_test = train_test_split(X_numeric, y, test_size=0.3, random_state=42)
43
  model = RandomForestRegressor(random_state=42)
44
  model.fit(X_train, y_train)
45
  y_pred = model.predict(X_test)
46
  mse = mean_squared_error(y_test, y_pred)
47
  r2 = r2_score(y_test, y_pred)
48
  results_text += f"regression results:\nmse: {mse:.3f}\nr2: {r2:.3f}\n"
49
- # regression scatter plot: true vs predicted
50
- plt.figure()
51
  plt.scatter(y_test, y_pred, alpha=0.7)
 
52
  plt.xlabel("true values")
53
  plt.ylabel("predicted values")
54
  plt.title("regression: true vs predicted")
55
  buf = io.BytesIO()
56
- plt.savefig(buf, format="png")
57
  plt.close()
58
  buf.seek(0)
59
- reg_img = buf
60
- # note: confusion matrix not applicable here
61
- feature_importances = model.feature_importances_
62
  else:
63
- task = "classification"
64
- # encode labels as integers
65
- y_encoded = pd.factorize(y)[0]
66
- X_train, X_test, y_train, y_test = train_test_split(X_numeric, y_encoded, test_size=0.3, random_state=42)
67
  model = RandomForestClassifier(random_state=42)
68
  model.fit(X_train, y_train)
69
  y_pred = model.predict(X_test)
70
  cm = confusion_matrix(y_test, y_pred)
71
  cr = classification_report(y_test, y_pred)
72
  results_text += f"classification results:\n{cr}\n"
73
- plt.figure()
74
  sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
75
  plt.xlabel("predicted")
76
  plt.ylabel("true")
77
  plt.title("confusion matrix")
78
  buf = io.BytesIO()
79
- plt.savefig(buf, format="png")
80
  plt.close()
81
  buf.seek(0)
82
- cm_img = buf
83
- feature_importances = model.feature_importances_
84
-
85
- # feature importance plot
86
- fi = pd.Series(feature_importances, index=X_numeric.columns).sort_values(ascending=False)
87
  plt.figure(figsize=(8,4))
88
  sns.barplot(x=fi.values, y=fi.index)
89
  plt.title("feature importances")
 
 
90
  buf = io.BytesIO()
91
  plt.savefig(buf, format="png", bbox_inches="tight")
92
  plt.close()
93
  buf.seek(0)
94
  fi_img = buf
95
 
96
- # clustering: kmeans on numeric features; use pca for 2d visualization
97
- k = 3
98
- kmeans = KMeans(n_clusters=k, random_state=42)
99
- clusters = kmeans.fit_predict(X_numeric)
100
  pca = PCA(n_components=2, random_state=42)
101
- X_pca = pca.fit_transform(X_numeric)
102
- plt.figure()
103
- scatter = plt.scatter(X_pca[:,0], X_pca[:,1], c=clusters, cmap="viridis", alpha=0.7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  plt.xlabel("pca 1")
105
  plt.ylabel("pca 2")
106
- plt.title("kmeans clustering (k=3)")
107
- plt.colorbar(scatter, ticks=range(k))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  buf = io.BytesIO()
109
  plt.savefig(buf, format="png", bbox_inches="tight")
110
  plt.close()
111
  buf.seek(0)
112
- cluster_img = buf
113
 
114
- return cm_img, reg_img, fi_img, cluster_img, results_text
115
 
116
  def update_dropdown(file):
117
  try:
@@ -125,23 +152,26 @@ with gr.Blocks() as demo:
125
  with gr.Row():
126
  file_input = gr.File(label="upload csv", file_types=[".csv"])
127
  label_dropdown = gr.Dropdown(label="select label column", choices=[])
 
128
 
129
  file_input.change(fn=update_dropdown, inputs=file_input, outputs=label_dropdown)
130
 
131
  analyze_btn = gr.Button("analyze")
132
  with gr.Tabs():
133
  with gr.TabItem("results"):
134
- results_textbox = gr.Textbox(label="metrics & results", lines=10)
135
- with gr.TabItem("confusion matrix"):
136
- cm_output = gr.Image(label="confusion matrix")
137
- with gr.TabItem("regression plot"):
138
- reg_output = gr.Image(label="regression plot")
139
  with gr.TabItem("feature importances"):
140
  fi_output = gr.Image(label="feature importances")
141
- with gr.TabItem("clustering"):
142
- cluster_output = gr.Image(label="cluster plot")
 
 
 
 
143
 
144
- analyze_btn.click(fn=analyze_csv, inputs=[file_input, label_dropdown],
145
- outputs=[cm_output, reg_output, fi_output, cluster_output, results_textbox])
146
 
147
  demo.launch()
 
4
  from sklearn.model_selection import train_test_split
5
  from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
6
  from sklearn.metrics import confusion_matrix, classification_report, mean_squared_error, r2_score
7
+ from sklearn.cluster import KMeans, AgglomerativeClustering
8
  from sklearn.decomposition import PCA
9
+ from sklearn.preprocessing import StandardScaler
10
+ from sklearn.feature_selection import f_classif
11
  import matplotlib.pyplot as plt
12
  import seaborn as sns
13
  import io
14
 
15
+ def analyze_csv(file, label_col, n_clusters):
16
  try:
17
  df = pd.read_csv(file.name if hasattr(file, "name") else file)
18
  except Exception as e:
19
+ return (None,)*6 + (f"error reading csv: {e}",)
20
 
21
  if label_col not in df.columns:
22
+ return (None,)*6 + (f"label column '{label_col}' not in data",)
23
 
24
  df = df.dropna()
25
+ # separate target and features
26
  y = df[label_col]
27
+ X = df.drop(columns=[label_col])
28
+ # create one-hot encodings for non-numeric columns
29
+ X_processed = pd.get_dummies(X)
30
 
31
+ # scale features for clustering methods
32
+ scaler = StandardScaler()
33
+ X_scaled = scaler.fit_transform(X_processed)
 
34
 
35
  results_text = ""
36
+ model_img = None
 
 
 
 
 
37
 
38
+ # model training & evaluation: regression if y numeric, classification otherwise
39
  if pd.api.types.is_numeric_dtype(y):
40
+ # regression
41
+ X_train, X_test, y_train, y_test = train_test_split(X_processed, y, test_size=0.3, random_state=42)
42
  model = RandomForestRegressor(random_state=42)
43
  model.fit(X_train, y_train)
44
  y_pred = model.predict(X_test)
45
  mse = mean_squared_error(y_test, y_pred)
46
  r2 = r2_score(y_test, y_pred)
47
  results_text += f"regression results:\nmse: {mse:.3f}\nr2: {r2:.3f}\n"
48
+ # scatter plot: true vs predicted with y=x line
49
+ plt.figure(figsize=(6,4))
50
  plt.scatter(y_test, y_pred, alpha=0.7)
51
+ plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
52
  plt.xlabel("true values")
53
  plt.ylabel("predicted values")
54
  plt.title("regression: true vs predicted")
55
  buf = io.BytesIO()
56
+ plt.savefig(buf, format="png", bbox_inches="tight")
57
  plt.close()
58
  buf.seek(0)
59
+ model_img = buf
 
 
60
  else:
61
+ # classification
62
+ y_encoded, uniques = pd.factorize(y)
63
+ X_train, X_test, y_train, y_test = train_test_split(X_processed, y_encoded, test_size=0.3, random_state=42)
 
64
  model = RandomForestClassifier(random_state=42)
65
  model.fit(X_train, y_train)
66
  y_pred = model.predict(X_test)
67
  cm = confusion_matrix(y_test, y_pred)
68
  cr = classification_report(y_test, y_pred)
69
  results_text += f"classification results:\n{cr}\n"
70
+ plt.figure(figsize=(6,4))
71
  sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
72
  plt.xlabel("predicted")
73
  plt.ylabel("true")
74
  plt.title("confusion matrix")
75
  buf = io.BytesIO()
76
+ plt.savefig(buf, format="png", bbox_inches="tight")
77
  plt.close()
78
  buf.seek(0)
79
+ model_img = buf
80
+
81
+ # feature importance plot (from the model)
82
+ fi = pd.Series(model.feature_importances_, index=X_processed.columns).sort_values(ascending=False)
 
83
  plt.figure(figsize=(8,4))
84
  sns.barplot(x=fi.values, y=fi.index)
85
  plt.title("feature importances")
86
+ plt.xlabel("importance")
87
+ plt.ylabel("feature")
88
  buf = io.BytesIO()
89
  plt.savefig(buf, format="png", bbox_inches="tight")
90
  plt.close()
91
  buf.seek(0)
92
  fi_img = buf
93
 
94
+ # clustering with kmeans
95
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42)
96
+ clusters_kmeans = kmeans.fit_predict(X_scaled)
 
97
  pca = PCA(n_components=2, random_state=42)
98
+ X_pca = pca.fit_transform(X_scaled)
99
+ plt.figure(figsize=(6,4))
100
+ scatter = plt.scatter(X_pca[:,0], X_pca[:,1], c=clusters_kmeans, cmap="viridis", alpha=0.7)
101
+ plt.xlabel("pca 1")
102
+ plt.ylabel("pca 2")
103
+ plt.title(f"kmeans clustering (k={n_clusters})")
104
+ plt.colorbar(scatter, ticks=range(n_clusters))
105
+ buf = io.BytesIO()
106
+ plt.savefig(buf, format="png", bbox_inches="tight")
107
+ plt.close()
108
+ buf.seek(0)
109
+ kmeans_img = buf
110
+
111
+ # clustering with agglomerative clustering
112
+ agg = AgglomerativeClustering(n_clusters=n_clusters)
113
+ clusters_agg = agg.fit_predict(X_scaled)
114
+ plt.figure(figsize=(6,4))
115
+ scatter = plt.scatter(X_pca[:,0], X_pca[:,1], c=clusters_agg, cmap="plasma", alpha=0.7)
116
  plt.xlabel("pca 1")
117
  plt.ylabel("pca 2")
118
+ plt.title(f"agglomerative clustering (k={n_clusters})")
119
+ plt.colorbar(scatter, ticks=range(n_clusters))
120
+ buf = io.BytesIO()
121
+ plt.savefig(buf, format="png", bbox_inches="tight")
122
+ plt.close()
123
+ buf.seek(0)
124
+ agg_img = buf
125
+
126
+ # differentiating features among clusters (using kmeans clusters)
127
+ f_scores, p_vals = f_classif(X_processed, clusters_kmeans)
128
+ f_series = pd.Series(f_scores, index=X_processed.columns).sort_values(ascending=False)
129
+ top_features = f_series.head(10)
130
+ plt.figure(figsize=(8,4))
131
+ sns.barplot(x=top_features.values, y=top_features.index, palette="mako")
132
+ plt.title("top differentiating features (anova f-scores)")
133
+ plt.xlabel("f-score")
134
+ plt.ylabel("feature")
135
  buf = io.BytesIO()
136
  plt.savefig(buf, format="png", bbox_inches="tight")
137
  plt.close()
138
  buf.seek(0)
139
+ diff_img = buf
140
 
141
+ return results_text, model_img, fi_img, kmeans_img, agg_img, diff_img
142
 
143
  def update_dropdown(file):
144
  try:
 
152
  with gr.Row():
153
  file_input = gr.File(label="upload csv", file_types=[".csv"])
154
  label_dropdown = gr.Dropdown(label="select label column", choices=[])
155
+ clusters_slider = gr.Slider(minimum=2, maximum=10, step=1, value=3, label="number of clusters")
156
 
157
  file_input.change(fn=update_dropdown, inputs=file_input, outputs=label_dropdown)
158
 
159
  analyze_btn = gr.Button("analyze")
160
  with gr.Tabs():
161
  with gr.TabItem("results"):
162
+ results_textbox = gr.Textbox(label="metrics & descriptions", lines=10)
163
+ with gr.TabItem("model visualization"):
164
+ model_img_output = gr.Image(label="model output (confusion matrix or regression scatter)")
 
 
165
  with gr.TabItem("feature importances"):
166
  fi_output = gr.Image(label="feature importances")
167
+ with gr.TabItem("kmeans clustering"):
168
+ kmeans_output = gr.Image(label="kmeans clustering (pca projection)")
169
+ with gr.TabItem("agglomerative clustering"):
170
+ agg_output = gr.Image(label="agglomerative clustering (pca projection)")
171
+ with gr.TabItem("cluster differentiation"):
172
+ diff_output = gr.Image(label="differentiating features among clusters")
173
 
174
+ analyze_btn.click(fn=analyze_csv, inputs=[file_input, label_dropdown, clusters_slider],
175
+ outputs=[results_textbox, model_img_output, fi_output, kmeans_output, agg_output, diff_output])
176
 
177
  demo.launch()