leygit's picture
Update app.py
59cc262 verified
raw
history blame
9.17 kB
# DISTILLBERT RUN 3 , added weight_decay=0.01
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer # Converts text into a matrix of token counts
from sklearn.metrics import classification_report, accuracy_score
import gradio as gr
# Load dataset
file_path = 'spam_ham_dataset.csv'
df = pd.read_csv(file_path)
# Convert label column to numeric (0 for ham, 1 for spam)
df['label_num'] = df['label'].astype('category').cat.codes
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Tokenize dataset
encodings = tokenizer(df['text'].tolist(), padding=True, truncation=True, max_length=128, return_tensors="pt")
labels = torch.tensor(df['label_num'].values)
# Custom Dataset
class SpamDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.encodings.items()} # Keep as PyTorch tensors
item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long) # Ensure labels are `long`
return item
# Create dataset
dataset = SpamDataset(encodings, labels)
# Split dataset (80% train, 20% validation)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
def get_top_words(corpus, n=None):
vec = CountVectorizer(stop_words='english').fit(corpus)
bag_of_words = vec.transform(corpus)
sum_words = bag_of_words.sum(axis=0)
words_freq = [(word, sum_words[0, idx]) for word, idx in vec.vocabulary_.items()]
words_freq = sorted(words_freq, key=lambda x: x[1], reverse=True)
return words_freq[:n]
# DataLoader Function (Fix Collate)
def collate_fn(batch):
keys = batch[0].keys()
collated = {key: torch.stack([b[key] for b in batch]) for key in keys}
return collated
# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
# Load the trained model
def load_model(model_path="distilbert_spam_model.pt"):
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.load_state_dict(torch.load(model_path, map_location=device)) # Load model weights
model.to(device)
model.eval() # Set model to evaluation mode
return model
# Load model globally
model = load_model()
# Classification function
def classify_email(email_text):
model.eval()
with torch.no_grad():
inputs = tokenizer(email_text, padding=True, truncation=True, max_length=256, return_tensors="pt")
inputs = {key: val.to(device) for key, val in inputs.items()}
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=1)
probs = F.softmax(logits, dim=1)
confidence = torch.max(probs).item() * 100
result = "Spam" if predictions.item() == 1 else "Ham"
return result, f"{confidence:.2f}%"
# Evaluation function with detailed classification report
def evaluate_model_with_report(val_loader):
model.eval() # Set model to evaluation mode
y_true = []
y_pred = []
correct = 0
total = 0
with torch.no_grad():
for batch in val_loader:
inputs = {key: val.to(device) for key, val in batch.items()}
labels = inputs.pop("labels").to(device)
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=1)
# Collect labels and predictions
y_true.extend(labels.cpu().numpy())
y_pred.extend(predictions.cpu().numpy())
# Calculate accuracy
correct += (predictions == labels).sum().item()
total += labels.size(0)
# Calculate accuracy
accuracy = correct / total if total > 0 else 0
print(f"Validation Accuracy: {accuracy:.4f}")
# Print classification report
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=["Ham", "Spam"]))
return accuracy
# Performance metrics
def generate_performance_metrics():
model.eval() # Set model to evaluation mode
y_true = [] # True labels
y_pred = [] # Predicted labels
with torch.no_grad():
for batch in val_loader:
inputs = {key: val.to(device) for key, val in batch.items()}
labels = inputs.pop("labels").to(device) # Extract labels
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(predictions.cpu().numpy())
# Compute accuracy and classification report
accuracy = accuracy_score(y_true, y_pred)
report = classification_report(y_true, y_pred, output_dict=True)
return {
"accuracy": f"{accuracy:.2%}",
"precision": f"{report['1']['precision']:.2%}",
"recall": f"{report['1']['recall']:.2%}",
"f1_score": f"{report['1']['f1-score']:.2%}",
}
# Gradio Interface
def create_interface():
performance_metrics = generate_performance_metrics()
with gr.Blocks() as interface:
with gr.Tab("Demo"):
gr.Markdown("Spam and Phishing Email Detection")
gr.Markdown(
"""
Welcome to the Spam and Phishing Email Detection Demo! This tool leverages DistilBERT, a lightweight yet powerful transformer model, to classify emails as ham (legitimate), spam, or phishing based on their content.
To provide a comprehensive overview of the system, the demo is divided into three key sections:
Detection Demo – Input an email and see real-time classification results.
Metrics Analysis – Gain insights into the performance of the model, including accuracy, precision, recall, and F1-score.
Credits – Acknowledging the datasets, tools, and frameworks that made this project possible.
This project aims to enhance email security by identifying malicious messages with high accuracy, reducing the risk of scams and fraud. Feel free to explore the demo and see how AI is improving cybersecurity!
""")
# Email Text Input
email_input = gr.Textbox(
lines=8, placeholder="Type or paste your email content here...", label="Email Content"
)
# Email Text Results and Analysis
result_output = gr.Textbox(label="Classification Result")
confidence_output = gr.Textbox(label="Confidence Score", interactive=False)
analyze_button = gr.Button("Analyze Email")
def email_analysis_pipeline(email_text):
results = classify_email(email_text)
return (
results["result"],
results["confidence"]
)
analyze_button.click(
fn=classify_email,
inputs=email_input,
outputs=[result_output, confidence_output]
)
with gr.Tab("Analysis"):
with gr.Blocks() as interface:
gr.Markdown("## Dataset Overview")
gr.Markdown("### Dataet Headers")
gr.DataFrame(df)
# Top 10 words for spam
gr.Markdown("### Top Spam Words")
top_spam_words = get_top_words(df[df['label'] == "spam"]['text'], n=10)
gr.DataFrame(top_spam_words)
# Top 10 words for ham
gr.Markdown("### Top Ham Words")
top_ham_words = get_top_words(df[df['label'] == "ham"]['text'], n=10)
gr.DataFrame(top_ham_words)
gr.Markdown("## 📊 Model Performance Analytics")
with gr.Row():
gr.Textbox(value=performance_metrics["accuracy"], label="Accuracy", interactive=False)
gr.Textbox(value=performance_metrics["precision"], label="Precision", interactive=False)
gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False)
gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False)
with gr.Tab("Glossary"):
gr.Markdown(" ## Credits and Reference ")
return interface
# Launch the interface
interface = create_interface()
interface.launch(share=True)