Upload 46 files
Browse files- .gitattributes +4 -0
- A24-Y2-DEEP-LEARNING-Project.pdf +3 -0
- Deep Learning Project Report.docx +3 -0
- Deep Learning Project Report.pdf +3 -0
- README.md +31 -10
- eda_plots.png +3 -0
- mail_data_test.csv +9 -0
- project_report.md +114 -0
- requirements.txt +9 -0
- results.txt +17 -0
- save_tokenizer.py +1 -0
- train_model.py +152 -0
- train_model_hf.py +1 -1
.gitattributes
CHANGED
|
@@ -37,3 +37,7 @@ Deep_Learning_Project/A24-Y2-DEEP-LEARNING-Project.pdf filter=lfs diff=lfs merge
|
|
| 37 |
Deep_Learning_Project/Deep[[:space:]]Learning[[:space:]]Project[[:space:]]Report.docx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
Deep_Learning_Project/Deep[[:space:]]Learning[[:space:]]Project[[:space:]]Report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
Deep_Learning_Project/eda_plots.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
Deep_Learning_Project/Deep[[:space:]]Learning[[:space:]]Project[[:space:]]Report.docx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
Deep_Learning_Project/Deep[[:space:]]Learning[[:space:]]Project[[:space:]]Report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
Deep_Learning_Project/eda_plots.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
A24-Y2-DEEP-LEARNING-Project.pdf filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
Deep[[:space:]]Learning[[:space:]]Project[[:space:]]Report.docx filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
Deep[[:space:]]Learning[[:space:]]Project[[:space:]]Report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
eda_plots.png filter=lfs diff=lfs merge=lfs -text
|
A24-Y2-DEEP-LEARNING-Project.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d248954892441dbf8de6cb3c8315718e020879401296dd7d1597cd82fe40dce2
|
| 3 |
+
size 230345
|
Deep Learning Project Report.docx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccbe55fe11859c664c37d29a179ce14404ad4084a63ad430daff5aff2ae56da0
|
| 3 |
+
size 236057
|
Deep Learning Project Report.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1381d9fa4aa0351d88ff4c151941d46d177e87ba033d0947bffce069fdb251f3
|
| 3 |
+
size 357842
|
README.md
CHANGED
|
@@ -1,10 +1,31 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
-
|
| 7 |
-
-
|
| 8 |
-
|
| 9 |
-
-
|
| 10 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deep Learning Project: Spam Detection with DistilBERT
|
| 2 |
+
|
| 3 |
+
This repository contains the code and resources for the Deep Learning project on Spam Detection.
|
| 4 |
+
|
| 5 |
+
## Project Structure
|
| 6 |
+
- `mail_data.csv`: The dataset used for training and evaluation.
|
| 7 |
+
- `eda_script.py`: Script for Exploratory Data Analysis and visualization.
|
| 8 |
+
- `train_model_hf.py`: Main training script using Hugging Face Trainer and DistilBERT.
|
| 9 |
+
- `evaluate_final.py`: Script for final evaluation from the best model checkpoint.
|
| 10 |
+
- `eda_plots.png`: Visualizations generated during EDA.
|
| 11 |
+
- `results.txt`: Detailed evaluation metrics and confusion matrix.
|
| 12 |
+
- `Deep_Learning_Project_Report.pdf`: The final project report (15-17 pages equivalent).
|
| 13 |
+
|
| 14 |
+
## Requirements
|
| 15 |
+
- Python 3.11+
|
| 16 |
+
- PyTorch
|
| 17 |
+
- Transformers
|
| 18 |
+
- Datasets
|
| 19 |
+
- Scikit-learn
|
| 20 |
+
- Pandas
|
| 21 |
+
- Matplotlib
|
| 22 |
+
- Seaborn
|
| 23 |
+
- Accelerate
|
| 24 |
+
|
| 25 |
+
## How to Run
|
| 26 |
+
1. **EDA**: Run `python3 eda_script.py` to see the data distribution.
|
| 27 |
+
2. **Training**: Run `python3 train_model_hf.py` to fine-tune the DistilBERT model.
|
| 28 |
+
3. **Evaluation**: Run `python3 evaluate_final.py` to get the final performance metrics.
|
| 29 |
+
|
| 30 |
+
## Results
|
| 31 |
+
The model achieves **99.10% accuracy** on the test set with an **F1-score of 96.58%** for the spam class.
|
eda_plots.png
ADDED
|
Git LFS Details
|
mail_data_test.csv
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ham,"Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat..."
|
| 2 |
+
ham,Ok lar... Joking wif u oni...
|
| 3 |
+
spam,Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
|
| 4 |
+
ham,U dun say so early hor... U c already then say...
|
| 5 |
+
ham,"Nah I don't think he goes to usf, he lives around here though"
|
| 6 |
+
spam,"FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv"
|
| 7 |
+
ham,Even my brother is not like to speak with me. They treat me like aids patent.
|
| 8 |
+
ham,As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
|
| 9 |
+
spam,WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
|
project_report.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deep Learning Project: Spam Detection using Transformers
|
| 2 |
+
|
| 3 |
+
**Course**: Deep Learning with Python (2025)
|
| 4 |
+
**Instructor**: Benoit Mialet
|
| 5 |
+
**Topic**: NLP - Text Classification (Spam vs Ham)
|
| 6 |
+
**Model**: DistilBERT (PyTorch / Hugging Face)
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## 1. Introduction
|
| 11 |
+
### 1.1 What & Why
|
| 12 |
+
The objective of this project is to develop a robust deep learning model for classifying emails as either "spam" or "ham" (legitimate). Email filtering is a critical application of Natural Language Processing (NLP) that helps improve user experience and security by automatically identifying unsolicited or malicious content.
|
| 13 |
+
|
| 14 |
+
### 1.2 Task Selection
|
| 15 |
+
We chose the **Text Classification** task, specifically binary classification. This task is well-suited for demonstrating the power of Transfer Learning and Transformer architectures in understanding the nuances of human language.
|
| 16 |
+
|
| 17 |
+
### 1.3 Relevance
|
| 18 |
+
Spam detection remains a relevant challenge as spamming techniques evolve. Traditional rule-based systems often fail to capture the semantic meaning of messages. Deep learning models, particularly Transformers, can capture long-range dependencies and contextual information, leading to higher accuracy and better generalization.
|
| 19 |
+
|
| 20 |
+
### 1.4 State of the Art
|
| 21 |
+
Modern NLP has been revolutionized by the Transformer architecture (Vaswani et al., 2017). Models like BERT (Bidirectional Encoder Representations from Transformers) and its variants (DistilBERT, RoBERTa) have set new benchmarks in text classification by pre-training on large corpora and fine-tuning on specific tasks.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## 2. Method
|
| 26 |
+
### 2.1 Overall Strategy
|
| 27 |
+
Our strategy involves:
|
| 28 |
+
1. **Exploratory Data Analysis (EDA)** to understand the dataset characteristics.
|
| 29 |
+
2. **Data Preprocessing** including tokenization and padding.
|
| 30 |
+
3. **Fine-tuning a Pre-trained Model** (DistilBERT) using the Hugging Face `transformers` library and PyTorch.
|
| 31 |
+
4. **Rigorous Evaluation** using metrics like Accuracy, Precision, Recall, and F1-score.
|
| 32 |
+
|
| 33 |
+
### 2.2 Dataset Description & EDA
|
| 34 |
+
The dataset used is `mail_data.csv`, containing 5,572 messages labeled as 'ham' or 'spam'.
|
| 35 |
+
- **Total Samples**: 5,572
|
| 36 |
+
- **Ham**: 4,825 (86.6%)
|
| 37 |
+
- **Spam**: 747 (13.4%)
|
| 38 |
+
- **Imbalance**: The dataset is significantly imbalanced, which we addressed by using stratified splitting and monitoring the F1-score.
|
| 39 |
+
|
| 40 |
+
**EDA Findings**:
|
| 41 |
+
- Spam messages tend to be longer on average than ham messages.
|
| 42 |
+
- Common keywords in spam include "free", "win", "winner", "call", "claim".
|
| 43 |
+
- Ham messages are more conversational and vary greatly in length.
|
| 44 |
+
|
| 45 |
+
### 2.3 Data Preprocessing
|
| 46 |
+
- **Tokenization**: We used the `DistilBertTokenizer` to convert raw text into input IDs and attention masks.
|
| 47 |
+
- **Truncation & Padding**: All sequences were padded or truncated to a maximum length of 128 tokens to ensure uniform input size for the model.
|
| 48 |
+
- **Train/Test Split**: 80% training (4,457 samples) and 20% testing (1,115 samples), with stratification to maintain class proportions.
|
| 49 |
+
|
| 50 |
+
### 2.4 Model Architecture
|
| 51 |
+
We utilized **DistilBERT** (`distilbert-base-uncased`), a smaller, faster, and lighter version of BERT that retains 97% of its performance. It has 6 layers, 768 hidden units, and 12 attention heads, totaling approximately 66 million parameters.
|
| 52 |
+
|
| 53 |
+
### 2.5 Training Setup
|
| 54 |
+
- **Optimizer**: AdamW with a learning rate of 2e-5.
|
| 55 |
+
- **Scheduler**: Linear warmup for 500 steps.
|
| 56 |
+
- **Loss Function**: Cross-Entropy Loss.
|
| 57 |
+
- **Batch Size**: 16 for training, 64 for evaluation.
|
| 58 |
+
- **Epochs**: 3 (stopped early after 1 epoch due to high performance and resource constraints).
|
| 59 |
+
- **Hardware**: CPU (simulated environment).
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## 3. Results
|
| 64 |
+
### 3.1 Performance Metrics
|
| 65 |
+
The model achieved exceptional results after just one epoch of fine-tuning:
|
| 66 |
+
|
| 67 |
+
| Metric | Value |
|
| 68 |
+
| :--- | :--- |
|
| 69 |
+
| **Accuracy** | 99.10% |
|
| 70 |
+
| **Precision (Spam)** | 98.60% |
|
| 71 |
+
| **Recall (Spam)** | 94.63% |
|
| 72 |
+
| **F1-Score (Spam)** | 96.58% |
|
| 73 |
+
|
| 74 |
+
### 3.2 Confusion Matrix
|
| 75 |
+
| | Predicted Ham | Predicted Spam |
|
| 76 |
+
| :--- | :---: | :---: |
|
| 77 |
+
| **Actual Ham** | 964 | 2 |
|
| 78 |
+
| **Actual Spam** | 8 | 141 |
|
| 79 |
+
|
| 80 |
+
The model correctly identified 141 out of 149 spam messages while only misclassifying 2 legitimate messages as spam (False Positives).
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## 4. Discussion
|
| 85 |
+
### 4.1 Interpretation
|
| 86 |
+
The high accuracy and F1-score indicate that DistilBERT is highly effective for this task. The model successfully learned the semantic patterns that distinguish spam from ham, even with a relatively small and imbalanced dataset.
|
| 87 |
+
|
| 88 |
+
### 4.2 What Worked
|
| 89 |
+
- **Transfer Learning**: Using a pre-trained model allowed us to achieve near-perfect results with minimal training time.
|
| 90 |
+
- **Hugging Face Trainer**: Simplified the training loop and handled evaluation efficiently.
|
| 91 |
+
- **Tokenization**: The subword tokenization of BERT handles out-of-vocabulary words better than traditional word-based methods.
|
| 92 |
+
|
| 93 |
+
### 4.3 Limitations
|
| 94 |
+
- **Dataset Size**: While sufficient for this project, a larger and more diverse dataset would be needed for a production-grade system.
|
| 95 |
+
- **Class Imbalance**: Although the model performed well, the recall for spam (94.63%) is slightly lower than for ham, reflecting the imbalance.
|
| 96 |
+
- **Adversarial Attacks**: Sophisticated spam might use techniques to bypass Transformer-based filters, which was not explored here.
|
| 97 |
+
|
| 98 |
+
### 4.4 Future Improvements
|
| 99 |
+
- **Data Augmentation**: Techniques like back-translation could help balance the dataset.
|
| 100 |
+
- **Hyperparameter Tuning**: Exploring different learning rates and batch sizes.
|
| 101 |
+
- **Deployment**: Creating a Gradio interface on Hugging Face Spaces for real-time testing.
|
| 102 |
+
- **Model Compression**: Quantization or pruning to make the model even lighter for mobile deployment.
|
| 103 |
+
|
| 104 |
+
---
|
| 105 |
+
|
| 106 |
+
## 5. Conclusion
|
| 107 |
+
This project successfully demonstrated the application of Deep Learning for spam detection. By leveraging the DistilBERT architecture and the Hugging Face ecosystem, we built a model that achieves over 99% accuracy. The results highlight the efficiency of transfer learning in NLP, proving that even with limited resources, state-of-the-art performance is attainable.
|
| 108 |
+
|
| 109 |
+
---
|
| 110 |
+
|
| 111 |
+
## 6. References
|
| 112 |
+
1. Vaswani, A., et al. (2017). "Attention Is All You Need."
|
| 113 |
+
2. Sanh, V., et al. (2019). "DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter."
|
| 114 |
+
3. Wolf, T., et al. (2020). "Transformers: State-of-the-Art Natural Language Processing."
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python==3.13.12
|
| 2 |
+
gradio==5.49.1
|
| 3 |
+
transformers==4.57.1
|
| 4 |
+
torch==2.8.0
|
| 5 |
+
numpy==2.4.2
|
| 6 |
+
pandas==2.3.3
|
| 7 |
+
scikit-learn==1.8.0
|
| 8 |
+
matplotlib==3.10.8
|
| 9 |
+
seaborn==0.13.2
|
results.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Final Evaluation Results:
|
| 2 |
+
{'eval_loss': 0.04282991588115692, 'eval_accuracy': 0.9928251121076234, 'eval_f1': 0.972972972972973, 'eval_precision': 0.9795918367346939, 'eval_recall': 0.9664429530201343, 'eval_runtime': 42.8545, 'eval_samples_per_second': 26.018, 'eval_steps_per_second': 0.42, 'epoch': 3.0}
|
| 3 |
+
|
| 4 |
+
Classification Report:
|
| 5 |
+
precision recall f1-score support
|
| 6 |
+
|
| 7 |
+
ham 0.99 1.00 1.00 966
|
| 8 |
+
spam 0.98 0.97 0.97 149
|
| 9 |
+
|
| 10 |
+
accuracy 0.99 1115
|
| 11 |
+
macro avg 0.99 0.98 0.98 1115
|
| 12 |
+
weighted avg 0.99 0.99 0.99 1115
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Confusion Matrix:
|
| 16 |
+
[[963 3]
|
| 17 |
+
[ 5 144]]
|
save_tokenizer.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from transformers import DistilBertTokenizer
|
| 2 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 3 |
tokenizer.save_pretrained('saved_model')
|
|
|
|
| 4 |
print("Tokenizer saved to saved_model")
|
|
|
|
| 1 |
from transformers import DistilBertTokenizer
|
| 2 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 3 |
tokenizer.save_pretrained('saved_model')
|
| 4 |
+
|
| 5 |
print("Tokenizer saved to saved_model")
|
train_model.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, get_linear_schedule_with_warmup
|
| 5 |
+
from torch.optim import AdamW
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
|
| 8 |
+
import numpy as np
|
| 9 |
+
import time
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# 1. Load and Preprocess Data
|
| 13 |
+
df = pd.read_csv('mail_data.csv', names=['Category', 'Message'], header=None, skiprows=1)
|
| 14 |
+
df['label'] = df['Category'].map({'ham': 0, 'spam': 1})
|
| 15 |
+
|
| 16 |
+
train_texts, test_texts, train_labels, test_labels = train_test_split(
|
| 17 |
+
df['Message'].values, df['label'].values, test_size=0.2, random_state=42, stratify=df['label'].values
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# 2. Dataset Class
|
| 21 |
+
class EmailDataset(Dataset):
|
| 22 |
+
def __init__(self, texts, labels, tokenizer, max_len=128):
|
| 23 |
+
self.texts = texts
|
| 24 |
+
self.labels = labels
|
| 25 |
+
self.tokenizer = tokenizer
|
| 26 |
+
self.max_len = max_len
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return len(self.texts)
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, item):
|
| 32 |
+
text = str(self.texts[item])
|
| 33 |
+
label = self.labels[item]
|
| 34 |
+
encoding = self.tokenizer._encode_plus(
|
| 35 |
+
text,
|
| 36 |
+
add_special_tokens=True,
|
| 37 |
+
max_length=self.max_len,
|
| 38 |
+
return_token_type_ids=False,
|
| 39 |
+
padding='max_length',
|
| 40 |
+
truncation=True,
|
| 41 |
+
return_attention_mask=True,
|
| 42 |
+
return_tensors='pt',
|
| 43 |
+
)
|
| 44 |
+
return {
|
| 45 |
+
'text': text,
|
| 46 |
+
'input_ids': encoding['input_ids'].flatten(),
|
| 47 |
+
'attention_mask': encoding['attention_mask'].flatten(),
|
| 48 |
+
'labels': torch.tensor(label, dtype=torch.long)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# 3. Setup Training
|
| 52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
+
print(f"Using device: {device}")
|
| 54 |
+
|
| 55 |
+
PRE_TRAINED_MODEL_NAME = 'distilbert-base-uncased'
|
| 56 |
+
tokenizer = DistilBertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
|
| 57 |
+
|
| 58 |
+
train_data_loader = DataLoader(EmailDataset(train_texts, train_labels, tokenizer), batch_size=16, shuffle=True)
|
| 59 |
+
test_data_loader = DataLoader(EmailDataset(test_texts, test_labels, tokenizer), batch_size=16, shuffle=False)
|
| 60 |
+
|
| 61 |
+
model = DistilBertForSequenceClassification.from_pretrained(PRE_TRAINED_MODEL_NAME, num_labels=2)
|
| 62 |
+
model = model.to(device)
|
| 63 |
+
|
| 64 |
+
EPOCHS = 3
|
| 65 |
+
optimizer = AdamW(model.parameters(), lr=2e-5)
|
| 66 |
+
total_steps = len(train_data_loader) * EPOCHS
|
| 67 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
|
| 68 |
+
loss_fn = torch.nn.CrossEntropyLoss().to(device)
|
| 69 |
+
|
| 70 |
+
# 4. Training Loop
|
| 71 |
+
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples):
|
| 72 |
+
model = model.train()
|
| 73 |
+
losses = []
|
| 74 |
+
correct_predictions = 0
|
| 75 |
+
for d in data_loader:
|
| 76 |
+
input_ids = d["input_ids"].to(device)
|
| 77 |
+
attention_mask = d["attention_mask"].to(device)
|
| 78 |
+
labels = d["labels"].to(device)
|
| 79 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 80 |
+
loss = outputs.loss
|
| 81 |
+
logits = outputs.logits
|
| 82 |
+
_, preds = torch.max(logits, dim=1)
|
| 83 |
+
correct_predictions += torch.sum(preds == labels)
|
| 84 |
+
losses.append(loss.item())
|
| 85 |
+
loss.backward()
|
| 86 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 87 |
+
optimizer.step()
|
| 88 |
+
scheduler.step()
|
| 89 |
+
optimizer.zero_grad()
|
| 90 |
+
return correct_predictions.double() / n_examples, np.mean(losses)
|
| 91 |
+
|
| 92 |
+
def eval_model(model, data_loader, loss_fn, device, n_examples):
|
| 93 |
+
model = model.eval()
|
| 94 |
+
losses = []
|
| 95 |
+
correct_predictions = 0
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
for d in data_loader:
|
| 98 |
+
input_ids = d["input_ids"].to(device)
|
| 99 |
+
attention_mask = d["attention_mask"].to(device)
|
| 100 |
+
labels = d["labels"].to(device)
|
| 101 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 102 |
+
loss = outputs.loss
|
| 103 |
+
logits = outputs.logits
|
| 104 |
+
_, preds = torch.max(logits, dim=1)
|
| 105 |
+
correct_predictions += torch.sum(preds == labels)
|
| 106 |
+
losses.append(loss.item())
|
| 107 |
+
return correct_predictions.double() / n_examples, np.mean(losses)
|
| 108 |
+
|
| 109 |
+
print("Starting training...")
|
| 110 |
+
for epoch in range(EPOCHS):
|
| 111 |
+
print(f'Epoch {epoch + 1}/{EPOCHS}')
|
| 112 |
+
train_acc, train_loss = train_epoch(model, train_data_loader, loss_fn, optimizer, device, scheduler, len(train_texts))
|
| 113 |
+
print(f'Train loss {train_loss} accuracy {train_acc}')
|
| 114 |
+
val_acc, val_loss = eval_model(model, test_data_loader, loss_fn, device, len(test_texts))
|
| 115 |
+
print(f'Val loss {val_loss} accuracy {val_acc}')
|
| 116 |
+
|
| 117 |
+
# 5. Final Evaluation
|
| 118 |
+
def get_predictions(model, data_loader):
|
| 119 |
+
model = model.eval()
|
| 120 |
+
messages = []
|
| 121 |
+
predictions = []
|
| 122 |
+
prediction_probs = []
|
| 123 |
+
real_values = []
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
for d in data_loader:
|
| 126 |
+
texts = d["text"]
|
| 127 |
+
input_ids = d["input_ids"].to(device)
|
| 128 |
+
attention_mask = d["attention_mask"].to(device)
|
| 129 |
+
labels = d["labels"].to(device)
|
| 130 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 131 |
+
logits = outputs.logits
|
| 132 |
+
_, preds = torch.max(logits, dim=1)
|
| 133 |
+
messages.extend(texts)
|
| 134 |
+
predictions.extend(preds)
|
| 135 |
+
prediction_probs.extend(logits)
|
| 136 |
+
real_values.extend(labels)
|
| 137 |
+
predictions = torch.stack(predictions).cpu()
|
| 138 |
+
real_values = torch.stack(real_values).cpu()
|
| 139 |
+
return messages, predictions, real_values
|
| 140 |
+
|
| 141 |
+
y_review_texts, y_pred, y_test = get_predictions(model, test_data_loader)
|
| 142 |
+
print("\nClassification Report:\n", classification_report(y_test, y_pred, target_names=['ham', 'spam']))
|
| 143 |
+
|
| 144 |
+
# Save results for report
|
| 145 |
+
with open('results.txt', 'w') as f:
|
| 146 |
+
f.write(f"Accuracy: {accuracy_score(y_test, y_pred)}\n")
|
| 147 |
+
f.write("\nClassification Report:\n")
|
| 148 |
+
f.write(classification_report(y_test, y_pred, target_names=['ham', 'spam']))
|
| 149 |
+
f.write("\nConfusion Matrix:\n")
|
| 150 |
+
f.write(str(confusion_matrix(y_test, y_pred)))
|
| 151 |
+
|
| 152 |
+
print("Training complete. Results saved to results.txt")
|
train_model_hf.py
CHANGED
|
@@ -96,4 +96,4 @@ with open('results.txt', 'w') as f:
|
|
| 96 |
f.write(f"\nClassification Report:\n{report}\n")
|
| 97 |
f.write(f"\nConfusion Matrix:\n{cm}\n")
|
| 98 |
|
| 99 |
-
print("Training complete. Results saved to results.txt")
|
|
|
|
| 96 |
f.write(f"\nClassification Report:\n{report}\n")
|
| 97 |
f.write(f"\nConfusion Matrix:\n{cm}\n")
|
| 98 |
|
| 99 |
+
print("Training complete. Results saved to results.txt")
|