clementBE commited on
Commit
6849a4f
ยท
verified ยท
1 Parent(s): 95316bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -10
app.py CHANGED
@@ -8,7 +8,9 @@ from sklearn.metrics import classification_report, accuracy_score, precision_sco
8
  df_train = None
9
  model = None
10
  vectorizer = None
11
- test_metrics = None # To store metrics after training
 
 
12
 
13
  def load_training_file(file):
14
  global df_train
@@ -18,7 +20,7 @@ def load_training_file(file):
18
  df_train = pd.read_excel(file.name)
19
  col_names = list(df_train.columns)
20
 
21
- return f"โœ… Loaded file with {len(df_train)} rows", gr.update(choices=col_names, value=col_names[0]), gr.update(choices=col_names, value=col_names[-1])
22
 
23
  def train_model(text_column, target_column):
24
  global model, vectorizer, test_metrics, df_train
@@ -31,7 +33,6 @@ def train_model(text_column, target_column):
31
 
32
  df_filtered = df_train.dropna(subset=[text_column, target_column])
33
 
34
- # Split train/test
35
  X_train, X_test, y_train, y_test = train_test_split(
36
  df_filtered[text_column], df_filtered[target_column], test_size=0.2, random_state=42
37
  )
@@ -43,12 +44,10 @@ def train_model(text_column, target_column):
43
  model = LogisticRegression(max_iter=1000)
44
  model.fit(X_train_vec, y_train)
45
 
46
- # Predict on test set
47
  y_pred = model.predict(X_test_vec)
48
 
49
- # Compute metrics
50
  accuracy = accuracy_score(y_test, y_pred)
51
- precision = precision_score(y_test, y_pred, average='weighted', zero_division=0) # weighted average for multiclass
52
  report = classification_report(y_test, y_pred, zero_division=0)
53
 
54
  test_metrics = f"Accuracy: {accuracy:.2%}\nPrecision (weighted): {precision:.2%}\n\nClassification Report:\n{report}"
@@ -61,16 +60,46 @@ def predict_label(text_input):
61
 
62
  X = vectorizer.transform([text_input])
63
  prediction = model.predict(X)[0]
64
- proba = model.predict_proba(X).max() # highest confidence for predicted class
65
 
66
  return f"๐Ÿ”ฎ Prediction: {prediction} (confidence: {proba:.2%})"
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with gr.Blocks() as demo:
69
  gr.Markdown("# ๐Ÿง  Text Classification App")
70
 
71
  with gr.Row():
72
- file_input = gr.File(label="Upload Excel File (.xlsx)", file_types=[".xlsx"])
73
- load_button = gr.Button("๐Ÿ“‚ Load File")
74
 
75
  status_output = gr.Markdown()
76
  with gr.Row():
@@ -82,10 +111,22 @@ with gr.Blocks() as demo:
82
 
83
  with gr.Row():
84
  input_text = gr.Textbox(label="Enter text to classify")
85
- predict_button = gr.Button("๐Ÿ” Predict")
86
 
87
  prediction_output = gr.Markdown()
88
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  load_button.click(
90
  fn=load_training_file,
91
  inputs=file_input,
@@ -104,5 +145,17 @@ with gr.Blocks() as demo:
104
  outputs=prediction_output
105
  )
106
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if __name__ == "__main__":
108
  demo.launch()
 
8
  df_train = None
9
  model = None
10
  vectorizer = None
11
+ test_metrics = None
12
+
13
+ df_predict = None # for batch prediction file
14
 
15
  def load_training_file(file):
16
  global df_train
 
20
  df_train = pd.read_excel(file.name)
21
  col_names = list(df_train.columns)
22
 
23
+ return f"โœ… Loaded training file with {len(df_train)} rows", gr.update(choices=col_names, value=col_names[0]), gr.update(choices=col_names, value=col_names[-1])
24
 
25
  def train_model(text_column, target_column):
26
  global model, vectorizer, test_metrics, df_train
 
33
 
34
  df_filtered = df_train.dropna(subset=[text_column, target_column])
35
 
 
36
  X_train, X_test, y_train, y_test = train_test_split(
37
  df_filtered[text_column], df_filtered[target_column], test_size=0.2, random_state=42
38
  )
 
44
  model = LogisticRegression(max_iter=1000)
45
  model.fit(X_train_vec, y_train)
46
 
 
47
  y_pred = model.predict(X_test_vec)
48
 
 
49
  accuracy = accuracy_score(y_test, y_pred)
50
+ precision = precision_score(y_test, y_pred, average='weighted', zero_division=0)
51
  report = classification_report(y_test, y_pred, zero_division=0)
52
 
53
  test_metrics = f"Accuracy: {accuracy:.2%}\nPrecision (weighted): {precision:.2%}\n\nClassification Report:\n{report}"
 
60
 
61
  X = vectorizer.transform([text_input])
62
  prediction = model.predict(X)[0]
63
+ proba = model.predict_proba(X).max()
64
 
65
  return f"๐Ÿ”ฎ Prediction: {prediction} (confidence: {proba:.2%})"
66
 
67
+ # New function for batch prediction
68
+ def load_prediction_file(file):
69
+ global df_predict
70
+ if file is None:
71
+ return "โŒ Please upload a prediction file.", gr.update(choices=[], value=None)
72
+ df_predict = pd.read_excel(file.name)
73
+ col_names = list(df_predict.columns)
74
+ return f"โœ… Loaded prediction file with {len(df_predict)} rows", gr.update(choices=col_names, value=col_names[0])
75
+
76
+ def run_batch_prediction(text_column):
77
+ global df_predict, model, vectorizer
78
+ if model is None or vectorizer is None:
79
+ return "โŒ Model is not trained yet."
80
+ if df_predict is None:
81
+ return "โŒ No prediction file loaded."
82
+ if text_column not in df_predict.columns:
83
+ return "โŒ Invalid text column selected."
84
+
85
+ df_filtered = df_predict.dropna(subset=[text_column]).copy()
86
+ X = vectorizer.transform(df_filtered[text_column])
87
+ preds = model.predict(X)
88
+ probs = model.predict_proba(X).max(axis=1)
89
+
90
+ df_filtered["Prediction"] = preds
91
+ df_filtered["Confidence"] = probs
92
+
93
+ # Show top 10 results as preview in the UI
94
+ preview = df_filtered.head(10).to_dict(orient="records")
95
+ return preview
96
+
97
  with gr.Blocks() as demo:
98
  gr.Markdown("# ๐Ÿง  Text Classification App")
99
 
100
  with gr.Row():
101
+ file_input = gr.File(label="Upload Training Excel File (.xlsx)", file_types=[".xlsx"])
102
+ load_button = gr.Button("๐Ÿ“‚ Load Training File")
103
 
104
  status_output = gr.Markdown()
105
  with gr.Row():
 
111
 
112
  with gr.Row():
113
  input_text = gr.Textbox(label="Enter text to classify")
114
+ predict_button = gr.Button("๐Ÿ” Predict Single")
115
 
116
  prediction_output = gr.Markdown()
117
 
118
+ # New part for batch prediction
119
+ with gr.Row():
120
+ pred_file_input = gr.File(label="Upload Prediction Excel File (.xlsx)", file_types=[".xlsx"])
121
+ load_pred_button = gr.Button("๐Ÿ“‚ Load Prediction File")
122
+
123
+ pred_status = gr.Markdown()
124
+ pred_text_column_dropdown = gr.Dropdown(label="Text column for Prediction")
125
+
126
+ batch_pred_button = gr.Button("โšก Run Batch Prediction")
127
+ batch_pred_output = gr.Dataframe(headers=["All columns from input + Prediction + Confidence"], interactive=False)
128
+
129
+ # Link buttons and functions
130
  load_button.click(
131
  fn=load_training_file,
132
  inputs=file_input,
 
145
  outputs=prediction_output
146
  )
147
 
148
+ load_pred_button.click(
149
+ fn=load_prediction_file,
150
+ inputs=pred_file_input,
151
+ outputs=[pred_status, pred_text_column_dropdown]
152
+ )
153
+
154
+ batch_pred_button.click(
155
+ fn=run_batch_prediction,
156
+ inputs=pred_text_column_dropdown,
157
+ outputs=batch_pred_output
158
+ )
159
+
160
  if __name__ == "__main__":
161
  demo.launch()