|
|
|
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.feature_extraction.text import CountVectorizer |
|
|
from sklearn.metrics import classification_report, accuracy_score |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
file_path = 'spam_ham_dataset.csv' |
|
|
df = pd.read_csv(file_path) |
|
|
|
|
|
|
|
|
df['label_num'] = df['label'].astype('category').cat.codes |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") |
|
|
|
|
|
|
|
|
encodings = tokenizer(df['text'].tolist(), padding=True, truncation=True, max_length=128, return_tensors="pt") |
|
|
labels = torch.tensor(df['label_num'].values) |
|
|
|
|
|
|
|
|
class SpamDataset(Dataset): |
|
|
def __init__(self, encodings, labels): |
|
|
self.encodings = encodings |
|
|
self.labels = labels |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.labels) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = {key: val[idx] for key, val in self.encodings.items()} |
|
|
item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long) |
|
|
return item |
|
|
|
|
|
|
|
|
dataset = SpamDataset(encodings, labels) |
|
|
|
|
|
|
|
|
train_size = int(0.8 * len(dataset)) |
|
|
val_size = len(dataset) - train_size |
|
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) |
|
|
|
|
|
def get_top_words(corpus, n=None): |
|
|
vec = CountVectorizer(stop_words='english').fit(corpus) |
|
|
bag_of_words = vec.transform(corpus) |
|
|
sum_words = bag_of_words.sum(axis=0) |
|
|
words_freq = [(word, sum_words[0, idx]) for word, idx in vec.vocabulary_.items()] |
|
|
words_freq = sorted(words_freq, key=lambda x: x[1], reverse=True) |
|
|
return words_freq[:n] |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
keys = batch[0].keys() |
|
|
collated = {key: torch.stack([b[key] for b in batch]) for key in keys} |
|
|
return collated |
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn) |
|
|
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn) |
|
|
|
|
|
|
|
|
def load_model(model_path="distilbert_spam_model.pt"): |
|
|
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) |
|
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
def classify_email(email_text): |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
inputs = tokenizer(email_text, padding=True, truncation=True, max_length=256, return_tensors="pt") |
|
|
inputs = {key: val.to(device) for key, val in inputs.items()} |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
predictions = torch.argmax(logits, dim=1) |
|
|
probs = F.softmax(logits, dim=1) |
|
|
confidence = torch.max(probs).item() * 100 |
|
|
|
|
|
result = "Spam" if predictions.item() == 1 else "Ham" |
|
|
return result, f"{confidence:.2f}%" |
|
|
|
|
|
|
|
|
def evaluate_model_with_report(val_loader): |
|
|
model.eval() |
|
|
y_true = [] |
|
|
y_pred = [] |
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in val_loader: |
|
|
inputs = {key: val.to(device) for key, val in batch.items()} |
|
|
labels = inputs.pop("labels").to(device) |
|
|
|
|
|
outputs = model(**inputs) |
|
|
predictions = torch.argmax(outputs.logits, dim=1) |
|
|
|
|
|
|
|
|
y_true.extend(labels.cpu().numpy()) |
|
|
y_pred.extend(predictions.cpu().numpy()) |
|
|
|
|
|
|
|
|
correct += (predictions == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
|
|
|
|
|
|
accuracy = correct / total if total > 0 else 0 |
|
|
print(f"Validation Accuracy: {accuracy:.4f}") |
|
|
|
|
|
|
|
|
print("\nClassification Report:") |
|
|
print(classification_report(y_true, y_pred, target_names=["Ham", "Spam"])) |
|
|
|
|
|
return accuracy |
|
|
|
|
|
|
|
|
def generate_performance_metrics(): |
|
|
model.eval() |
|
|
|
|
|
y_true = [] |
|
|
y_pred = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in val_loader: |
|
|
inputs = {key: val.to(device) for key, val in batch.items()} |
|
|
labels = inputs.pop("labels").to(device) |
|
|
|
|
|
outputs = model(**inputs) |
|
|
predictions = torch.argmax(outputs.logits, dim=1) |
|
|
|
|
|
y_true.extend(labels.cpu().numpy()) |
|
|
y_pred.extend(predictions.cpu().numpy()) |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(y_true, y_pred) |
|
|
report = classification_report(y_true, y_pred, output_dict=True) |
|
|
|
|
|
return { |
|
|
"accuracy": f"{accuracy:.2%}", |
|
|
"precision": f"{report['1']['precision']:.2%}", |
|
|
"recall": f"{report['1']['recall']:.2%}", |
|
|
"f1_score": f"{report['1']['f1-score']:.2%}", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
performance_metrics = generate_performance_metrics() |
|
|
with gr.Blocks() as interface: |
|
|
with gr.Tab("Demo"): |
|
|
gr.Markdown("Spam and Phishing Email Detection") |
|
|
gr.Markdown( |
|
|
""" |
|
|
Welcome to the Spam and Phishing Email Detection Demo! This tool leverages DistilBERT, a lightweight yet powerful transformer model, to classify emails as ham (legitimate), spam, or phishing based on their content. |
|
|
|
|
|
To provide a comprehensive overview of the system, the demo is divided into three key sections: |
|
|
|
|
|
Detection Demo – Input an email and see real-time classification results. |
|
|
|
|
|
Metrics Analysis – Gain insights into the performance of the model, including accuracy, precision, recall, and F1-score. |
|
|
|
|
|
Credits – Acknowledging the datasets, tools, and frameworks that made this project possible. |
|
|
This project aims to enhance email security by identifying malicious messages with high accuracy, reducing the risk of scams and fraud. Feel free to explore the demo and see how AI is improving cybersecurity! |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
email_input = gr.Textbox( |
|
|
lines=8, placeholder="Type or paste your email content here...", label="Email Content" |
|
|
) |
|
|
|
|
|
|
|
|
result_output = gr.Textbox(label="Classification Result") |
|
|
confidence_output = gr.Textbox(label="Confidence Score", interactive=False) |
|
|
|
|
|
analyze_button = gr.Button("Analyze Email") |
|
|
|
|
|
def email_analysis_pipeline(email_text): |
|
|
results = classify_email(email_text) |
|
|
return ( |
|
|
results["result"], |
|
|
results["confidence"] |
|
|
) |
|
|
|
|
|
analyze_button.click( |
|
|
fn=classify_email, |
|
|
inputs=email_input, |
|
|
outputs=[result_output, confidence_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Analysis"): |
|
|
with gr.Blocks() as interface: |
|
|
gr.Markdown("## Dataset Overview") |
|
|
gr.Markdown("### Dataet Headers") |
|
|
gr.DataFrame(df) |
|
|
|
|
|
|
|
|
gr.Markdown("### Top Spam Words") |
|
|
top_spam_words = get_top_words(df[df['label'] == "spam"]['text'], n=10) |
|
|
gr.DataFrame(top_spam_words) |
|
|
|
|
|
|
|
|
gr.Markdown("### Top Ham Words") |
|
|
top_ham_words = get_top_words(df[df['label'] == "ham"]['text'], n=10) |
|
|
gr.DataFrame(top_ham_words) |
|
|
|
|
|
gr.Markdown("## 📊 Model Performance Analytics") |
|
|
with gr.Row(): |
|
|
gr.Textbox(value=performance_metrics["accuracy"], label="Accuracy", interactive=False) |
|
|
gr.Textbox(value=performance_metrics["precision"], label="Precision", interactive=False) |
|
|
gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False) |
|
|
gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False) |
|
|
|
|
|
with gr.Tab("Glossary"): |
|
|
gr.Markdown(" ## Credits and Reference ") |
|
|
|
|
|
return interface |
|
|
|
|
|
|
|
|
interface = create_interface() |
|
|
interface.launch(share=True) |
|
|
|