leygit commited on
Commit
fd64ed7
·
verified ·
1 Parent(s): 6760c84

Rename app2test.py to app.py

Browse files
Files changed (1) hide show
  1. app2test.py → app.py +32 -111
app2test.py → app.py RENAMED
@@ -1,4 +1,4 @@
1
- #DISTILLBERT RUN 3 , added weight_decay=0.01
2
  import pandas as pd
3
  import torch
4
  import torch.nn as nn
@@ -8,77 +8,24 @@ from torch.utils.data import Dataset, DataLoader
8
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
9
  from sklearn.model_selection import train_test_split
10
  from sklearn.metrics import classification_report
11
- from transformers import BertTokenizer
12
-
13
- # Load dataset
14
- file_path = 'spam_ham_dataset.csv'
15
- df = pd.read_csv(file_path)
16
 
17
- # Convert labels to numeric
18
- df['label_num'] = df['label'].map({'ham': 0, 'spam': 1})
19
 
20
  # Load tokenizer
21
- tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
22
-
23
- # Tokenize dataset
24
- encodings = tokenizer(df['text'].tolist(), padding=True, truncation=True, max_length=128, return_tensors="pt")
25
- labels = torch.tensor(df['label_num'].values)
26
-
27
- # Custom Dataset
28
- class SpamDataset(Dataset):
29
- def __init__(self, encodings, labels):
30
- self.encodings = encodings
31
- self.labels = labels
32
-
33
- def __len__(self):
34
- return len(self.labels)
35
-
36
- def __getitem__(self, idx):
37
- item = {key: val[idx] for key, val in self.encodings.items()}
38
- item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
39
- return item
40
-
41
- # Create dataset
42
- dataset = SpamDataset(encodings, labels)
43
-
44
- # Split dataset (80% train, 20% validation)
45
- train_size = int(0.8 * len(dataset))
46
- val_size = len(dataset) - train_size
47
- train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
48
-
49
- # DataLoader with batch size
50
- def collate_fn(batch):
51
- keys = batch[0].keys()
52
- return {key: torch.stack([b[key] for b in batch]) for key in keys}
53
-
54
- train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
55
- val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)
56
 
57
  # Load the trained model
58
  def load_model(model_path="distilbert_spam_model.pt"):
59
  model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
60
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model weights
 
61
  model.eval() # Set model to evaluation mode
62
  return model
63
 
64
- # Evaluation
65
- model.eval()
66
- correct = 0
67
- total = 0
68
- with torch.no_grad():
69
- for batch in val_loader:
70
- inputs = {key: val.to(device) for key, val in batch.items()}
71
- labels = inputs.pop("labels").to(device)
72
-
73
- outputs = model(**inputs)
74
- predictions = torch.argmax(outputs.logits, dim=1)
75
- correct += (predictions == labels).sum().item()
76
- total += labels.size(0)
77
-
78
- accuracy = correct / total
79
- print(f"Validation Accuracy: {accuracy:.4f}")
80
-
81
-
82
 
83
  # Classification function
84
  def classify_email(email_text):
@@ -144,73 +91,47 @@ def evaluate_model_with_report(val_loader):
144
 
145
  return accuracy
146
 
147
- # Run evaluation with classification report
148
- accuracy = evaluate_model_with_report(val_loader)
149
- print(f"Model Validation Accuracy: {accuracy:.4f}")
150
-
151
- ## Gradio Interface
 
 
 
 
152
 
153
- import gradio as gr
154
 
155
- # Create Gradio Interface
156
  def create_interface():
157
- performance_metrics = generate_performance_metrics()
158
-
159
- # Introduction - Title + Brief Description
160
- with gr.Blocks(css=custom_css) as interface:
161
  gr.Markdown("Spam Email Classification")
162
- gr.Markdown(
163
- """
164
- Brief description of the project here
165
- """
166
- )
167
 
168
  # Email Text Input
169
- with gr.Row():
170
- email_input = gr.Textbox(
171
- lines=8, placeholder="Type or paste your email content here...", label="Email Content"
172
- )
173
 
174
  # Email Text Results and Analysis
175
- with gr.Row():
176
- result_output = gr.HTML(label="Classification Result") # label = [function that prints classification result]
177
- confidence_output = gr.Textbox(label="Confidence Score", interactive=False)
178
- accuracy_output = gr.Textbox(label="Accuracy", interactive=False)
179
-
180
 
181
  analyze_button = gr.Button("Analyze Email 🕵️‍♂️")
182
 
183
  analyze_button.click(
184
- fn=email_analysis_pipeline,
185
  inputs=email_input,
186
  outputs=[result_output, confidence_output, accuracy_output]
187
  )
188
 
189
- # Analysis
190
  gr.Markdown("## 📊 Model Performance Analytics")
191
  with gr.Row():
192
- with gr.Column():
193
- gr.Textbox(value=performance_metrics["accuracy"], label="Accuracy", interactive=False, elem_classes=["metric"])
194
- gr.Textbox(value=performance_metrics["precision"], label="Precision", interactive=False, elem_classes=["metric"])
195
- gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False, elem_classes=["metric"])
196
- gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False, elem_classes=["metric"])
197
- with gr.Column():
198
- gr.Markdown("### Confusion Matrix")
199
- gr.HTML(f"<img src='data:image/png;base64,{performance_metrics['confusion_matrix_plot']}' style='max-width: 100%; height: auto;' />")
200
-
201
- gr.Markdown("## 📘 Glossary and Explanation of Labels")
202
- gr.Markdown(
203
- """
204
- ### Labels:
205
- - **Spam:** Unwanted or harmful emails flagged by the system.
206
- - **Ham:** Legitimate, safe emails.
207
- ### Metrics:
208
- - **Accuracy:** The percentage of correct classifications.
209
- - **Precision:** Out of predicted Spam, how many are actually Spam.
210
- - **Recall:** Out of all actual Spam emails, how many are predicted as Spam.
211
- - **F1 Score:** Harmonic mean of Precision and Recall.
212
- """
213
- )
214
 
215
  return interface
216
 
 
1
+ # DISTILLBERT RUN 3 , added weight_decay=0.01
2
  import pandas as pd
3
  import torch
4
  import torch.nn as nn
 
8
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
9
  from sklearn.model_selection import train_test_split
10
  from sklearn.metrics import classification_report
11
+ import gradio as gr
 
 
 
 
12
 
13
+ # Define device
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  # Load tokenizer
17
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Load the trained model
20
  def load_model(model_path="distilbert_spam_model.pt"):
21
  model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
22
+ model.load_state_dict(torch.load(model_path, map_location=device)) # Load model weights
23
+ model.to(device)
24
  model.eval() # Set model to evaluation mode
25
  return model
26
 
27
+ # Load model globally
28
+ model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Classification function
31
  def classify_email(email_text):
 
91
 
92
  return accuracy
93
 
94
+ # Performance metrics
95
+ def generate_performance_metrics():
96
+ return {
97
+ "accuracy": "N/A",
98
+ "precision": "N/A",
99
+ "recall": "N/A",
100
+ "f1_score": "N/A",
101
+ "confusion_matrix_plot": "",
102
+ }
103
 
104
+ performance_metrics = generate_performance_metrics()
105
 
106
+ # Gradio Interface
107
  def create_interface():
108
+ with gr.Blocks() as interface:
 
 
 
109
  gr.Markdown("Spam Email Classification")
 
 
 
 
 
110
 
111
  # Email Text Input
112
+ email_input = gr.Textbox(
113
+ lines=8, placeholder="Type or paste your email content here...", label="Email Content"
114
+ )
 
115
 
116
  # Email Text Results and Analysis
117
+ result_output = gr.Textbox(label="Classification Result")
118
+ confidence_output = gr.Textbox(label="Confidence Score", interactive=False)
119
+ accuracy_output = gr.Textbox(label="Accuracy", interactive=False)
 
 
120
 
121
  analyze_button = gr.Button("Analyze Email 🕵️‍♂️")
122
 
123
  analyze_button.click(
124
+ fn=classify_email,
125
  inputs=email_input,
126
  outputs=[result_output, confidence_output, accuracy_output]
127
  )
128
 
 
129
  gr.Markdown("## 📊 Model Performance Analytics")
130
  with gr.Row():
131
+ gr.Textbox(value=performance_metrics["accuracy"], label="Accuracy", interactive=False)
132
+ gr.Textbox(value=performance_metrics["precision"], label="Precision", interactive=False)
133
+ gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False)
134
+ gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  return interface
137