import pandas as pd from sklearn.model_selection import train_test_split from sklearn.feature_extraction.text import CountVectorizer from sklearn.naive_bayes import MultinomialNB from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay import joblib import matplotlib.pyplot as plt from io import BytesIO import base64 import gradio as gr import re # Load and preprocess dataset dataset = pd.read_csv('Email_spam_niki.csv', on_bad_lines='skip', engine='python') # Drop rows where 'spam' or 'text' is NaN and convert 'spam' to numeric dataset.dropna(subset=['spam', 'text'], inplace=True) dataset['spam'] = pd.to_numeric(dataset['spam'], errors='coerce') # Remove any rows where 'spam' is NaN after conversion and convert 'spam' to integers dataset.dropna(subset=['spam'], inplace=True) dataset['spam'] = dataset['spam'].astype(int) # Vectorize the text data vectorizer = CountVectorizer() X = vectorizer.fit_transform(dataset['text']) y = dataset['spam'] # Split the data into training and testing sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) # Train the Naive Bayes model model = MultinomialNB() model.fit(X_train, y_train) # Save the model and vectorizer joblib.dump(model, 'spam_model.pkl') joblib.dump(vectorizer, 'spam_vectorizer.pkl') # Reload for consistency model = joblib.load('spam_model.pkl') vectorizer = joblib.load('spam_vectorizer.pkl') # List of spammy keywords spam_keywords = [ "win", "free", "urgent", "money", "credit", "loan", "offer", "buy now", "limited time", "click here", "guaranteed", "congratulations", "winner" ] # Helper function to highlight spammy keywords def highlight_keywords(text): highlighted = text for keyword in spam_keywords: pattern = re.compile(rf"(\b{keyword}\b)", re.IGNORECASE) highlighted = pattern.sub(f"{keyword}", highlighted) return highlighted # Prediction function def classify_email(email_text): email_vector = vectorizer.transform([email_text]) prediction = model.predict(email_vector) confidence = model.predict_proba(email_vector).max() * 100 result = "Spam" if prediction[0] == 1 else "Ham" highlighted_text = highlight_keywords(email_text) color = "red" if result == "Spam" else "green" emoji = "📧" if result == "Ham" else "⚠️" advice = "Be careful! This might be a scam." if result == "Spam" else "This email seems safe." return { "result": f"{emoji} {result}", "confidence": f"{confidence:.2f}%", "highlighted": highlighted_text, "spammy_keywords": ", ".join( [kw for kw in spam_keywords if kw.lower() in email_text.lower()] ), "advice": advice } # Generate performance metrics def generate_performance_metrics(): y_pred = model.predict(X_test) accuracy = accuracy_score(y_test, y_pred) report = classification_report(y_test, y_pred, output_dict=True) # Confusion matrix plot fig, ax = plt.subplots(figsize=(6, 6)) ConfusionMatrixDisplay.from_predictions(y_test, y_pred, ax=ax, cmap='Blues') plt.title("Confusion Matrix") plt.tight_layout() # Save plot as a base64 string buf = BytesIO() plt.savefig(buf, format="png") buf.seek(0) img_base64 = base64.b64encode(buf.getvalue()).decode("utf-8") buf.close() return { "accuracy": f"{accuracy:.2%}", "precision": f"{report['1']['precision']:.2%}", "recall": f"{report['1']['recall']:.2%}", "f1_score": f"{report['1']['f1-score']:.2%}", "confusion_matrix_plot": img_base64, } # Function to add new email data and retrain the model def save_and_retrain(email_text, label): try: # Convert label to numeric value (0 for Ham, 1 for Spam) label_numeric = 1 if label == "Spam" else 0 # Add the new data to the dataset new_data = pd.DataFrame({"text": [email_text], "spam": [label_numeric]}) global dataset, X, y, model, vectorizer dataset = pd.concat([dataset, new_data], ignore_index=True) # Vectorize the updated text data X = vectorizer.fit_transform(dataset['text']) y = dataset['spam'] # Retrain the model model.fit(X, y) # Save the updated model and vectorizer joblib.dump(model, 'spam_model.pkl') joblib.dump(vectorizer, 'spam_vectorizer.pkl') return "Model retrained successfully with new data!" except Exception as e: return f"Error while retraining: {str(e)}" # Updated CSS custom_css = """ body { font-family: 'Arial', sans-serif; background-image: url('https://cdn.pixabay.com/photo/2016/11/19/15/26/email-1839873_1280.jpg'); background-size: cover; background-position: center; background-attachment: fixed; color: #333; } h1, h2, h3 { text-align: center; color: #ffffff; text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.7); } .gradio-container { background-color: rgba(255, 255, 255, 0.8); border-radius: 10px; padding: 20px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.3); } button { background-color: #1e90ff; color: white; padding: 10px 20px; border: none; border-radius: 5px; cursor: pointer; font-size: 1.2em; transition: transform 0.2s, background-color 0.3s; } button:hover { background-color: #1c86ee; transform: scale(1.05); } .highlight { background-color: #ffeb3b; font-weight: bold; padding: 0 3px; border-radius: 3px; } .metric { font-size: 1.2em; text-align: center; color: #ffffff; background-color: #4CAF50; border-radius: 8px; padding: 10px; margin: 10px 0; box-shadow: 2px 2px 5px rgba(0, 0, 0, 0.2); } """ # Create Gradio Interface def create_interface(): performance_metrics = generate_performance_metrics() with gr.Blocks(css=custom_css) as interface: gr.Markdown("# 📩 Advanced Email Spam Classifier") gr.Markdown( """ ### Enter the content of an email below to classify it as Spam or Ham. The tool uses **machine learning** to analyze email content, highlights spammy keywords, and shows key performance analytics. """ ) with gr.Row(): with gr.Column(): email_input = gr.Textbox( lines=8, placeholder="Type or paste your email content here...", label="Email Content" ) with gr.Column(): result_output = gr.HTML(label="Classification Result") confidence_output = gr.Textbox(label="Confidence Score", interactive=False) highlighted_output = gr.HTML(label="Highlighted Text") keywords_output = gr.Textbox(label="Spam Keywords Detected", interactive=False) advice_output = gr.HTML(label="Advice") analyze_button = gr.Button("Analyze Email 🕵️‍♂️") def email_analysis_pipeline(email_text): results = classify_email(email_text) return ( results["result"], results["confidence"], results["highlighted"], results["spammy_keywords"], results["advice"] ) analyze_button.click( fn=email_analysis_pipeline, inputs=email_input, outputs=[result_output, confidence_output, highlighted_output, keywords_output, advice_output] ) gr.Markdown("## 📊 Model Performance Analytics") with gr.Row(): with gr.Column(): gr.Textbox(value=performance_metrics["accuracy"], label="Accuracy", interactive=False, elem_classes=["metric"]) gr.Textbox(value=performance_metrics["precision"], label="Precision", interactive=False, elem_classes=["metric"]) gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False, elem_classes=["metric"]) gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False, elem_classes=["metric"]) with gr.Column(): gr.Markdown("### Confusion Matrix") gr.HTML(f"") gr.Markdown("## 🛠️ Save and Retrain the Model") with gr.Row(): email_for_retraining = gr.Textbox( lines=8, placeholder="Enter the email content to label as Spam or Ham and retrain", label="Email Content" ) label_input = gr.Radio(["Spam", "Ham"], label="Label", type="value") retrain_button = gr.Button("Save & Retrain Model") retrain_result = gr.Textbox(label="Retrain Result", interactive=False) retrain_button.click( fn=save_and_retrain, inputs=[email_for_retraining, label_input], outputs=retrain_result ) gr.Markdown("## 📘 Glossary and Explanation of Labels") gr.Markdown( """ ### Labels: - **Spam:** Unwanted or harmful emails flagged by the system. - **Ham:** Legitimate, safe emails. ### Confusion Matrix: The confusion matrix shows the performance of the model by comparing the true labels with the predicted ones. It consists of: - **True Positives (TP):** Correctly predicted spam emails. - **True Negatives (TN):** Correctly predicted ham emails. - **False Positives (FP):** Ham emails incorrectly predicted as spam. - **False Negatives (FN):** Spam emails incorrectly predicted as ham. ### Metrics: - **Accuracy:** The percentage of correct classifications. - **Precision:** Out of predicted Spam, how many are actually Spam. - **Recall:** Out of all actual Spam emails, how many are predicted as Spam. - **F1 Score:** Harmonic mean of Precision and Recall. """ ) return interface # Launch the interface interface = create_interface() interface.launch(share=True)