leygit commited on
Commit
ba8b794
·
verified ·
1 Parent(s): c7e6ce7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -7,7 +7,7 @@ import torch.nn.functional as F
7
  from torch.utils.data import Dataset, DataLoader
8
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
9
  from sklearn.model_selection import train_test_split
10
- from sklearn.metrics import classification_report
11
  import gradio as gr
12
 
13
  # Define device
@@ -93,8 +93,18 @@ def evaluate_model_with_report(val_loader):
93
 
94
  # Performance metrics
95
  def generate_performance_metrics():
96
- y_pred = model.predict(X_test)
97
- accuracy = accuracy_score(y_test,y_pred)
 
 
 
 
 
 
 
 
 
 
98
  report = classification_report(y_test, y_pred, output_dict=True)
99
 
100
  return {
@@ -104,6 +114,7 @@ def generate_performance_metrics():
104
  "f1_score": f"{report['1']['f1-score']:.2%}",
105
  }
106
 
 
107
  # Gradio Interface
108
 
109
  def create_interface():
@@ -127,16 +138,13 @@ def create_interface():
127
  results = classify_email(email_text)
128
  return (
129
  results["result"],
130
- results["confidence"],
131
- results["highlighted"],
132
- results["spammy_keywords"],
133
- results["advice"]
134
  )
135
 
136
  analyze_button.click(
137
  fn=classify_email,
138
  inputs=email_input,
139
- outputs=[result_output, confidence_output, accuracy_output]
140
  )
141
 
142
  gr.Markdown("## 📊 Model Performance Analytics")
@@ -146,7 +154,7 @@ def create_interface():
146
  gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False)
147
  gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False)
148
 
149
- return interface
150
 
151
  # Launch the interface
152
  interface = create_interface()
 
7
  from torch.utils.data import Dataset, DataLoader
8
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
9
  from sklearn.model_selection import train_test_split
10
+ from sklearn.metrics import classification_report, accuracy_score
11
  import gradio as gr
12
 
13
  # Define device
 
93
 
94
  # Performance metrics
95
  def generate_performance_metrics():
96
+ model.eval() # Set model to evaluation mode
97
+
98
+ y_pred = []
99
+
100
+ with torch.no_grad():
101
+ for email in X_test:
102
+ inputs = tokenizer(email, padding=True, truncation=True, max_length=128, return_tensors="pt")
103
+ outputs = model(**inputs)
104
+ prediction = torch.argmax(outputs.logits, dim=1).item()
105
+ y_pred.append(prediction)
106
+
107
+ accuracy = accuracy_score(y_test, y_pred)
108
  report = classification_report(y_test, y_pred, output_dict=True)
109
 
110
  return {
 
114
  "f1_score": f"{report['1']['f1-score']:.2%}",
115
  }
116
 
117
+
118
  # Gradio Interface
119
 
120
  def create_interface():
 
138
  results = classify_email(email_text)
139
  return (
140
  results["result"],
141
+ results["confidence"]
 
 
 
142
  )
143
 
144
  analyze_button.click(
145
  fn=classify_email,
146
  inputs=email_input,
147
+ outputs=[result_output, confidence_output]
148
  )
149
 
150
  gr.Markdown("## 📊 Model Performance Analytics")
 
154
  gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False)
155
  gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False)
156
 
157
+ return interface
158
 
159
  # Launch the interface
160
  interface = create_interface()