Upload 47 files
Browse files- .gitattributes +4 -0
- Deep_Learning_Project/A24-Y2-DEEP-LEARNING-Project.pdf +3 -0
- Deep_Learning_Project/Deep Learning Project Report.docx +3 -0
- Deep_Learning_Project/Deep Learning Project Report.pdf +3 -0
- Deep_Learning_Project/README.md +31 -0
- Deep_Learning_Project/app.py +74 -0
- Deep_Learning_Project/eda_plots.png +3 -0
- Deep_Learning_Project/eda_script.py +152 -0
- Deep_Learning_Project/evaluate_final.py +63 -0
- Deep_Learning_Project/mail_data.csv +0 -0
- Deep_Learning_Project/mail_data_test.csv +9 -0
- Deep_Learning_Project/project_report.md +114 -0
- Deep_Learning_Project/requirements.txt +9 -0
- Deep_Learning_Project/results.txt +17 -0
- Deep_Learning_Project/results/checkpoint-279/config.json +28 -0
- Deep_Learning_Project/results/checkpoint-279/model.safetensors +3 -0
- Deep_Learning_Project/results/checkpoint-279/optimizer.pt +3 -0
- Deep_Learning_Project/results/checkpoint-279/rng_state.pth +3 -0
- Deep_Learning_Project/results/checkpoint-279/scheduler.pt +3 -0
- Deep_Learning_Project/results/checkpoint-279/trainer_state.json +235 -0
- Deep_Learning_Project/results/checkpoint-279/training_args.bin +3 -0
- Deep_Learning_Project/results/checkpoint-558/config.json +28 -0
- Deep_Learning_Project/results/checkpoint-558/model.safetensors +3 -0
- Deep_Learning_Project/results/checkpoint-558/optimizer.pt +3 -0
- Deep_Learning_Project/results/checkpoint-558/rng_state.pth +3 -0
- Deep_Learning_Project/results/checkpoint-558/scheduler.pt +3 -0
- Deep_Learning_Project/results/checkpoint-558/trainer_state.json +443 -0
- Deep_Learning_Project/results/checkpoint-558/training_args.bin +3 -0
- Deep_Learning_Project/results/checkpoint-837/config.json +28 -0
- Deep_Learning_Project/results/checkpoint-837/model.safetensors +3 -0
- Deep_Learning_Project/results/checkpoint-837/optimizer.pt +3 -0
- Deep_Learning_Project/results/checkpoint-837/rng_state.pth +3 -0
- Deep_Learning_Project/results/checkpoint-837/scheduler.pt +3 -0
- Deep_Learning_Project/results/checkpoint-837/trainer_state.json +651 -0
- Deep_Learning_Project/results/checkpoint-837/training_args.bin +3 -0
- Deep_Learning_Project/save_tokenizer.py +5 -0
- Deep_Learning_Project/saved_model/config.json +28 -0
- Deep_Learning_Project/saved_model/model.safetensors +3 -0
- Deep_Learning_Project/saved_model/optimizer.pt +3 -0
- Deep_Learning_Project/saved_model/rng_state.pth +3 -0
- Deep_Learning_Project/saved_model/scheduler.pt +3 -0
- Deep_Learning_Project/saved_model/tokenizer.json +0 -0
- Deep_Learning_Project/saved_model/tokenizer_config.json +14 -0
- Deep_Learning_Project/saved_model/trainer_state.json +235 -0
- Deep_Learning_Project/saved_model/training_args.bin +3 -0
- Deep_Learning_Project/train_model.py +152 -0
- Deep_Learning_Project/train_model_hf.py +99 -0
- Deep_Learning_Project/~$ep Learning Project Report.docx +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Deep_Learning_Project/A24-Y2-DEEP-LEARNING-Project.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Deep_Learning_Project/Deep[[:space:]]Learning[[:space:]]Project[[:space:]]Report.docx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Deep_Learning_Project/Deep[[:space:]]Learning[[:space:]]Project[[:space:]]Report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Deep_Learning_Project/eda_plots.png filter=lfs diff=lfs merge=lfs -text
|
Deep_Learning_Project/A24-Y2-DEEP-LEARNING-Project.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d248954892441dbf8de6cb3c8315718e020879401296dd7d1597cd82fe40dce2
|
| 3 |
+
size 230345
|
Deep_Learning_Project/Deep Learning Project Report.docx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccbe55fe11859c664c37d29a179ce14404ad4084a63ad430daff5aff2ae56da0
|
| 3 |
+
size 236057
|
Deep_Learning_Project/Deep Learning Project Report.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1381d9fa4aa0351d88ff4c151941d46d177e87ba033d0947bffce069fdb251f3
|
| 3 |
+
size 357842
|
Deep_Learning_Project/README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deep Learning Project: Spam Detection with DistilBERT
|
| 2 |
+
|
| 3 |
+
This repository contains the code and resources for the Deep Learning project on Spam Detection.
|
| 4 |
+
|
| 5 |
+
## Project Structure
|
| 6 |
+
- `mail_data.csv`: The dataset used for training and evaluation.
|
| 7 |
+
- `eda_script.py`: Script for Exploratory Data Analysis and visualization.
|
| 8 |
+
- `train_model_hf.py`: Main training script using Hugging Face Trainer and DistilBERT.
|
| 9 |
+
- `evaluate_final.py`: Script for final evaluation from the best model checkpoint.
|
| 10 |
+
- `eda_plots.png`: Visualizations generated during EDA.
|
| 11 |
+
- `results.txt`: Detailed evaluation metrics and confusion matrix.
|
| 12 |
+
- `Deep_Learning_Project_Report.pdf`: The final project report (15-17 pages equivalent).
|
| 13 |
+
|
| 14 |
+
## Requirements
|
| 15 |
+
- Python 3.11+
|
| 16 |
+
- PyTorch
|
| 17 |
+
- Transformers
|
| 18 |
+
- Datasets
|
| 19 |
+
- Scikit-learn
|
| 20 |
+
- Pandas
|
| 21 |
+
- Matplotlib
|
| 22 |
+
- Seaborn
|
| 23 |
+
- Accelerate
|
| 24 |
+
|
| 25 |
+
## How to Run
|
| 26 |
+
1. **EDA**: Run `python3 eda_script.py` to see the data distribution.
|
| 27 |
+
2. **Training**: Run `python3 train_model_hf.py` to fine-tune the DistilBERT model.
|
| 28 |
+
3. **Evaluation**: Run `python3 evaluate_final.py` to get the final performance metrics.
|
| 29 |
+
|
| 30 |
+
## Results
|
| 31 |
+
The model achieves **99.10% accuracy** on the test set with an **F1-score of 96.58%** for the spam class.
|
Deep_Learning_Project/app.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
# 1. Load the model and tokenizer from your saved directory
|
| 7 |
+
# If uploading to Hugging Face Spaces, ensure your saved_model folder is uploaded too!
|
| 8 |
+
model_path = "./saved_model" # Update this path to where your model is saved
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
|
| 12 |
+
model = DistilBertForSequenceClassification.from_pretrained(model_path)
|
| 13 |
+
model.eval() # Set model to evaluation mode
|
| 14 |
+
except Exception as e:
|
| 15 |
+
# Fallback to base model if you are just testing the code without your fine-tuned weights
|
| 16 |
+
# (Remove this try-except block in your final version, it's just to prevent crashes if path is wrong)
|
| 17 |
+
print("Could not load fine-tuned model, loading base model for demonstration purposes...")
|
| 18 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 19 |
+
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
|
| 20 |
+
model.eval()
|
| 21 |
+
|
| 22 |
+
# 2. Define the prediction function
|
| 23 |
+
def predict_spam(message):
|
| 24 |
+
if not message.strip():
|
| 25 |
+
return {"Please enter a message": 1.0}
|
| 26 |
+
|
| 27 |
+
# Tokenize input
|
| 28 |
+
inputs = tokenizer(message, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
| 29 |
+
|
| 30 |
+
# Get model prediction
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
outputs = model(**inputs)
|
| 33 |
+
logits = outputs.logits
|
| 34 |
+
|
| 35 |
+
# Apply softmax to get probabilities
|
| 36 |
+
probabilities = F.softmax(logits, dim=1).squeeze()
|
| 37 |
+
|
| 38 |
+
# Our labels are: 0 -> Ham, 1 -> Spam
|
| 39 |
+
prob_ham = probabilities[0].item()
|
| 40 |
+
prob_spam = probabilities[1].item()
|
| 41 |
+
|
| 42 |
+
# Gradio expects a dictionary of {label: probability} for the Label component
|
| 43 |
+
return {"Ham (Legitimate)": prob_ham, "Spam (Malicious)": prob_spam}
|
| 44 |
+
|
| 45 |
+
# 3. Create the Gradio Interface
|
| 46 |
+
# We use a clean, modern interface
|
| 47 |
+
demo = gr.Interface(
|
| 48 |
+
fn=predict_spam,
|
| 49 |
+
inputs=gr.Textbox(
|
| 50 |
+
lines=4,
|
| 51 |
+
placeholder="Type an email or SMS message here...",
|
| 52 |
+
label="Message Content"
|
| 53 |
+
),
|
| 54 |
+
outputs=gr.Label(num_top_classes=2, label="Prediction Confidence"),
|
| 55 |
+
title="🛡️ Spam Detection AI",
|
| 56 |
+
description="""
|
| 57 |
+
### Deep Learning Project (2025)
|
| 58 |
+
This application uses a fine-tuned **DistilBERT** Transformer model to classify text messages as either Spam or Ham.
|
| 59 |
+
* Enter a message below and click Submit.
|
| 60 |
+
* **Examples of Spam:** 'URGENT! You have won a 1 week FREE membership. Call 087124006024'
|
| 61 |
+
* **Examples of Ham:** 'Hey, are we still meeting for lunch tomorrow?'
|
| 62 |
+
""",
|
| 63 |
+
examples=[
|
| 64 |
+
["WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."],
|
| 65 |
+
["Hey man, just checking in. Are we still on for the movie tonight at 8?"],
|
| 66 |
+
["URGENT: Your bank account has been locked due to suspicious activity. Please click this link to verify your identity: http://secure-login-update.com"],
|
| 67 |
+
["I'll be about 10 minutes late, stuck in traffic."]
|
| 68 |
+
],
|
| 69 |
+
theme=gr.themes.Soft()
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# 4. Launch the app
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
demo.launch()
|
Deep_Learning_Project/eda_plots.png
ADDED
|
Git LFS Details
|
Deep_Learning_Project/eda_script.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import seaborn as sns
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
# Load dataset
|
| 9 |
+
df = pd.read_csv('mail_data.csv', names=['Category', 'Message'], header=None, skiprows=1)
|
| 10 |
+
|
| 11 |
+
# Basic info
|
| 12 |
+
print("Dataset Shape:", df.shape)
|
| 13 |
+
print("\nValue Counts:\n", df['Category'].value_counts())
|
| 14 |
+
print("\nMissing Values:\n", df.isnull().sum())
|
| 15 |
+
|
| 16 |
+
# Add message length
|
| 17 |
+
df['Length'] = df['Message'].apply(len)
|
| 18 |
+
|
| 19 |
+
# Visualizations
|
| 20 |
+
plt.figure(figsize=(18, 10))
|
| 21 |
+
|
| 22 |
+
# 1. Class Distribution
|
| 23 |
+
plt.subplot(3, 2, 1)
|
| 24 |
+
sns.countplot(x='Category', data=df)
|
| 25 |
+
plt.title('Class Distribution (Spam vs Ham)')
|
| 26 |
+
|
| 27 |
+
# 2. Message Length Distribution
|
| 28 |
+
plt.subplot(3, 2, 2)
|
| 29 |
+
sns.histplot(data=df, x='Length', hue='Category', bins=50, kde=True)
|
| 30 |
+
plt.title('Message Length Distribution')
|
| 31 |
+
|
| 32 |
+
# Sample messages
|
| 33 |
+
print("\nSample Ham:")
|
| 34 |
+
print(df[df['Category'] == 'ham']['Message'].iloc[0])
|
| 35 |
+
print("\nSample Spam:")
|
| 36 |
+
print(df[df['Category'] == 'spam']['Message'].iloc[0])
|
| 37 |
+
|
| 38 |
+
#Let's check if there are other characteristic in our text data that separates spam from ham
|
| 39 |
+
|
| 40 |
+
#Word counts (how often different words are used)
|
| 41 |
+
def get_word_count(messages_series: pd.Series)->tuple:
|
| 42 |
+
word_list = []
|
| 43 |
+
for message in messages_series:
|
| 44 |
+
clean_msg = re.sub("[\.|\?|,|\!]+", "", message).lower()
|
| 45 |
+
clean_msg = re.sub("\ \ +", " ", clean_msg).split(" ")
|
| 46 |
+
clean_msg = [word for word in clean_msg if word!=""]
|
| 47 |
+
word_list+=clean_msg
|
| 48 |
+
|
| 49 |
+
word_count = pd.Series(word_list).value_counts(normalize=True)
|
| 50 |
+
return word_count
|
| 51 |
+
|
| 52 |
+
spam_word_count = get_word_count(df[df['Category'] == 'spam']['Message'])
|
| 53 |
+
ham_word_count = get_word_count(df[df['Category'] == 'ham']['Message'])
|
| 54 |
+
|
| 55 |
+
#TODO: Compare the words frequencies to find the ones used more often in spam emails compared to ham emails.
|
| 56 |
+
|
| 57 |
+
words_list = list(set(spam_word_count.index)&set(ham_word_count.index))
|
| 58 |
+
wordcount_distance_spam_ham = []
|
| 59 |
+
for i in words_list:
|
| 60 |
+
if i in spam_word_count:
|
| 61 |
+
spam_count_i = spam_word_count[i]
|
| 62 |
+
else:
|
| 63 |
+
spam_count_i = 0
|
| 64 |
+
if i in ham_word_count:
|
| 65 |
+
ham_count_i = ham_word_count[i]
|
| 66 |
+
else:
|
| 67 |
+
ham_count_i = 0
|
| 68 |
+
wordcount_distance_spam_ham.append((ham_count_i-spam_count_i))
|
| 69 |
+
wordcount_distance_spam_ham = pd.Series(wordcount_distance_spam_ham, index=words_list)
|
| 70 |
+
wordcount_distance_spam_ham = wordcount_distance_spam_ham.sort_values(ascending=False)
|
| 71 |
+
print("Words more often found in normal emails than spam emails")
|
| 72 |
+
print(wordcount_distance_spam_ham[0:10]) #Words more often present in Ham emails
|
| 73 |
+
print("words more often found in spam emails than normal emails")
|
| 74 |
+
print(wordcount_distance_spam_ham[-10:]) #Words more often present in spam emails
|
| 75 |
+
|
| 76 |
+
#Mail Words length (numbers of words used in the message)
|
| 77 |
+
def get_word_len(messages_series: pd.Series)->tuple:
|
| 78 |
+
word_len_list = []
|
| 79 |
+
for message in messages_series:
|
| 80 |
+
clean_msg = re.sub("[\.|\?|,|\!|0-9|]+", "", message).lower()
|
| 81 |
+
clean_msg = re.sub("\ \ +", " ", clean_msg).split(" ")
|
| 82 |
+
clean_msg = [word for word in clean_msg if word!=""]
|
| 83 |
+
word_len_list.append(len(clean_msg))
|
| 84 |
+
|
| 85 |
+
len_count = pd.Series(word_len_list)
|
| 86 |
+
return len_count
|
| 87 |
+
|
| 88 |
+
df['word_len'] = get_word_len(df['Message'])
|
| 89 |
+
|
| 90 |
+
plt.subplot(3, 2, 3)
|
| 91 |
+
sns.histplot(data=df, x='word_len', hue='Category', bins=50, kde=True)
|
| 92 |
+
plt.title('Word count Distribution')
|
| 93 |
+
|
| 94 |
+
#Mail Words length (numbers of words used in the message)
|
| 95 |
+
def get_word_len2(messages_series: pd.Series)->tuple:
|
| 96 |
+
word_len_list = []
|
| 97 |
+
for message in messages_series:
|
| 98 |
+
clean_msg = re.sub("[\.|\?|,|\!|0-9]+", "", message).lower()
|
| 99 |
+
clean_msg = re.sub("\ \ +", " ", clean_msg).split(" ")
|
| 100 |
+
clean_msg = [word for word in clean_msg if word!=""]
|
| 101 |
+
total_word_length = 0
|
| 102 |
+
if len(clean_msg)>0:
|
| 103 |
+
for word in clean_msg:
|
| 104 |
+
total_word_length += len(word)
|
| 105 |
+
avg_word_len = total_word_length/len(clean_msg)
|
| 106 |
+
word_len_list.append(avg_word_len)
|
| 107 |
+
else:
|
| 108 |
+
word_len_list.append(0)
|
| 109 |
+
|
| 110 |
+
len_count = pd.Series(word_len_list)
|
| 111 |
+
return len_count
|
| 112 |
+
|
| 113 |
+
df['avg_word_len'] = get_word_len2(df['Message'])
|
| 114 |
+
|
| 115 |
+
plt.subplot(3, 2, 4)
|
| 116 |
+
sns.histplot(data=df, x='avg_word_len', hue='Category', bins=50, kde=True)
|
| 117 |
+
plt.title('Word length Distribution')
|
| 118 |
+
|
| 119 |
+
#Mail Words length (numbers of words used in the message)
|
| 120 |
+
def get_avg_sentence_len(messages_series: pd.Series)->tuple:
|
| 121 |
+
word_len_list = []
|
| 122 |
+
avg_sentence_len = []
|
| 123 |
+
for message in messages_series:
|
| 124 |
+
clean_msg = re.sub("\ \ +", " ", message).split(" ")
|
| 125 |
+
sentence_number = 1
|
| 126 |
+
sentence_finished = False
|
| 127 |
+
total_word_length = 0
|
| 128 |
+
for word in clean_msg:
|
| 129 |
+
total_word_length += 1
|
| 130 |
+
if ("." in word) or ("?" in word) or ("!" in word):
|
| 131 |
+
sentence_finished = True
|
| 132 |
+
else:
|
| 133 |
+
if sentence_finished:
|
| 134 |
+
sentence_number+=1
|
| 135 |
+
sentence_finished=False
|
| 136 |
+
if total_word_length > 0:
|
| 137 |
+
avg_sentence_len.append(total_word_length/sentence_number)
|
| 138 |
+
avg_sentence_len = pd.Series(avg_sentence_len)
|
| 139 |
+
return avg_sentence_len
|
| 140 |
+
|
| 141 |
+
df['avg_sentence_len'] = get_avg_sentence_len(df['Message'])
|
| 142 |
+
|
| 143 |
+
plt.subplot(3, 2, 5)
|
| 144 |
+
sns.histplot(data=df, x='avg_sentence_len', hue='Category', bins=50, kde=True)
|
| 145 |
+
plt.title('Sentence length Distribution')
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
plt.tight_layout()
|
| 150 |
+
plt.savefig('eda_plots.png')
|
| 151 |
+
print("\nEDA plots saved to eda_plots.png")
|
| 152 |
+
|
Deep_Learning_Project/evaluate_final.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
|
| 5 |
+
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# 1. Load Data
|
| 10 |
+
df = pd.read_csv('mail_data.csv', names=['Category', 'Message'], header=None, skiprows=1)
|
| 11 |
+
df['label'] = df['Category'].map({'ham': 0, 'spam': 1})
|
| 12 |
+
|
| 13 |
+
_, test_texts, _, test_labels = train_test_split(
|
| 14 |
+
df['Message'].values.tolist(), df['label'].values.tolist(), test_size=0.2, random_state=42, stratify=df['label'].values
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# 2. Tokenization
|
| 18 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 19 |
+
test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=128)
|
| 20 |
+
|
| 21 |
+
class EmailDataset(torch.utils.data.Dataset):
|
| 22 |
+
def __init__(self, encodings, labels):
|
| 23 |
+
self.encodings = encodings
|
| 24 |
+
self.labels = labels
|
| 25 |
+
|
| 26 |
+
def __getitem__(self, idx):
|
| 27 |
+
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
| 28 |
+
item['labels'] = torch.tensor(self.labels[idx])
|
| 29 |
+
return item
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return len(self.labels)
|
| 33 |
+
|
| 34 |
+
test_dataset = EmailDataset(test_encodings, test_labels)
|
| 35 |
+
|
| 36 |
+
# 3. Load Model from Checkpoint
|
| 37 |
+
# Find the checkpoint directory
|
| 38 |
+
checkpoint_dir = [d for d in os.listdir('./results') if d.startswith('checkpoint')][0]
|
| 39 |
+
model_path = os.path.join('./results', checkpoint_dir)
|
| 40 |
+
print(f"Loading model from {model_path}")
|
| 41 |
+
|
| 42 |
+
model = DistilBertForSequenceClassification.from_pretrained(model_path)
|
| 43 |
+
|
| 44 |
+
# 4. Evaluation
|
| 45 |
+
trainer = Trainer(model=model)
|
| 46 |
+
predictions = trainer.predict(test_dataset)
|
| 47 |
+
preds = predictions.predictions.argmax(-1)
|
| 48 |
+
labels = predictions.label_ids
|
| 49 |
+
|
| 50 |
+
report = classification_report(labels, preds, target_names=['ham', 'spam'])
|
| 51 |
+
cm = confusion_matrix(labels, preds)
|
| 52 |
+
acc = accuracy_score(labels, preds)
|
| 53 |
+
|
| 54 |
+
print(f"Accuracy: {acc}")
|
| 55 |
+
print(report)
|
| 56 |
+
|
| 57 |
+
with open('results.txt', 'w') as f:
|
| 58 |
+
f.write(f"Final Evaluation Results (from {checkpoint_dir}):\n")
|
| 59 |
+
f.write(f"Accuracy: {acc}\n")
|
| 60 |
+
f.write(f"\nClassification Report:\n{report}\n")
|
| 61 |
+
f.write(f"\nConfusion Matrix:\n{cm}\n")
|
| 62 |
+
|
| 63 |
+
print("Evaluation complete. Results saved to results.txt")
|
Deep_Learning_Project/mail_data.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Deep_Learning_Project/mail_data_test.csv
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ham,"Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat..."
|
| 2 |
+
ham,Ok lar... Joking wif u oni...
|
| 3 |
+
spam,Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
|
| 4 |
+
ham,U dun say so early hor... U c already then say...
|
| 5 |
+
ham,"Nah I don't think he goes to usf, he lives around here though"
|
| 6 |
+
spam,"FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv"
|
| 7 |
+
ham,Even my brother is not like to speak with me. They treat me like aids patent.
|
| 8 |
+
ham,As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
|
| 9 |
+
spam,WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
|
Deep_Learning_Project/project_report.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deep Learning Project: Spam Detection using Transformers
|
| 2 |
+
|
| 3 |
+
**Course**: Deep Learning with Python (2025)
|
| 4 |
+
**Instructor**: Benoit Mialet
|
| 5 |
+
**Topic**: NLP - Text Classification (Spam vs Ham)
|
| 6 |
+
**Model**: DistilBERT (PyTorch / Hugging Face)
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## 1. Introduction
|
| 11 |
+
### 1.1 What & Why
|
| 12 |
+
The objective of this project is to develop a robust deep learning model for classifying emails as either "spam" or "ham" (legitimate). Email filtering is a critical application of Natural Language Processing (NLP) that helps improve user experience and security by automatically identifying unsolicited or malicious content.
|
| 13 |
+
|
| 14 |
+
### 1.2 Task Selection
|
| 15 |
+
We chose the **Text Classification** task, specifically binary classification. This task is well-suited for demonstrating the power of Transfer Learning and Transformer architectures in understanding the nuances of human language.
|
| 16 |
+
|
| 17 |
+
### 1.3 Relevance
|
| 18 |
+
Spam detection remains a relevant challenge as spamming techniques evolve. Traditional rule-based systems often fail to capture the semantic meaning of messages. Deep learning models, particularly Transformers, can capture long-range dependencies and contextual information, leading to higher accuracy and better generalization.
|
| 19 |
+
|
| 20 |
+
### 1.4 State of the Art
|
| 21 |
+
Modern NLP has been revolutionized by the Transformer architecture (Vaswani et al., 2017). Models like BERT (Bidirectional Encoder Representations from Transformers) and its variants (DistilBERT, RoBERTa) have set new benchmarks in text classification by pre-training on large corpora and fine-tuning on specific tasks.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## 2. Method
|
| 26 |
+
### 2.1 Overall Strategy
|
| 27 |
+
Our strategy involves:
|
| 28 |
+
1. **Exploratory Data Analysis (EDA)** to understand the dataset characteristics.
|
| 29 |
+
2. **Data Preprocessing** including tokenization and padding.
|
| 30 |
+
3. **Fine-tuning a Pre-trained Model** (DistilBERT) using the Hugging Face `transformers` library and PyTorch.
|
| 31 |
+
4. **Rigorous Evaluation** using metrics like Accuracy, Precision, Recall, and F1-score.
|
| 32 |
+
|
| 33 |
+
### 2.2 Dataset Description & EDA
|
| 34 |
+
The dataset used is `mail_data.csv`, containing 5,572 messages labeled as 'ham' or 'spam'.
|
| 35 |
+
- **Total Samples**: 5,572
|
| 36 |
+
- **Ham**: 4,825 (86.6%)
|
| 37 |
+
- **Spam**: 747 (13.4%)
|
| 38 |
+
- **Imbalance**: The dataset is significantly imbalanced, which we addressed by using stratified splitting and monitoring the F1-score.
|
| 39 |
+
|
| 40 |
+
**EDA Findings**:
|
| 41 |
+
- Spam messages tend to be longer on average than ham messages.
|
| 42 |
+
- Common keywords in spam include "free", "win", "winner", "call", "claim".
|
| 43 |
+
- Ham messages are more conversational and vary greatly in length.
|
| 44 |
+
|
| 45 |
+
### 2.3 Data Preprocessing
|
| 46 |
+
- **Tokenization**: We used the `DistilBertTokenizer` to convert raw text into input IDs and attention masks.
|
| 47 |
+
- **Truncation & Padding**: All sequences were padded or truncated to a maximum length of 128 tokens to ensure uniform input size for the model.
|
| 48 |
+
- **Train/Test Split**: 80% training (4,457 samples) and 20% testing (1,115 samples), with stratification to maintain class proportions.
|
| 49 |
+
|
| 50 |
+
### 2.4 Model Architecture
|
| 51 |
+
We utilized **DistilBERT** (`distilbert-base-uncased`), a smaller, faster, and lighter version of BERT that retains 97% of its performance. It has 6 layers, 768 hidden units, and 12 attention heads, totaling approximately 66 million parameters.
|
| 52 |
+
|
| 53 |
+
### 2.5 Training Setup
|
| 54 |
+
- **Optimizer**: AdamW with a learning rate of 2e-5.
|
| 55 |
+
- **Scheduler**: Linear warmup for 500 steps.
|
| 56 |
+
- **Loss Function**: Cross-Entropy Loss.
|
| 57 |
+
- **Batch Size**: 16 for training, 64 for evaluation.
|
| 58 |
+
- **Epochs**: 3 (stopped early after 1 epoch due to high performance and resource constraints).
|
| 59 |
+
- **Hardware**: CPU (simulated environment).
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## 3. Results
|
| 64 |
+
### 3.1 Performance Metrics
|
| 65 |
+
The model achieved exceptional results after just one epoch of fine-tuning:
|
| 66 |
+
|
| 67 |
+
| Metric | Value |
|
| 68 |
+
| :--- | :--- |
|
| 69 |
+
| **Accuracy** | 99.10% |
|
| 70 |
+
| **Precision (Spam)** | 98.60% |
|
| 71 |
+
| **Recall (Spam)** | 94.63% |
|
| 72 |
+
| **F1-Score (Spam)** | 96.58% |
|
| 73 |
+
|
| 74 |
+
### 3.2 Confusion Matrix
|
| 75 |
+
| | Predicted Ham | Predicted Spam |
|
| 76 |
+
| :--- | :---: | :---: |
|
| 77 |
+
| **Actual Ham** | 964 | 2 |
|
| 78 |
+
| **Actual Spam** | 8 | 141 |
|
| 79 |
+
|
| 80 |
+
The model correctly identified 141 out of 149 spam messages while only misclassifying 2 legitimate messages as spam (False Positives).
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## 4. Discussion
|
| 85 |
+
### 4.1 Interpretation
|
| 86 |
+
The high accuracy and F1-score indicate that DistilBERT is highly effective for this task. The model successfully learned the semantic patterns that distinguish spam from ham, even with a relatively small and imbalanced dataset.
|
| 87 |
+
|
| 88 |
+
### 4.2 What Worked
|
| 89 |
+
- **Transfer Learning**: Using a pre-trained model allowed us to achieve near-perfect results with minimal training time.
|
| 90 |
+
- **Hugging Face Trainer**: Simplified the training loop and handled evaluation efficiently.
|
| 91 |
+
- **Tokenization**: The subword tokenization of BERT handles out-of-vocabulary words better than traditional word-based methods.
|
| 92 |
+
|
| 93 |
+
### 4.3 Limitations
|
| 94 |
+
- **Dataset Size**: While sufficient for this project, a larger and more diverse dataset would be needed for a production-grade system.
|
| 95 |
+
- **Class Imbalance**: Although the model performed well, the recall for spam (94.63%) is slightly lower than for ham, reflecting the imbalance.
|
| 96 |
+
- **Adversarial Attacks**: Sophisticated spam might use techniques to bypass Transformer-based filters, which was not explored here.
|
| 97 |
+
|
| 98 |
+
### 4.4 Future Improvements
|
| 99 |
+
- **Data Augmentation**: Techniques like back-translation could help balance the dataset.
|
| 100 |
+
- **Hyperparameter Tuning**: Exploring different learning rates and batch sizes.
|
| 101 |
+
- **Deployment**: Creating a Gradio interface on Hugging Face Spaces for real-time testing.
|
| 102 |
+
- **Model Compression**: Quantization or pruning to make the model even lighter for mobile deployment.
|
| 103 |
+
|
| 104 |
+
---
|
| 105 |
+
|
| 106 |
+
## 5. Conclusion
|
| 107 |
+
This project successfully demonstrated the application of Deep Learning for spam detection. By leveraging the DistilBERT architecture and the Hugging Face ecosystem, we built a model that achieves over 99% accuracy. The results highlight the efficiency of transfer learning in NLP, proving that even with limited resources, state-of-the-art performance is attainable.
|
| 108 |
+
|
| 109 |
+
---
|
| 110 |
+
|
| 111 |
+
## 6. References
|
| 112 |
+
1. Vaswani, A., et al. (2017). "Attention Is All You Need."
|
| 113 |
+
2. Sanh, V., et al. (2019). "DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter."
|
| 114 |
+
3. Wolf, T., et al. (2020). "Transformers: State-of-the-Art Natural Language Processing."
|
Deep_Learning_Project/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python==3.13.12
|
| 2 |
+
gradio==5.49.1
|
| 3 |
+
transformers==4.57.1
|
| 4 |
+
torch==2.8.0
|
| 5 |
+
numpy==2.4.2
|
| 6 |
+
pandas==2.3.3
|
| 7 |
+
scikit-learn==1.8.0
|
| 8 |
+
matplotlib==3.10.8
|
| 9 |
+
seaborn==0.13.2
|
Deep_Learning_Project/results.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Final Evaluation Results:
|
| 2 |
+
{'eval_loss': 0.04282991588115692, 'eval_accuracy': 0.9928251121076234, 'eval_f1': 0.972972972972973, 'eval_precision': 0.9795918367346939, 'eval_recall': 0.9664429530201343, 'eval_runtime': 42.8545, 'eval_samples_per_second': 26.018, 'eval_steps_per_second': 0.42, 'epoch': 3.0}
|
| 3 |
+
|
| 4 |
+
Classification Report:
|
| 5 |
+
precision recall f1-score support
|
| 6 |
+
|
| 7 |
+
ham 0.99 1.00 1.00 966
|
| 8 |
+
spam 0.98 0.97 0.97 149
|
| 9 |
+
|
| 10 |
+
accuracy 0.99 1115
|
| 11 |
+
macro avg 0.99 0.98 0.98 1115
|
| 12 |
+
weighted avg 0.99 0.99 0.99 1115
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Confusion Matrix:
|
| 16 |
+
[[963 3]
|
| 17 |
+
[ 5 144]]
|
Deep_Learning_Project/results/checkpoint-279/config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"DistilBertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.1,
|
| 7 |
+
"bos_token_id": null,
|
| 8 |
+
"dim": 768,
|
| 9 |
+
"dropout": 0.1,
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"eos_token_id": null,
|
| 12 |
+
"hidden_dim": 3072,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"max_position_embeddings": 512,
|
| 15 |
+
"model_type": "distilbert",
|
| 16 |
+
"n_heads": 12,
|
| 17 |
+
"n_layers": 6,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"problem_type": "single_label_classification",
|
| 20 |
+
"qa_dropout": 0.1,
|
| 21 |
+
"seq_classif_dropout": 0.2,
|
| 22 |
+
"sinusoidal_pos_embds": false,
|
| 23 |
+
"tie_weights_": true,
|
| 24 |
+
"tie_word_embeddings": true,
|
| 25 |
+
"transformers_version": "5.3.0",
|
| 26 |
+
"use_cache": false,
|
| 27 |
+
"vocab_size": 30522
|
| 28 |
+
}
|
Deep_Learning_Project/results/checkpoint-279/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:776e086853f745009f0a06b8fbcb550cc2ad70e182eb8563a473ba4e0bc9822f
|
| 3 |
+
size 267832560
|
Deep_Learning_Project/results/checkpoint-279/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:969c762d65aebb85df619a878bc2cde125e5563dd6cb996878063145074a3333
|
| 3 |
+
size 535724875
|
Deep_Learning_Project/results/checkpoint-279/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:826fb52114115240de98bb8ce0c3df8acf1d27dfc083ebb7276218129dcbb363
|
| 3 |
+
size 14391
|
Deep_Learning_Project/results/checkpoint-279/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c08501cbcb6c2d9f05c9876a9e1c851c8d3b6ad5c3fae91849cbfa8c16dda076
|
| 3 |
+
size 1465
|
Deep_Learning_Project/results/checkpoint-279/trainer_state.json
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": 279,
|
| 3 |
+
"best_metric": 0.04283029958605766,
|
| 4 |
+
"best_model_checkpoint": "./results\\checkpoint-279",
|
| 5 |
+
"epoch": 1.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 279,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.035842293906810034,
|
| 14 |
+
"grad_norm": 4.070792198181152,
|
| 15 |
+
"learning_rate": 9e-07,
|
| 16 |
+
"loss": 0.6961095333099365,
|
| 17 |
+
"step": 10
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.07168458781362007,
|
| 21 |
+
"grad_norm": 3.0411295890808105,
|
| 22 |
+
"learning_rate": 1.9e-06,
|
| 23 |
+
"loss": 0.6723562717437744,
|
| 24 |
+
"step": 20
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.10752688172043011,
|
| 28 |
+
"grad_norm": 3.487802505493164,
|
| 29 |
+
"learning_rate": 2.9e-06,
|
| 30 |
+
"loss": 0.6391770362854003,
|
| 31 |
+
"step": 30
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.14336917562724014,
|
| 35 |
+
"grad_norm": 1.8233377933502197,
|
| 36 |
+
"learning_rate": 3.9e-06,
|
| 37 |
+
"loss": 0.5914762496948243,
|
| 38 |
+
"step": 40
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.17921146953405018,
|
| 42 |
+
"grad_norm": 1.4096935987472534,
|
| 43 |
+
"learning_rate": 4.9000000000000005e-06,
|
| 44 |
+
"loss": 0.5362973213195801,
|
| 45 |
+
"step": 50
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 0.21505376344086022,
|
| 49 |
+
"grad_norm": 1.8130015134811401,
|
| 50 |
+
"learning_rate": 5.9e-06,
|
| 51 |
+
"loss": 0.415470027923584,
|
| 52 |
+
"step": 60
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 0.25089605734767023,
|
| 56 |
+
"grad_norm": 1.5945749282836914,
|
| 57 |
+
"learning_rate": 6.900000000000001e-06,
|
| 58 |
+
"loss": 0.3050834178924561,
|
| 59 |
+
"step": 70
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 0.2867383512544803,
|
| 63 |
+
"grad_norm": 1.1991592645645142,
|
| 64 |
+
"learning_rate": 7.9e-06,
|
| 65 |
+
"loss": 0.2292243242263794,
|
| 66 |
+
"step": 80
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 0.3225806451612903,
|
| 70 |
+
"grad_norm": 3.596205711364746,
|
| 71 |
+
"learning_rate": 8.9e-06,
|
| 72 |
+
"loss": 0.13576759099960328,
|
| 73 |
+
"step": 90
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 0.35842293906810035,
|
| 77 |
+
"grad_norm": 1.3767762184143066,
|
| 78 |
+
"learning_rate": 9.900000000000002e-06,
|
| 79 |
+
"loss": 0.08189120292663574,
|
| 80 |
+
"step": 100
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 0.3942652329749104,
|
| 84 |
+
"grad_norm": 0.7235111594200134,
|
| 85 |
+
"learning_rate": 1.09e-05,
|
| 86 |
+
"loss": 0.09285140037536621,
|
| 87 |
+
"step": 110
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 0.43010752688172044,
|
| 91 |
+
"grad_norm": 4.165863513946533,
|
| 92 |
+
"learning_rate": 1.19e-05,
|
| 93 |
+
"loss": 0.09866725206375122,
|
| 94 |
+
"step": 120
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 0.4659498207885305,
|
| 98 |
+
"grad_norm": 0.24346598982810974,
|
| 99 |
+
"learning_rate": 1.29e-05,
|
| 100 |
+
"loss": 0.02432841956615448,
|
| 101 |
+
"step": 130
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 0.5017921146953405,
|
| 105 |
+
"grad_norm": 0.1488434225320816,
|
| 106 |
+
"learning_rate": 1.3900000000000002e-05,
|
| 107 |
+
"loss": 0.1057550072669983,
|
| 108 |
+
"step": 140
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 0.5376344086021505,
|
| 112 |
+
"grad_norm": 0.16067259013652802,
|
| 113 |
+
"learning_rate": 1.49e-05,
|
| 114 |
+
"loss": 0.04291625022888183,
|
| 115 |
+
"step": 150
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 0.5734767025089605,
|
| 119 |
+
"grad_norm": 0.814973771572113,
|
| 120 |
+
"learning_rate": 1.59e-05,
|
| 121 |
+
"loss": 0.020908774435520174,
|
| 122 |
+
"step": 160
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 0.6093189964157706,
|
| 126 |
+
"grad_norm": 0.08694848418235779,
|
| 127 |
+
"learning_rate": 1.69e-05,
|
| 128 |
+
"loss": 0.06968651413917541,
|
| 129 |
+
"step": 170
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 0.6451612903225806,
|
| 133 |
+
"grad_norm": 0.08773200958967209,
|
| 134 |
+
"learning_rate": 1.79e-05,
|
| 135 |
+
"loss": 0.04258593320846558,
|
| 136 |
+
"step": 180
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 0.6810035842293907,
|
| 140 |
+
"grad_norm": 0.09697633981704712,
|
| 141 |
+
"learning_rate": 1.8900000000000002e-05,
|
| 142 |
+
"loss": 0.037452369928359985,
|
| 143 |
+
"step": 190
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 0.7168458781362007,
|
| 147 |
+
"grad_norm": 0.10511818528175354,
|
| 148 |
+
"learning_rate": 1.9900000000000003e-05,
|
| 149 |
+
"loss": 0.007081887871026993,
|
| 150 |
+
"step": 200
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 0.7526881720430108,
|
| 154 |
+
"grad_norm": 0.0881538987159729,
|
| 155 |
+
"learning_rate": 2.09e-05,
|
| 156 |
+
"loss": 0.008766584843397141,
|
| 157 |
+
"step": 210
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 0.7885304659498208,
|
| 161 |
+
"grad_norm": 0.06669408828020096,
|
| 162 |
+
"learning_rate": 2.19e-05,
|
| 163 |
+
"loss": 0.05853445529937744,
|
| 164 |
+
"step": 220
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 0.8243727598566308,
|
| 168 |
+
"grad_norm": 0.05606541037559509,
|
| 169 |
+
"learning_rate": 2.29e-05,
|
| 170 |
+
"loss": 0.0038808729499578477,
|
| 171 |
+
"step": 230
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 0.8602150537634409,
|
| 175 |
+
"grad_norm": 5.139830112457275,
|
| 176 |
+
"learning_rate": 2.39e-05,
|
| 177 |
+
"loss": 0.06458683013916015,
|
| 178 |
+
"step": 240
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 0.8960573476702509,
|
| 182 |
+
"grad_norm": 0.06282316893339157,
|
| 183 |
+
"learning_rate": 2.4900000000000002e-05,
|
| 184 |
+
"loss": 0.006039018556475639,
|
| 185 |
+
"step": 250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 0.931899641577061,
|
| 189 |
+
"grad_norm": 7.700125694274902,
|
| 190 |
+
"learning_rate": 2.5900000000000003e-05,
|
| 191 |
+
"loss": 0.11402667760848999,
|
| 192 |
+
"step": 260
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"epoch": 0.967741935483871,
|
| 196 |
+
"grad_norm": 0.1442285031080246,
|
| 197 |
+
"learning_rate": 2.6900000000000003e-05,
|
| 198 |
+
"loss": 0.012984356284141541,
|
| 199 |
+
"step": 270
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"epoch": 1.0,
|
| 203 |
+
"eval_accuracy": 0.9928251121076234,
|
| 204 |
+
"eval_f1": 0.9726027397260274,
|
| 205 |
+
"eval_loss": 0.04283029958605766,
|
| 206 |
+
"eval_precision": 0.993006993006993,
|
| 207 |
+
"eval_recall": 0.9530201342281879,
|
| 208 |
+
"eval_runtime": 42.1831,
|
| 209 |
+
"eval_samples_per_second": 26.432,
|
| 210 |
+
"eval_steps_per_second": 0.427,
|
| 211 |
+
"step": 279
|
| 212 |
+
}
|
| 213 |
+
],
|
| 214 |
+
"logging_steps": 10,
|
| 215 |
+
"max_steps": 837,
|
| 216 |
+
"num_input_tokens_seen": 0,
|
| 217 |
+
"num_train_epochs": 3,
|
| 218 |
+
"save_steps": 500,
|
| 219 |
+
"stateful_callbacks": {
|
| 220 |
+
"TrainerControl": {
|
| 221 |
+
"args": {
|
| 222 |
+
"should_epoch_stop": false,
|
| 223 |
+
"should_evaluate": false,
|
| 224 |
+
"should_log": false,
|
| 225 |
+
"should_save": true,
|
| 226 |
+
"should_training_stop": false
|
| 227 |
+
},
|
| 228 |
+
"attributes": {}
|
| 229 |
+
}
|
| 230 |
+
},
|
| 231 |
+
"total_flos": 147601798952448.0,
|
| 232 |
+
"train_batch_size": 16,
|
| 233 |
+
"trial_name": null,
|
| 234 |
+
"trial_params": null
|
| 235 |
+
}
|
Deep_Learning_Project/results/checkpoint-279/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4fa1c04a8eee5e3e20cd86241a1b9ef2f6932b1f87ed49b5c7e1a4b6bf1a7ad0
|
| 3 |
+
size 5201
|
Deep_Learning_Project/results/checkpoint-558/config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"DistilBertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.1,
|
| 7 |
+
"bos_token_id": null,
|
| 8 |
+
"dim": 768,
|
| 9 |
+
"dropout": 0.1,
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"eos_token_id": null,
|
| 12 |
+
"hidden_dim": 3072,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"max_position_embeddings": 512,
|
| 15 |
+
"model_type": "distilbert",
|
| 16 |
+
"n_heads": 12,
|
| 17 |
+
"n_layers": 6,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"problem_type": "single_label_classification",
|
| 20 |
+
"qa_dropout": 0.1,
|
| 21 |
+
"seq_classif_dropout": 0.2,
|
| 22 |
+
"sinusoidal_pos_embds": false,
|
| 23 |
+
"tie_weights_": true,
|
| 24 |
+
"tie_word_embeddings": true,
|
| 25 |
+
"transformers_version": "5.3.0",
|
| 26 |
+
"use_cache": false,
|
| 27 |
+
"vocab_size": 30522
|
| 28 |
+
}
|
Deep_Learning_Project/results/checkpoint-558/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d6801b2fc9de6ccd631b024e96ac0964d55a32c171e938c21271224dbc8da90
|
| 3 |
+
size 267832560
|
Deep_Learning_Project/results/checkpoint-558/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2597d49d8bf7e2dd11e3084f908e6b1f989820acfae878e2ef2146d44f0dabf1
|
| 3 |
+
size 535724875
|
Deep_Learning_Project/results/checkpoint-558/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b7fd9840ce3449fb3fb17c7f4bc645548d8a9f10980ca4186834e471255c83ee
|
| 3 |
+
size 14391
|
Deep_Learning_Project/results/checkpoint-558/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee934d9a60a571b66f851c4ecfc7a4e17ba93cd05b33a1bfee57b0d258e5611d
|
| 3 |
+
size 1465
|
Deep_Learning_Project/results/checkpoint-558/trainer_state.json
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": 279,
|
| 3 |
+
"best_metric": 0.042980484664440155,
|
| 4 |
+
"best_model_checkpoint": "./results\\checkpoint-279",
|
| 5 |
+
"epoch": 2.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 558,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.035842293906810034,
|
| 14 |
+
"grad_norm": 4.168096542358398,
|
| 15 |
+
"learning_rate": 9e-07,
|
| 16 |
+
"loss": 0.7643725395202636,
|
| 17 |
+
"step": 10
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.07168458781362007,
|
| 21 |
+
"grad_norm": 3.3584811687469482,
|
| 22 |
+
"learning_rate": 1.9e-06,
|
| 23 |
+
"loss": 0.7450664043426514,
|
| 24 |
+
"step": 20
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.10752688172043011,
|
| 28 |
+
"grad_norm": 3.753683090209961,
|
| 29 |
+
"learning_rate": 2.9e-06,
|
| 30 |
+
"loss": 0.7087752342224121,
|
| 31 |
+
"step": 30
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.14336917562724014,
|
| 35 |
+
"grad_norm": 1.9111875295639038,
|
| 36 |
+
"learning_rate": 3.9e-06,
|
| 37 |
+
"loss": 0.6540002822875977,
|
| 38 |
+
"step": 40
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.17921146953405018,
|
| 42 |
+
"grad_norm": 1.412906527519226,
|
| 43 |
+
"learning_rate": 4.9000000000000005e-06,
|
| 44 |
+
"loss": 0.5844130516052246,
|
| 45 |
+
"step": 50
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 0.21505376344086022,
|
| 49 |
+
"grad_norm": 2.155179977416992,
|
| 50 |
+
"learning_rate": 5.9e-06,
|
| 51 |
+
"loss": 0.4763498783111572,
|
| 52 |
+
"step": 60
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 0.25089605734767023,
|
| 56 |
+
"grad_norm": 1.7028120756149292,
|
| 57 |
+
"learning_rate": 6.900000000000001e-06,
|
| 58 |
+
"loss": 0.338704252243042,
|
| 59 |
+
"step": 70
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 0.2867383512544803,
|
| 63 |
+
"grad_norm": 1.2706844806671143,
|
| 64 |
+
"learning_rate": 7.9e-06,
|
| 65 |
+
"loss": 0.2414313316345215,
|
| 66 |
+
"step": 80
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 0.3225806451612903,
|
| 70 |
+
"grad_norm": 1.987801194190979,
|
| 71 |
+
"learning_rate": 8.9e-06,
|
| 72 |
+
"loss": 0.13595094680786132,
|
| 73 |
+
"step": 90
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 0.35842293906810035,
|
| 77 |
+
"grad_norm": 1.6342447996139526,
|
| 78 |
+
"learning_rate": 9.900000000000002e-06,
|
| 79 |
+
"loss": 0.07480053305625915,
|
| 80 |
+
"step": 100
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 0.3942652329749104,
|
| 84 |
+
"grad_norm": 1.897525668144226,
|
| 85 |
+
"learning_rate": 1.09e-05,
|
| 86 |
+
"loss": 0.0948406457901001,
|
| 87 |
+
"step": 110
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 0.43010752688172044,
|
| 91 |
+
"grad_norm": 6.166868209838867,
|
| 92 |
+
"learning_rate": 1.19e-05,
|
| 93 |
+
"loss": 0.10227099657058716,
|
| 94 |
+
"step": 120
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 0.4659498207885305,
|
| 98 |
+
"grad_norm": 0.20176957547664642,
|
| 99 |
+
"learning_rate": 1.29e-05,
|
| 100 |
+
"loss": 0.023712341487407685,
|
| 101 |
+
"step": 130
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 0.5017921146953405,
|
| 105 |
+
"grad_norm": 0.1295539289712906,
|
| 106 |
+
"learning_rate": 1.3900000000000002e-05,
|
| 107 |
+
"loss": 0.0910467267036438,
|
| 108 |
+
"step": 140
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 0.5376344086021505,
|
| 112 |
+
"grad_norm": 0.15045320987701416,
|
| 113 |
+
"learning_rate": 1.49e-05,
|
| 114 |
+
"loss": 0.041373416781425476,
|
| 115 |
+
"step": 150
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 0.5734767025089605,
|
| 119 |
+
"grad_norm": 0.2986961603164673,
|
| 120 |
+
"learning_rate": 1.59e-05,
|
| 121 |
+
"loss": 0.0312202125787735,
|
| 122 |
+
"step": 160
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 0.6093189964157706,
|
| 126 |
+
"grad_norm": 0.0721135213971138,
|
| 127 |
+
"learning_rate": 1.69e-05,
|
| 128 |
+
"loss": 0.06126164197921753,
|
| 129 |
+
"step": 170
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 0.6451612903225806,
|
| 133 |
+
"grad_norm": 0.08172585070133209,
|
| 134 |
+
"learning_rate": 1.79e-05,
|
| 135 |
+
"loss": 0.041085737943649295,
|
| 136 |
+
"step": 180
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 0.6810035842293907,
|
| 140 |
+
"grad_norm": 0.09516890347003937,
|
| 141 |
+
"learning_rate": 1.8900000000000002e-05,
|
| 142 |
+
"loss": 0.04159103333950043,
|
| 143 |
+
"step": 190
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 0.7168458781362007,
|
| 147 |
+
"grad_norm": 0.07213900983333588,
|
| 148 |
+
"learning_rate": 1.9900000000000003e-05,
|
| 149 |
+
"loss": 0.005801299959421158,
|
| 150 |
+
"step": 200
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 0.7526881720430108,
|
| 154 |
+
"grad_norm": 0.053828418254852295,
|
| 155 |
+
"learning_rate": 2.09e-05,
|
| 156 |
+
"loss": 0.01323639452457428,
|
| 157 |
+
"step": 210
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 0.7885304659498208,
|
| 161 |
+
"grad_norm": 0.08665204793214798,
|
| 162 |
+
"learning_rate": 2.19e-05,
|
| 163 |
+
"loss": 0.06004759669303894,
|
| 164 |
+
"step": 220
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 0.8243727598566308,
|
| 168 |
+
"grad_norm": 0.04578182473778725,
|
| 169 |
+
"learning_rate": 2.29e-05,
|
| 170 |
+
"loss": 0.003989944979548454,
|
| 171 |
+
"step": 230
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 0.8602150537634409,
|
| 175 |
+
"grad_norm": 2.856813430786133,
|
| 176 |
+
"learning_rate": 2.39e-05,
|
| 177 |
+
"loss": 0.059805816411972045,
|
| 178 |
+
"step": 240
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 0.8960573476702509,
|
| 182 |
+
"grad_norm": 0.10952852666378021,
|
| 183 |
+
"learning_rate": 2.4900000000000002e-05,
|
| 184 |
+
"loss": 0.013767723739147187,
|
| 185 |
+
"step": 250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 0.931899641577061,
|
| 189 |
+
"grad_norm": 11.293957710266113,
|
| 190 |
+
"learning_rate": 2.5900000000000003e-05,
|
| 191 |
+
"loss": 0.09849974513053894,
|
| 192 |
+
"step": 260
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"epoch": 0.967741935483871,
|
| 196 |
+
"grad_norm": 0.22327303886413574,
|
| 197 |
+
"learning_rate": 2.6900000000000003e-05,
|
| 198 |
+
"loss": 0.020736195147037506,
|
| 199 |
+
"step": 270
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"epoch": 1.0,
|
| 203 |
+
"eval_accuracy": 0.9928251121076234,
|
| 204 |
+
"eval_f1": 0.9726027397260274,
|
| 205 |
+
"eval_loss": 0.042980484664440155,
|
| 206 |
+
"eval_precision": 0.993006993006993,
|
| 207 |
+
"eval_recall": 0.9530201342281879,
|
| 208 |
+
"eval_runtime": 39.0178,
|
| 209 |
+
"eval_samples_per_second": 28.577,
|
| 210 |
+
"eval_steps_per_second": 0.461,
|
| 211 |
+
"step": 279
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"epoch": 1.003584229390681,
|
| 215 |
+
"grad_norm": 0.08417440950870514,
|
| 216 |
+
"learning_rate": 2.7900000000000004e-05,
|
| 217 |
+
"loss": 0.08314200639724731,
|
| 218 |
+
"step": 280
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"epoch": 1.039426523297491,
|
| 222 |
+
"grad_norm": 12.878673553466797,
|
| 223 |
+
"learning_rate": 2.8899999999999998e-05,
|
| 224 |
+
"loss": 0.02687312364578247,
|
| 225 |
+
"step": 290
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"epoch": 1.075268817204301,
|
| 229 |
+
"grad_norm": 0.16290566325187683,
|
| 230 |
+
"learning_rate": 2.9900000000000002e-05,
|
| 231 |
+
"loss": 0.029204189777374268,
|
| 232 |
+
"step": 300
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"epoch": 1.1111111111111112,
|
| 236 |
+
"grad_norm": 0.09557493776082993,
|
| 237 |
+
"learning_rate": 3.09e-05,
|
| 238 |
+
"loss": 0.0412629634141922,
|
| 239 |
+
"step": 310
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"epoch": 1.146953405017921,
|
| 243 |
+
"grad_norm": 0.07874058932065964,
|
| 244 |
+
"learning_rate": 3.19e-05,
|
| 245 |
+
"loss": 0.06848401427268982,
|
| 246 |
+
"step": 320
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"epoch": 1.1827956989247312,
|
| 250 |
+
"grad_norm": 8.930388450622559,
|
| 251 |
+
"learning_rate": 3.29e-05,
|
| 252 |
+
"loss": 0.05015917420387268,
|
| 253 |
+
"step": 330
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"epoch": 1.2186379928315412,
|
| 257 |
+
"grad_norm": 0.03418437018990517,
|
| 258 |
+
"learning_rate": 3.3900000000000004e-05,
|
| 259 |
+
"loss": 0.02710776627063751,
|
| 260 |
+
"step": 340
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"epoch": 1.2544802867383513,
|
| 264 |
+
"grad_norm": 1.197632908821106,
|
| 265 |
+
"learning_rate": 3.49e-05,
|
| 266 |
+
"loss": 0.0023652609437704087,
|
| 267 |
+
"step": 350
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"epoch": 1.2903225806451613,
|
| 271 |
+
"grad_norm": 9.31840991973877,
|
| 272 |
+
"learning_rate": 3.59e-05,
|
| 273 |
+
"loss": 0.050595653057098386,
|
| 274 |
+
"step": 360
|
| 275 |
+
},
|
| 276 |
+
{
|
| 277 |
+
"epoch": 1.3261648745519714,
|
| 278 |
+
"grad_norm": 4.070376873016357,
|
| 279 |
+
"learning_rate": 3.69e-05,
|
| 280 |
+
"loss": 0.0879504919052124,
|
| 281 |
+
"step": 370
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"epoch": 1.3620071684587813,
|
| 285 |
+
"grad_norm": 0.12485872954130173,
|
| 286 |
+
"learning_rate": 3.79e-05,
|
| 287 |
+
"loss": 0.04475467205047608,
|
| 288 |
+
"step": 380
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"epoch": 1.3978494623655915,
|
| 292 |
+
"grad_norm": 0.033023957163095474,
|
| 293 |
+
"learning_rate": 3.8900000000000004e-05,
|
| 294 |
+
"loss": 0.03315191864967346,
|
| 295 |
+
"step": 390
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"epoch": 1.4336917562724014,
|
| 299 |
+
"grad_norm": 0.02942063845694065,
|
| 300 |
+
"learning_rate": 3.99e-05,
|
| 301 |
+
"loss": 0.023516546189785003,
|
| 302 |
+
"step": 400
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"epoch": 1.4695340501792113,
|
| 306 |
+
"grad_norm": 0.016614006832242012,
|
| 307 |
+
"learning_rate": 4.09e-05,
|
| 308 |
+
"loss": 0.001084418874233961,
|
| 309 |
+
"step": 410
|
| 310 |
+
},
|
| 311 |
+
{
|
| 312 |
+
"epoch": 1.5053763440860215,
|
| 313 |
+
"grad_norm": 0.021838072687387466,
|
| 314 |
+
"learning_rate": 4.19e-05,
|
| 315 |
+
"loss": 0.03104546368122101,
|
| 316 |
+
"step": 420
|
| 317 |
+
},
|
| 318 |
+
{
|
| 319 |
+
"epoch": 1.5412186379928317,
|
| 320 |
+
"grad_norm": 0.07978533208370209,
|
| 321 |
+
"learning_rate": 4.29e-05,
|
| 322 |
+
"loss": 0.0017994396388530732,
|
| 323 |
+
"step": 430
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"epoch": 1.5770609318996416,
|
| 327 |
+
"grad_norm": 0.014919363893568516,
|
| 328 |
+
"learning_rate": 4.39e-05,
|
| 329 |
+
"loss": 0.05550175905227661,
|
| 330 |
+
"step": 440
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"epoch": 1.6129032258064515,
|
| 334 |
+
"grad_norm": 0.037673186510801315,
|
| 335 |
+
"learning_rate": 4.49e-05,
|
| 336 |
+
"loss": 0.09091965556144714,
|
| 337 |
+
"step": 450
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"epoch": 1.6487455197132617,
|
| 341 |
+
"grad_norm": 17.163318634033203,
|
| 342 |
+
"learning_rate": 4.5900000000000004e-05,
|
| 343 |
+
"loss": 0.010417895019054412,
|
| 344 |
+
"step": 460
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"epoch": 1.6845878136200718,
|
| 348 |
+
"grad_norm": 6.812087535858154,
|
| 349 |
+
"learning_rate": 4.69e-05,
|
| 350 |
+
"loss": 0.045432651042938234,
|
| 351 |
+
"step": 470
|
| 352 |
+
},
|
| 353 |
+
{
|
| 354 |
+
"epoch": 1.7204301075268817,
|
| 355 |
+
"grad_norm": 0.0256840568035841,
|
| 356 |
+
"learning_rate": 4.79e-05,
|
| 357 |
+
"loss": 0.06190497279167175,
|
| 358 |
+
"step": 480
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"epoch": 1.7562724014336917,
|
| 362 |
+
"grad_norm": 0.08079337328672409,
|
| 363 |
+
"learning_rate": 4.89e-05,
|
| 364 |
+
"loss": 0.07536581158638,
|
| 365 |
+
"step": 490
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"epoch": 1.7921146953405018,
|
| 369 |
+
"grad_norm": 0.020881259813904762,
|
| 370 |
+
"learning_rate": 4.99e-05,
|
| 371 |
+
"loss": 0.050325697660446166,
|
| 372 |
+
"step": 500
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"epoch": 1.827956989247312,
|
| 376 |
+
"grad_norm": 0.011713648214936256,
|
| 377 |
+
"learning_rate": 4.8664688427299705e-05,
|
| 378 |
+
"loss": 0.020131270587444305,
|
| 379 |
+
"step": 510
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"epoch": 1.863799283154122,
|
| 383 |
+
"grad_norm": 0.028018169105052948,
|
| 384 |
+
"learning_rate": 4.7181008902077156e-05,
|
| 385 |
+
"loss": 0.02937682271003723,
|
| 386 |
+
"step": 520
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"epoch": 1.8996415770609318,
|
| 390 |
+
"grad_norm": 0.018673259764909744,
|
| 391 |
+
"learning_rate": 4.56973293768546e-05,
|
| 392 |
+
"loss": 0.04715012311935425,
|
| 393 |
+
"step": 530
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"epoch": 1.935483870967742,
|
| 397 |
+
"grad_norm": 3.4265527725219727,
|
| 398 |
+
"learning_rate": 4.421364985163205e-05,
|
| 399 |
+
"loss": 0.00353032723069191,
|
| 400 |
+
"step": 540
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"epoch": 1.971326164874552,
|
| 404 |
+
"grad_norm": 0.08214850723743439,
|
| 405 |
+
"learning_rate": 4.2729970326409497e-05,
|
| 406 |
+
"loss": 0.03979413509368897,
|
| 407 |
+
"step": 550
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"epoch": 2.0,
|
| 411 |
+
"eval_accuracy": 0.9856502242152466,
|
| 412 |
+
"eval_f1": 0.9477124183006536,
|
| 413 |
+
"eval_loss": 0.06383560597896576,
|
| 414 |
+
"eval_precision": 0.9235668789808917,
|
| 415 |
+
"eval_recall": 0.9731543624161074,
|
| 416 |
+
"eval_runtime": 40.1147,
|
| 417 |
+
"eval_samples_per_second": 27.795,
|
| 418 |
+
"eval_steps_per_second": 0.449,
|
| 419 |
+
"step": 558
|
| 420 |
+
}
|
| 421 |
+
],
|
| 422 |
+
"logging_steps": 10,
|
| 423 |
+
"max_steps": 837,
|
| 424 |
+
"num_input_tokens_seen": 0,
|
| 425 |
+
"num_train_epochs": 3,
|
| 426 |
+
"save_steps": 500,
|
| 427 |
+
"stateful_callbacks": {
|
| 428 |
+
"TrainerControl": {
|
| 429 |
+
"args": {
|
| 430 |
+
"should_epoch_stop": false,
|
| 431 |
+
"should_evaluate": false,
|
| 432 |
+
"should_log": false,
|
| 433 |
+
"should_save": true,
|
| 434 |
+
"should_training_stop": false
|
| 435 |
+
},
|
| 436 |
+
"attributes": {}
|
| 437 |
+
}
|
| 438 |
+
},
|
| 439 |
+
"total_flos": 295203597904896.0,
|
| 440 |
+
"train_batch_size": 16,
|
| 441 |
+
"trial_name": null,
|
| 442 |
+
"trial_params": null
|
| 443 |
+
}
|
Deep_Learning_Project/results/checkpoint-558/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4fa1c04a8eee5e3e20cd86241a1b9ef2f6932b1f87ed49b5c7e1a4b6bf1a7ad0
|
| 3 |
+
size 5201
|
Deep_Learning_Project/results/checkpoint-837/config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"DistilBertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.1,
|
| 7 |
+
"bos_token_id": null,
|
| 8 |
+
"dim": 768,
|
| 9 |
+
"dropout": 0.1,
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"eos_token_id": null,
|
| 12 |
+
"hidden_dim": 3072,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"max_position_embeddings": 512,
|
| 15 |
+
"model_type": "distilbert",
|
| 16 |
+
"n_heads": 12,
|
| 17 |
+
"n_layers": 6,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"problem_type": "single_label_classification",
|
| 20 |
+
"qa_dropout": 0.1,
|
| 21 |
+
"seq_classif_dropout": 0.2,
|
| 22 |
+
"sinusoidal_pos_embds": false,
|
| 23 |
+
"tie_weights_": true,
|
| 24 |
+
"tie_word_embeddings": true,
|
| 25 |
+
"transformers_version": "5.3.0",
|
| 26 |
+
"use_cache": false,
|
| 27 |
+
"vocab_size": 30522
|
| 28 |
+
}
|
Deep_Learning_Project/results/checkpoint-837/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d1f89da5ca6b192b1b7c42ccbe489442102b145f453959a18dc0bbac8c389997
|
| 3 |
+
size 267832560
|
Deep_Learning_Project/results/checkpoint-837/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:630db79696e91210af33e221111cfcefe5102841bf92dff474423cfef8708af0
|
| 3 |
+
size 535724875
|
Deep_Learning_Project/results/checkpoint-837/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d65239f408a2b9b9f7060940cb5a72607f4e554a32f1e693d531e3f14da646b
|
| 3 |
+
size 14391
|
Deep_Learning_Project/results/checkpoint-837/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1118deb2b211462eb2b6af551d7bb98b42af89197dd44d8542729ab4409e2a1
|
| 3 |
+
size 1465
|
Deep_Learning_Project/results/checkpoint-837/trainer_state.json
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": 837,
|
| 3 |
+
"best_metric": 0.04282991588115692,
|
| 4 |
+
"best_model_checkpoint": "./results\\checkpoint-837",
|
| 5 |
+
"epoch": 3.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 837,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.035842293906810034,
|
| 14 |
+
"grad_norm": 4.168096542358398,
|
| 15 |
+
"learning_rate": 9e-07,
|
| 16 |
+
"loss": 0.7643725395202636,
|
| 17 |
+
"step": 10
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.07168458781362007,
|
| 21 |
+
"grad_norm": 3.3584811687469482,
|
| 22 |
+
"learning_rate": 1.9e-06,
|
| 23 |
+
"loss": 0.7450664043426514,
|
| 24 |
+
"step": 20
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.10752688172043011,
|
| 28 |
+
"grad_norm": 3.753683090209961,
|
| 29 |
+
"learning_rate": 2.9e-06,
|
| 30 |
+
"loss": 0.7087752342224121,
|
| 31 |
+
"step": 30
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.14336917562724014,
|
| 35 |
+
"grad_norm": 1.9111875295639038,
|
| 36 |
+
"learning_rate": 3.9e-06,
|
| 37 |
+
"loss": 0.6540002822875977,
|
| 38 |
+
"step": 40
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.17921146953405018,
|
| 42 |
+
"grad_norm": 1.412906527519226,
|
| 43 |
+
"learning_rate": 4.9000000000000005e-06,
|
| 44 |
+
"loss": 0.5844130516052246,
|
| 45 |
+
"step": 50
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 0.21505376344086022,
|
| 49 |
+
"grad_norm": 2.155179977416992,
|
| 50 |
+
"learning_rate": 5.9e-06,
|
| 51 |
+
"loss": 0.4763498783111572,
|
| 52 |
+
"step": 60
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 0.25089605734767023,
|
| 56 |
+
"grad_norm": 1.7028120756149292,
|
| 57 |
+
"learning_rate": 6.900000000000001e-06,
|
| 58 |
+
"loss": 0.338704252243042,
|
| 59 |
+
"step": 70
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 0.2867383512544803,
|
| 63 |
+
"grad_norm": 1.2706844806671143,
|
| 64 |
+
"learning_rate": 7.9e-06,
|
| 65 |
+
"loss": 0.2414313316345215,
|
| 66 |
+
"step": 80
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 0.3225806451612903,
|
| 70 |
+
"grad_norm": 1.987801194190979,
|
| 71 |
+
"learning_rate": 8.9e-06,
|
| 72 |
+
"loss": 0.13595094680786132,
|
| 73 |
+
"step": 90
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 0.35842293906810035,
|
| 77 |
+
"grad_norm": 1.6342447996139526,
|
| 78 |
+
"learning_rate": 9.900000000000002e-06,
|
| 79 |
+
"loss": 0.07480053305625915,
|
| 80 |
+
"step": 100
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 0.3942652329749104,
|
| 84 |
+
"grad_norm": 1.897525668144226,
|
| 85 |
+
"learning_rate": 1.09e-05,
|
| 86 |
+
"loss": 0.0948406457901001,
|
| 87 |
+
"step": 110
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 0.43010752688172044,
|
| 91 |
+
"grad_norm": 6.166868209838867,
|
| 92 |
+
"learning_rate": 1.19e-05,
|
| 93 |
+
"loss": 0.10227099657058716,
|
| 94 |
+
"step": 120
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 0.4659498207885305,
|
| 98 |
+
"grad_norm": 0.20176957547664642,
|
| 99 |
+
"learning_rate": 1.29e-05,
|
| 100 |
+
"loss": 0.023712341487407685,
|
| 101 |
+
"step": 130
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 0.5017921146953405,
|
| 105 |
+
"grad_norm": 0.1295539289712906,
|
| 106 |
+
"learning_rate": 1.3900000000000002e-05,
|
| 107 |
+
"loss": 0.0910467267036438,
|
| 108 |
+
"step": 140
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 0.5376344086021505,
|
| 112 |
+
"grad_norm": 0.15045320987701416,
|
| 113 |
+
"learning_rate": 1.49e-05,
|
| 114 |
+
"loss": 0.041373416781425476,
|
| 115 |
+
"step": 150
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 0.5734767025089605,
|
| 119 |
+
"grad_norm": 0.2986961603164673,
|
| 120 |
+
"learning_rate": 1.59e-05,
|
| 121 |
+
"loss": 0.0312202125787735,
|
| 122 |
+
"step": 160
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 0.6093189964157706,
|
| 126 |
+
"grad_norm": 0.0721135213971138,
|
| 127 |
+
"learning_rate": 1.69e-05,
|
| 128 |
+
"loss": 0.06126164197921753,
|
| 129 |
+
"step": 170
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 0.6451612903225806,
|
| 133 |
+
"grad_norm": 0.08172585070133209,
|
| 134 |
+
"learning_rate": 1.79e-05,
|
| 135 |
+
"loss": 0.041085737943649295,
|
| 136 |
+
"step": 180
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 0.6810035842293907,
|
| 140 |
+
"grad_norm": 0.09516890347003937,
|
| 141 |
+
"learning_rate": 1.8900000000000002e-05,
|
| 142 |
+
"loss": 0.04159103333950043,
|
| 143 |
+
"step": 190
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 0.7168458781362007,
|
| 147 |
+
"grad_norm": 0.07213900983333588,
|
| 148 |
+
"learning_rate": 1.9900000000000003e-05,
|
| 149 |
+
"loss": 0.005801299959421158,
|
| 150 |
+
"step": 200
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 0.7526881720430108,
|
| 154 |
+
"grad_norm": 0.053828418254852295,
|
| 155 |
+
"learning_rate": 2.09e-05,
|
| 156 |
+
"loss": 0.01323639452457428,
|
| 157 |
+
"step": 210
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 0.7885304659498208,
|
| 161 |
+
"grad_norm": 0.08665204793214798,
|
| 162 |
+
"learning_rate": 2.19e-05,
|
| 163 |
+
"loss": 0.06004759669303894,
|
| 164 |
+
"step": 220
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 0.8243727598566308,
|
| 168 |
+
"grad_norm": 0.04578182473778725,
|
| 169 |
+
"learning_rate": 2.29e-05,
|
| 170 |
+
"loss": 0.003989944979548454,
|
| 171 |
+
"step": 230
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 0.8602150537634409,
|
| 175 |
+
"grad_norm": 2.856813430786133,
|
| 176 |
+
"learning_rate": 2.39e-05,
|
| 177 |
+
"loss": 0.059805816411972045,
|
| 178 |
+
"step": 240
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 0.8960573476702509,
|
| 182 |
+
"grad_norm": 0.10952852666378021,
|
| 183 |
+
"learning_rate": 2.4900000000000002e-05,
|
| 184 |
+
"loss": 0.013767723739147187,
|
| 185 |
+
"step": 250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 0.931899641577061,
|
| 189 |
+
"grad_norm": 11.293957710266113,
|
| 190 |
+
"learning_rate": 2.5900000000000003e-05,
|
| 191 |
+
"loss": 0.09849974513053894,
|
| 192 |
+
"step": 260
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"epoch": 0.967741935483871,
|
| 196 |
+
"grad_norm": 0.22327303886413574,
|
| 197 |
+
"learning_rate": 2.6900000000000003e-05,
|
| 198 |
+
"loss": 0.020736195147037506,
|
| 199 |
+
"step": 270
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"epoch": 1.0,
|
| 203 |
+
"eval_accuracy": 0.9928251121076234,
|
| 204 |
+
"eval_f1": 0.9726027397260274,
|
| 205 |
+
"eval_loss": 0.042980484664440155,
|
| 206 |
+
"eval_precision": 0.993006993006993,
|
| 207 |
+
"eval_recall": 0.9530201342281879,
|
| 208 |
+
"eval_runtime": 39.0178,
|
| 209 |
+
"eval_samples_per_second": 28.577,
|
| 210 |
+
"eval_steps_per_second": 0.461,
|
| 211 |
+
"step": 279
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"epoch": 1.003584229390681,
|
| 215 |
+
"grad_norm": 0.08417440950870514,
|
| 216 |
+
"learning_rate": 2.7900000000000004e-05,
|
| 217 |
+
"loss": 0.08314200639724731,
|
| 218 |
+
"step": 280
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"epoch": 1.039426523297491,
|
| 222 |
+
"grad_norm": 12.878673553466797,
|
| 223 |
+
"learning_rate": 2.8899999999999998e-05,
|
| 224 |
+
"loss": 0.02687312364578247,
|
| 225 |
+
"step": 290
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"epoch": 1.075268817204301,
|
| 229 |
+
"grad_norm": 0.16290566325187683,
|
| 230 |
+
"learning_rate": 2.9900000000000002e-05,
|
| 231 |
+
"loss": 0.029204189777374268,
|
| 232 |
+
"step": 300
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"epoch": 1.1111111111111112,
|
| 236 |
+
"grad_norm": 0.09557493776082993,
|
| 237 |
+
"learning_rate": 3.09e-05,
|
| 238 |
+
"loss": 0.0412629634141922,
|
| 239 |
+
"step": 310
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"epoch": 1.146953405017921,
|
| 243 |
+
"grad_norm": 0.07874058932065964,
|
| 244 |
+
"learning_rate": 3.19e-05,
|
| 245 |
+
"loss": 0.06848401427268982,
|
| 246 |
+
"step": 320
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"epoch": 1.1827956989247312,
|
| 250 |
+
"grad_norm": 8.930388450622559,
|
| 251 |
+
"learning_rate": 3.29e-05,
|
| 252 |
+
"loss": 0.05015917420387268,
|
| 253 |
+
"step": 330
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"epoch": 1.2186379928315412,
|
| 257 |
+
"grad_norm": 0.03418437018990517,
|
| 258 |
+
"learning_rate": 3.3900000000000004e-05,
|
| 259 |
+
"loss": 0.02710776627063751,
|
| 260 |
+
"step": 340
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"epoch": 1.2544802867383513,
|
| 264 |
+
"grad_norm": 1.197632908821106,
|
| 265 |
+
"learning_rate": 3.49e-05,
|
| 266 |
+
"loss": 0.0023652609437704087,
|
| 267 |
+
"step": 350
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"epoch": 1.2903225806451613,
|
| 271 |
+
"grad_norm": 9.31840991973877,
|
| 272 |
+
"learning_rate": 3.59e-05,
|
| 273 |
+
"loss": 0.050595653057098386,
|
| 274 |
+
"step": 360
|
| 275 |
+
},
|
| 276 |
+
{
|
| 277 |
+
"epoch": 1.3261648745519714,
|
| 278 |
+
"grad_norm": 4.070376873016357,
|
| 279 |
+
"learning_rate": 3.69e-05,
|
| 280 |
+
"loss": 0.0879504919052124,
|
| 281 |
+
"step": 370
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"epoch": 1.3620071684587813,
|
| 285 |
+
"grad_norm": 0.12485872954130173,
|
| 286 |
+
"learning_rate": 3.79e-05,
|
| 287 |
+
"loss": 0.04475467205047608,
|
| 288 |
+
"step": 380
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"epoch": 1.3978494623655915,
|
| 292 |
+
"grad_norm": 0.033023957163095474,
|
| 293 |
+
"learning_rate": 3.8900000000000004e-05,
|
| 294 |
+
"loss": 0.03315191864967346,
|
| 295 |
+
"step": 390
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"epoch": 1.4336917562724014,
|
| 299 |
+
"grad_norm": 0.02942063845694065,
|
| 300 |
+
"learning_rate": 3.99e-05,
|
| 301 |
+
"loss": 0.023516546189785003,
|
| 302 |
+
"step": 400
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"epoch": 1.4695340501792113,
|
| 306 |
+
"grad_norm": 0.016614006832242012,
|
| 307 |
+
"learning_rate": 4.09e-05,
|
| 308 |
+
"loss": 0.001084418874233961,
|
| 309 |
+
"step": 410
|
| 310 |
+
},
|
| 311 |
+
{
|
| 312 |
+
"epoch": 1.5053763440860215,
|
| 313 |
+
"grad_norm": 0.021838072687387466,
|
| 314 |
+
"learning_rate": 4.19e-05,
|
| 315 |
+
"loss": 0.03104546368122101,
|
| 316 |
+
"step": 420
|
| 317 |
+
},
|
| 318 |
+
{
|
| 319 |
+
"epoch": 1.5412186379928317,
|
| 320 |
+
"grad_norm": 0.07978533208370209,
|
| 321 |
+
"learning_rate": 4.29e-05,
|
| 322 |
+
"loss": 0.0017994396388530732,
|
| 323 |
+
"step": 430
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"epoch": 1.5770609318996416,
|
| 327 |
+
"grad_norm": 0.014919363893568516,
|
| 328 |
+
"learning_rate": 4.39e-05,
|
| 329 |
+
"loss": 0.05550175905227661,
|
| 330 |
+
"step": 440
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"epoch": 1.6129032258064515,
|
| 334 |
+
"grad_norm": 0.037673186510801315,
|
| 335 |
+
"learning_rate": 4.49e-05,
|
| 336 |
+
"loss": 0.09091965556144714,
|
| 337 |
+
"step": 450
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"epoch": 1.6487455197132617,
|
| 341 |
+
"grad_norm": 17.163318634033203,
|
| 342 |
+
"learning_rate": 4.5900000000000004e-05,
|
| 343 |
+
"loss": 0.010417895019054412,
|
| 344 |
+
"step": 460
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"epoch": 1.6845878136200718,
|
| 348 |
+
"grad_norm": 6.812087535858154,
|
| 349 |
+
"learning_rate": 4.69e-05,
|
| 350 |
+
"loss": 0.045432651042938234,
|
| 351 |
+
"step": 470
|
| 352 |
+
},
|
| 353 |
+
{
|
| 354 |
+
"epoch": 1.7204301075268817,
|
| 355 |
+
"grad_norm": 0.0256840568035841,
|
| 356 |
+
"learning_rate": 4.79e-05,
|
| 357 |
+
"loss": 0.06190497279167175,
|
| 358 |
+
"step": 480
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"epoch": 1.7562724014336917,
|
| 362 |
+
"grad_norm": 0.08079337328672409,
|
| 363 |
+
"learning_rate": 4.89e-05,
|
| 364 |
+
"loss": 0.07536581158638,
|
| 365 |
+
"step": 490
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"epoch": 1.7921146953405018,
|
| 369 |
+
"grad_norm": 0.020881259813904762,
|
| 370 |
+
"learning_rate": 4.99e-05,
|
| 371 |
+
"loss": 0.050325697660446166,
|
| 372 |
+
"step": 500
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"epoch": 1.827956989247312,
|
| 376 |
+
"grad_norm": 0.011713648214936256,
|
| 377 |
+
"learning_rate": 4.8664688427299705e-05,
|
| 378 |
+
"loss": 0.020131270587444305,
|
| 379 |
+
"step": 510
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"epoch": 1.863799283154122,
|
| 383 |
+
"grad_norm": 0.028018169105052948,
|
| 384 |
+
"learning_rate": 4.7181008902077156e-05,
|
| 385 |
+
"loss": 0.02937682271003723,
|
| 386 |
+
"step": 520
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"epoch": 1.8996415770609318,
|
| 390 |
+
"grad_norm": 0.018673259764909744,
|
| 391 |
+
"learning_rate": 4.56973293768546e-05,
|
| 392 |
+
"loss": 0.04715012311935425,
|
| 393 |
+
"step": 530
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"epoch": 1.935483870967742,
|
| 397 |
+
"grad_norm": 3.4265527725219727,
|
| 398 |
+
"learning_rate": 4.421364985163205e-05,
|
| 399 |
+
"loss": 0.00353032723069191,
|
| 400 |
+
"step": 540
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"epoch": 1.971326164874552,
|
| 404 |
+
"grad_norm": 0.08214850723743439,
|
| 405 |
+
"learning_rate": 4.2729970326409497e-05,
|
| 406 |
+
"loss": 0.03979413509368897,
|
| 407 |
+
"step": 550
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"epoch": 2.0,
|
| 411 |
+
"eval_accuracy": 0.9856502242152466,
|
| 412 |
+
"eval_f1": 0.9477124183006536,
|
| 413 |
+
"eval_loss": 0.06383560597896576,
|
| 414 |
+
"eval_precision": 0.9235668789808917,
|
| 415 |
+
"eval_recall": 0.9731543624161074,
|
| 416 |
+
"eval_runtime": 40.1147,
|
| 417 |
+
"eval_samples_per_second": 27.795,
|
| 418 |
+
"eval_steps_per_second": 0.449,
|
| 419 |
+
"step": 558
|
| 420 |
+
},
|
| 421 |
+
{
|
| 422 |
+
"epoch": 2.007168458781362,
|
| 423 |
+
"grad_norm": 0.01592390425503254,
|
| 424 |
+
"learning_rate": 4.124629080118694e-05,
|
| 425 |
+
"loss": 0.010815779119729996,
|
| 426 |
+
"step": 560
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"epoch": 2.043010752688172,
|
| 430 |
+
"grad_norm": 0.011524681001901627,
|
| 431 |
+
"learning_rate": 3.976261127596439e-05,
|
| 432 |
+
"loss": 0.0023574797436594964,
|
| 433 |
+
"step": 570
|
| 434 |
+
},
|
| 435 |
+
{
|
| 436 |
+
"epoch": 2.078853046594982,
|
| 437 |
+
"grad_norm": 0.009996078908443451,
|
| 438 |
+
"learning_rate": 3.8278931750741844e-05,
|
| 439 |
+
"loss": 0.0014362875372171402,
|
| 440 |
+
"step": 580
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
"epoch": 2.1146953405017923,
|
| 444 |
+
"grad_norm": 0.01068540196865797,
|
| 445 |
+
"learning_rate": 3.679525222551929e-05,
|
| 446 |
+
"loss": 0.0005183162167668343,
|
| 447 |
+
"step": 590
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
"epoch": 2.150537634408602,
|
| 451 |
+
"grad_norm": 0.007895471528172493,
|
| 452 |
+
"learning_rate": 3.531157270029673e-05,
|
| 453 |
+
"loss": 0.000452942680567503,
|
| 454 |
+
"step": 600
|
| 455 |
+
},
|
| 456 |
+
{
|
| 457 |
+
"epoch": 2.186379928315412,
|
| 458 |
+
"grad_norm": 0.011167576536536217,
|
| 459 |
+
"learning_rate": 3.382789317507419e-05,
|
| 460 |
+
"loss": 0.04565889239311218,
|
| 461 |
+
"step": 610
|
| 462 |
+
},
|
| 463 |
+
{
|
| 464 |
+
"epoch": 2.2222222222222223,
|
| 465 |
+
"grad_norm": 0.021645231172442436,
|
| 466 |
+
"learning_rate": 3.2344213649851636e-05,
|
| 467 |
+
"loss": 0.04359175860881805,
|
| 468 |
+
"step": 620
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
"epoch": 2.258064516129032,
|
| 472 |
+
"grad_norm": 0.020626947283744812,
|
| 473 |
+
"learning_rate": 3.086053412462908e-05,
|
| 474 |
+
"loss": 0.0012028517201542854,
|
| 475 |
+
"step": 630
|
| 476 |
+
},
|
| 477 |
+
{
|
| 478 |
+
"epoch": 2.293906810035842,
|
| 479 |
+
"grad_norm": 2.1626150608062744,
|
| 480 |
+
"learning_rate": 2.937685459940653e-05,
|
| 481 |
+
"loss": 0.04995043277740478,
|
| 482 |
+
"step": 640
|
| 483 |
+
},
|
| 484 |
+
{
|
| 485 |
+
"epoch": 2.3297491039426523,
|
| 486 |
+
"grad_norm": 0.024094557389616966,
|
| 487 |
+
"learning_rate": 2.789317507418398e-05,
|
| 488 |
+
"loss": 0.0011813377030193805,
|
| 489 |
+
"step": 650
|
| 490 |
+
},
|
| 491 |
+
{
|
| 492 |
+
"epoch": 2.3655913978494625,
|
| 493 |
+
"grad_norm": 0.024276426061987877,
|
| 494 |
+
"learning_rate": 2.6409495548961428e-05,
|
| 495 |
+
"loss": 0.000971246138215065,
|
| 496 |
+
"step": 660
|
| 497 |
+
},
|
| 498 |
+
{
|
| 499 |
+
"epoch": 2.4014336917562726,
|
| 500 |
+
"grad_norm": 0.013462797738611698,
|
| 501 |
+
"learning_rate": 2.4925816023738872e-05,
|
| 502 |
+
"loss": 0.001258693542331457,
|
| 503 |
+
"step": 670
|
| 504 |
+
},
|
| 505 |
+
{
|
| 506 |
+
"epoch": 2.4372759856630823,
|
| 507 |
+
"grad_norm": 0.011892343871295452,
|
| 508 |
+
"learning_rate": 2.344213649851632e-05,
|
| 509 |
+
"loss": 0.000927103403955698,
|
| 510 |
+
"step": 680
|
| 511 |
+
},
|
| 512 |
+
{
|
| 513 |
+
"epoch": 2.4731182795698925,
|
| 514 |
+
"grad_norm": 0.020849108695983887,
|
| 515 |
+
"learning_rate": 2.195845697329377e-05,
|
| 516 |
+
"loss": 0.048611697554588315,
|
| 517 |
+
"step": 690
|
| 518 |
+
},
|
| 519 |
+
{
|
| 520 |
+
"epoch": 2.5089605734767026,
|
| 521 |
+
"grad_norm": 0.027146149426698685,
|
| 522 |
+
"learning_rate": 2.0474777448071216e-05,
|
| 523 |
+
"loss": 0.0010239629074931145,
|
| 524 |
+
"step": 700
|
| 525 |
+
},
|
| 526 |
+
{
|
| 527 |
+
"epoch": 2.5448028673835124,
|
| 528 |
+
"grad_norm": 0.02077840454876423,
|
| 529 |
+
"learning_rate": 1.8991097922848668e-05,
|
| 530 |
+
"loss": 0.011852725595235824,
|
| 531 |
+
"step": 710
|
| 532 |
+
},
|
| 533 |
+
{
|
| 534 |
+
"epoch": 2.5806451612903225,
|
| 535 |
+
"grad_norm": 0.013969900086522102,
|
| 536 |
+
"learning_rate": 1.7507418397626112e-05,
|
| 537 |
+
"loss": 0.0008809153921902179,
|
| 538 |
+
"step": 720
|
| 539 |
+
},
|
| 540 |
+
{
|
| 541 |
+
"epoch": 2.6164874551971327,
|
| 542 |
+
"grad_norm": 0.01932111196219921,
|
| 543 |
+
"learning_rate": 1.6023738872403564e-05,
|
| 544 |
+
"loss": 0.0008839547634124755,
|
| 545 |
+
"step": 730
|
| 546 |
+
},
|
| 547 |
+
{
|
| 548 |
+
"epoch": 2.652329749103943,
|
| 549 |
+
"grad_norm": 0.011346139945089817,
|
| 550 |
+
"learning_rate": 1.454005934718101e-05,
|
| 551 |
+
"loss": 0.0007063972763717175,
|
| 552 |
+
"step": 740
|
| 553 |
+
},
|
| 554 |
+
{
|
| 555 |
+
"epoch": 2.688172043010753,
|
| 556 |
+
"grad_norm": 0.02365295961499214,
|
| 557 |
+
"learning_rate": 1.3056379821958458e-05,
|
| 558 |
+
"loss": 0.03284582793712616,
|
| 559 |
+
"step": 750
|
| 560 |
+
},
|
| 561 |
+
{
|
| 562 |
+
"epoch": 2.7240143369175627,
|
| 563 |
+
"grad_norm": 0.019500477239489555,
|
| 564 |
+
"learning_rate": 1.1572700296735906e-05,
|
| 565 |
+
"loss": 0.009724142402410508,
|
| 566 |
+
"step": 760
|
| 567 |
+
},
|
| 568 |
+
{
|
| 569 |
+
"epoch": 2.759856630824373,
|
| 570 |
+
"grad_norm": 0.013806493952870369,
|
| 571 |
+
"learning_rate": 1.0089020771513354e-05,
|
| 572 |
+
"loss": 0.000749332644045353,
|
| 573 |
+
"step": 770
|
| 574 |
+
},
|
| 575 |
+
{
|
| 576 |
+
"epoch": 2.795698924731183,
|
| 577 |
+
"grad_norm": 0.04070394113659859,
|
| 578 |
+
"learning_rate": 8.605341246290802e-06,
|
| 579 |
+
"loss": 0.0005624637007713318,
|
| 580 |
+
"step": 780
|
| 581 |
+
},
|
| 582 |
+
{
|
| 583 |
+
"epoch": 2.8315412186379927,
|
| 584 |
+
"grad_norm": 0.014284063130617142,
|
| 585 |
+
"learning_rate": 7.12166172106825e-06,
|
| 586 |
+
"loss": 0.0005397583357989788,
|
| 587 |
+
"step": 790
|
| 588 |
+
},
|
| 589 |
+
{
|
| 590 |
+
"epoch": 2.867383512544803,
|
| 591 |
+
"grad_norm": 0.01228868868201971,
|
| 592 |
+
"learning_rate": 5.637982195845697e-06,
|
| 593 |
+
"loss": 0.0026672758162021638,
|
| 594 |
+
"step": 800
|
| 595 |
+
},
|
| 596 |
+
{
|
| 597 |
+
"epoch": 2.903225806451613,
|
| 598 |
+
"grad_norm": 0.007240036968141794,
|
| 599 |
+
"learning_rate": 4.154302670623145e-06,
|
| 600 |
+
"loss": 0.0004312596283853054,
|
| 601 |
+
"step": 810
|
| 602 |
+
},
|
| 603 |
+
{
|
| 604 |
+
"epoch": 2.9390681003584227,
|
| 605 |
+
"grad_norm": 0.010902538895606995,
|
| 606 |
+
"learning_rate": 2.6706231454005935e-06,
|
| 607 |
+
"loss": 0.0004907793365418911,
|
| 608 |
+
"step": 820
|
| 609 |
+
},
|
| 610 |
+
{
|
| 611 |
+
"epoch": 2.974910394265233,
|
| 612 |
+
"grad_norm": 0.011084447614848614,
|
| 613 |
+
"learning_rate": 1.1869436201780417e-06,
|
| 614 |
+
"loss": 0.01800166815519333,
|
| 615 |
+
"step": 830
|
| 616 |
+
},
|
| 617 |
+
{
|
| 618 |
+
"epoch": 3.0,
|
| 619 |
+
"eval_accuracy": 0.9928251121076234,
|
| 620 |
+
"eval_f1": 0.972972972972973,
|
| 621 |
+
"eval_loss": 0.04282991588115692,
|
| 622 |
+
"eval_precision": 0.9795918367346939,
|
| 623 |
+
"eval_recall": 0.9664429530201343,
|
| 624 |
+
"eval_runtime": 42.4854,
|
| 625 |
+
"eval_samples_per_second": 26.244,
|
| 626 |
+
"eval_steps_per_second": 0.424,
|
| 627 |
+
"step": 837
|
| 628 |
+
}
|
| 629 |
+
],
|
| 630 |
+
"logging_steps": 10,
|
| 631 |
+
"max_steps": 837,
|
| 632 |
+
"num_input_tokens_seen": 0,
|
| 633 |
+
"num_train_epochs": 3,
|
| 634 |
+
"save_steps": 500,
|
| 635 |
+
"stateful_callbacks": {
|
| 636 |
+
"TrainerControl": {
|
| 637 |
+
"args": {
|
| 638 |
+
"should_epoch_stop": false,
|
| 639 |
+
"should_evaluate": false,
|
| 640 |
+
"should_log": false,
|
| 641 |
+
"should_save": true,
|
| 642 |
+
"should_training_stop": true
|
| 643 |
+
},
|
| 644 |
+
"attributes": {}
|
| 645 |
+
}
|
| 646 |
+
},
|
| 647 |
+
"total_flos": 442805396857344.0,
|
| 648 |
+
"train_batch_size": 16,
|
| 649 |
+
"trial_name": null,
|
| 650 |
+
"trial_params": null
|
| 651 |
+
}
|
Deep_Learning_Project/results/checkpoint-837/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4fa1c04a8eee5e3e20cd86241a1b9ef2f6932b1f87ed49b5c7e1a4b6bf1a7ad0
|
| 3 |
+
size 5201
|
Deep_Learning_Project/save_tokenizer.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import DistilBertTokenizer
|
| 2 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 3 |
+
tokenizer.save_pretrained('saved_model')
|
| 4 |
+
|
| 5 |
+
print("Tokenizer saved to saved_model")
|
Deep_Learning_Project/saved_model/config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"DistilBertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.1,
|
| 7 |
+
"bos_token_id": null,
|
| 8 |
+
"dim": 768,
|
| 9 |
+
"dropout": 0.1,
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"eos_token_id": null,
|
| 12 |
+
"hidden_dim": 3072,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"max_position_embeddings": 512,
|
| 15 |
+
"model_type": "distilbert",
|
| 16 |
+
"n_heads": 12,
|
| 17 |
+
"n_layers": 6,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"problem_type": "single_label_classification",
|
| 20 |
+
"qa_dropout": 0.1,
|
| 21 |
+
"seq_classif_dropout": 0.2,
|
| 22 |
+
"sinusoidal_pos_embds": false,
|
| 23 |
+
"tie_weights_": true,
|
| 24 |
+
"tie_word_embeddings": true,
|
| 25 |
+
"transformers_version": "5.3.0",
|
| 26 |
+
"use_cache": false,
|
| 27 |
+
"vocab_size": 30522
|
| 28 |
+
}
|
Deep_Learning_Project/saved_model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c07d77d4bad466798bc928c36a02d91cadbef0c2047def7e445928f1825d959
|
| 3 |
+
size 267832560
|
Deep_Learning_Project/saved_model/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b616fb2661c2ad4dbc6c0ee5c4c1caefd532989e8c399005744f5442e4116ed
|
| 3 |
+
size 535724875
|
Deep_Learning_Project/saved_model/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c0f6bef14c378cf6dadcce9566f94c89a1de5a87f6dab300dc2cfb3620974e5
|
| 3 |
+
size 14455
|
Deep_Learning_Project/saved_model/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0fd612b91f9c87518d4d69e324f23842dc16c0f0824a97f416850db8d65d6175
|
| 3 |
+
size 1465
|
Deep_Learning_Project/saved_model/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Deep_Learning_Project/saved_model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"cls_token": "[CLS]",
|
| 4 |
+
"do_lower_case": true,
|
| 5 |
+
"is_local": false,
|
| 6 |
+
"mask_token": "[MASK]",
|
| 7 |
+
"model_max_length": 512,
|
| 8 |
+
"pad_token": "[PAD]",
|
| 9 |
+
"sep_token": "[SEP]",
|
| 10 |
+
"strip_accents": null,
|
| 11 |
+
"tokenize_chinese_chars": true,
|
| 12 |
+
"tokenizer_class": "DistilBertTokenizer",
|
| 13 |
+
"unk_token": "[UNK]"
|
| 14 |
+
}
|
Deep_Learning_Project/saved_model/trainer_state.json
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": 279,
|
| 3 |
+
"best_metric": 0.04009857773780823,
|
| 4 |
+
"best_model_checkpoint": "./results/checkpoint-279",
|
| 5 |
+
"epoch": 1.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 279,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.035842293906810034,
|
| 14 |
+
"grad_norm": 4.057015895843506,
|
| 15 |
+
"learning_rate": 9e-07,
|
| 16 |
+
"loss": 0.6395920753479004,
|
| 17 |
+
"step": 10
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.07168458781362007,
|
| 21 |
+
"grad_norm": 3.033071279525757,
|
| 22 |
+
"learning_rate": 1.9e-06,
|
| 23 |
+
"loss": 0.6280024528503418,
|
| 24 |
+
"step": 20
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.10752688172043011,
|
| 28 |
+
"grad_norm": 3.566011905670166,
|
| 29 |
+
"learning_rate": 2.9e-06,
|
| 30 |
+
"loss": 0.6011054992675782,
|
| 31 |
+
"step": 30
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.14336917562724014,
|
| 35 |
+
"grad_norm": 1.8344322443008423,
|
| 36 |
+
"learning_rate": 3.9e-06,
|
| 37 |
+
"loss": 0.5574448108673096,
|
| 38 |
+
"step": 40
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.17921146953405018,
|
| 42 |
+
"grad_norm": 1.37172269821167,
|
| 43 |
+
"learning_rate": 4.9000000000000005e-06,
|
| 44 |
+
"loss": 0.5156032562255859,
|
| 45 |
+
"step": 50
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 0.21505376344086022,
|
| 49 |
+
"grad_norm": 1.646365761756897,
|
| 50 |
+
"learning_rate": 5.9e-06,
|
| 51 |
+
"loss": 0.40119266510009766,
|
| 52 |
+
"step": 60
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 0.25089605734767023,
|
| 56 |
+
"grad_norm": 1.6370402574539185,
|
| 57 |
+
"learning_rate": 6.900000000000001e-06,
|
| 58 |
+
"loss": 0.2945317268371582,
|
| 59 |
+
"step": 70
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 0.2867383512544803,
|
| 63 |
+
"grad_norm": 1.153666377067566,
|
| 64 |
+
"learning_rate": 7.9e-06,
|
| 65 |
+
"loss": 0.2097024440765381,
|
| 66 |
+
"step": 80
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 0.3225806451612903,
|
| 70 |
+
"grad_norm": 4.979272365570068,
|
| 71 |
+
"learning_rate": 8.9e-06,
|
| 72 |
+
"loss": 0.1338476538658142,
|
| 73 |
+
"step": 90
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 0.35842293906810035,
|
| 77 |
+
"grad_norm": 1.4230221509933472,
|
| 78 |
+
"learning_rate": 9.900000000000002e-06,
|
| 79 |
+
"loss": 0.07898266315460205,
|
| 80 |
+
"step": 100
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 0.3942652329749104,
|
| 84 |
+
"grad_norm": 0.4248649477958679,
|
| 85 |
+
"learning_rate": 1.09e-05,
|
| 86 |
+
"loss": 0.09649826884269715,
|
| 87 |
+
"step": 110
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 0.43010752688172044,
|
| 91 |
+
"grad_norm": 5.597842216491699,
|
| 92 |
+
"learning_rate": 1.19e-05,
|
| 93 |
+
"loss": 0.10944361686706543,
|
| 94 |
+
"step": 120
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 0.4659498207885305,
|
| 98 |
+
"grad_norm": 0.2298722118139267,
|
| 99 |
+
"learning_rate": 1.29e-05,
|
| 100 |
+
"loss": 0.02345200479030609,
|
| 101 |
+
"step": 130
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 0.5017921146953405,
|
| 105 |
+
"grad_norm": 0.16554060578346252,
|
| 106 |
+
"learning_rate": 1.3900000000000002e-05,
|
| 107 |
+
"loss": 0.10124502182006836,
|
| 108 |
+
"step": 140
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 0.5376344086021505,
|
| 112 |
+
"grad_norm": 0.20511625707149506,
|
| 113 |
+
"learning_rate": 1.49e-05,
|
| 114 |
+
"loss": 0.03878947794437408,
|
| 115 |
+
"step": 150
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 0.5734767025089605,
|
| 119 |
+
"grad_norm": 0.22746935486793518,
|
| 120 |
+
"learning_rate": 1.59e-05,
|
| 121 |
+
"loss": 0.015351028740406036,
|
| 122 |
+
"step": 160
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 0.6093189964157706,
|
| 126 |
+
"grad_norm": 0.09672143310308456,
|
| 127 |
+
"learning_rate": 1.69e-05,
|
| 128 |
+
"loss": 0.08337704539299011,
|
| 129 |
+
"step": 170
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 0.6451612903225806,
|
| 133 |
+
"grad_norm": 0.09403681755065918,
|
| 134 |
+
"learning_rate": 1.79e-05,
|
| 135 |
+
"loss": 0.04203044176101685,
|
| 136 |
+
"step": 180
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 0.6810035842293907,
|
| 140 |
+
"grad_norm": 0.10677973181009293,
|
| 141 |
+
"learning_rate": 1.8900000000000002e-05,
|
| 142 |
+
"loss": 0.0367545485496521,
|
| 143 |
+
"step": 190
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 0.7168458781362007,
|
| 147 |
+
"grad_norm": 0.09498517960309982,
|
| 148 |
+
"learning_rate": 1.9900000000000003e-05,
|
| 149 |
+
"loss": 0.006626572459936142,
|
| 150 |
+
"step": 200
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 0.7526881720430108,
|
| 154 |
+
"grad_norm": 0.058950286358594894,
|
| 155 |
+
"learning_rate": 2.09e-05,
|
| 156 |
+
"loss": 0.004921931400895118,
|
| 157 |
+
"step": 210
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 0.7885304659498208,
|
| 161 |
+
"grad_norm": 0.05829274281859398,
|
| 162 |
+
"learning_rate": 2.19e-05,
|
| 163 |
+
"loss": 0.06366445422172547,
|
| 164 |
+
"step": 220
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 0.8243727598566308,
|
| 168 |
+
"grad_norm": 0.0513160303235054,
|
| 169 |
+
"learning_rate": 2.29e-05,
|
| 170 |
+
"loss": 0.0034723080694675445,
|
| 171 |
+
"step": 230
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 0.8602150537634409,
|
| 175 |
+
"grad_norm": 8.491065979003906,
|
| 176 |
+
"learning_rate": 2.39e-05,
|
| 177 |
+
"loss": 0.06077749729156494,
|
| 178 |
+
"step": 240
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 0.8960573476702509,
|
| 182 |
+
"grad_norm": 0.05915055796504021,
|
| 183 |
+
"learning_rate": 2.4900000000000002e-05,
|
| 184 |
+
"loss": 0.009128168970346451,
|
| 185 |
+
"step": 250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 0.931899641577061,
|
| 189 |
+
"grad_norm": 10.863852500915527,
|
| 190 |
+
"learning_rate": 2.5900000000000003e-05,
|
| 191 |
+
"loss": 0.09072564840316773,
|
| 192 |
+
"step": 260
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"epoch": 0.967741935483871,
|
| 196 |
+
"grad_norm": 0.42785295844078064,
|
| 197 |
+
"learning_rate": 2.6900000000000003e-05,
|
| 198 |
+
"loss": 0.03498620092868805,
|
| 199 |
+
"step": 270
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"epoch": 1.0,
|
| 203 |
+
"eval_accuracy": 0.9910313901345291,
|
| 204 |
+
"eval_f1": 0.9657534246575342,
|
| 205 |
+
"eval_loss": 0.04009857773780823,
|
| 206 |
+
"eval_precision": 0.986013986013986,
|
| 207 |
+
"eval_recall": 0.9463087248322147,
|
| 208 |
+
"eval_runtime": 32.5434,
|
| 209 |
+
"eval_samples_per_second": 34.262,
|
| 210 |
+
"eval_steps_per_second": 0.553,
|
| 211 |
+
"step": 279
|
| 212 |
+
}
|
| 213 |
+
],
|
| 214 |
+
"logging_steps": 10,
|
| 215 |
+
"max_steps": 837,
|
| 216 |
+
"num_input_tokens_seen": 0,
|
| 217 |
+
"num_train_epochs": 3,
|
| 218 |
+
"save_steps": 500,
|
| 219 |
+
"stateful_callbacks": {
|
| 220 |
+
"TrainerControl": {
|
| 221 |
+
"args": {
|
| 222 |
+
"should_epoch_stop": false,
|
| 223 |
+
"should_evaluate": false,
|
| 224 |
+
"should_log": false,
|
| 225 |
+
"should_save": true,
|
| 226 |
+
"should_training_stop": false
|
| 227 |
+
},
|
| 228 |
+
"attributes": {}
|
| 229 |
+
}
|
| 230 |
+
},
|
| 231 |
+
"total_flos": 147601798952448.0,
|
| 232 |
+
"train_batch_size": 16,
|
| 233 |
+
"trial_name": null,
|
| 234 |
+
"trial_params": null
|
| 235 |
+
}
|
Deep_Learning_Project/saved_model/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab8b54e46b16083bb42102640d659199bfe96644c6fe4869e1135f8028f75f5c
|
| 3 |
+
size 5201
|
Deep_Learning_Project/train_model.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, get_linear_schedule_with_warmup
|
| 5 |
+
from torch.optim import AdamW
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
|
| 8 |
+
import numpy as np
|
| 9 |
+
import time
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# 1. Load and Preprocess Data
|
| 13 |
+
df = pd.read_csv('mail_data.csv', names=['Category', 'Message'], header=None, skiprows=1)
|
| 14 |
+
df['label'] = df['Category'].map({'ham': 0, 'spam': 1})
|
| 15 |
+
|
| 16 |
+
train_texts, test_texts, train_labels, test_labels = train_test_split(
|
| 17 |
+
df['Message'].values, df['label'].values, test_size=0.2, random_state=42, stratify=df['label'].values
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# 2. Dataset Class
|
| 21 |
+
class EmailDataset(Dataset):
|
| 22 |
+
def __init__(self, texts, labels, tokenizer, max_len=128):
|
| 23 |
+
self.texts = texts
|
| 24 |
+
self.labels = labels
|
| 25 |
+
self.tokenizer = tokenizer
|
| 26 |
+
self.max_len = max_len
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return len(self.texts)
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, item):
|
| 32 |
+
text = str(self.texts[item])
|
| 33 |
+
label = self.labels[item]
|
| 34 |
+
encoding = self.tokenizer._encode_plus(
|
| 35 |
+
text,
|
| 36 |
+
add_special_tokens=True,
|
| 37 |
+
max_length=self.max_len,
|
| 38 |
+
return_token_type_ids=False,
|
| 39 |
+
padding='max_length',
|
| 40 |
+
truncation=True,
|
| 41 |
+
return_attention_mask=True,
|
| 42 |
+
return_tensors='pt',
|
| 43 |
+
)
|
| 44 |
+
return {
|
| 45 |
+
'text': text,
|
| 46 |
+
'input_ids': encoding['input_ids'].flatten(),
|
| 47 |
+
'attention_mask': encoding['attention_mask'].flatten(),
|
| 48 |
+
'labels': torch.tensor(label, dtype=torch.long)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# 3. Setup Training
|
| 52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
+
print(f"Using device: {device}")
|
| 54 |
+
|
| 55 |
+
PRE_TRAINED_MODEL_NAME = 'distilbert-base-uncased'
|
| 56 |
+
tokenizer = DistilBertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
|
| 57 |
+
|
| 58 |
+
train_data_loader = DataLoader(EmailDataset(train_texts, train_labels, tokenizer), batch_size=16, shuffle=True)
|
| 59 |
+
test_data_loader = DataLoader(EmailDataset(test_texts, test_labels, tokenizer), batch_size=16, shuffle=False)
|
| 60 |
+
|
| 61 |
+
model = DistilBertForSequenceClassification.from_pretrained(PRE_TRAINED_MODEL_NAME, num_labels=2)
|
| 62 |
+
model = model.to(device)
|
| 63 |
+
|
| 64 |
+
EPOCHS = 3
|
| 65 |
+
optimizer = AdamW(model.parameters(), lr=2e-5)
|
| 66 |
+
total_steps = len(train_data_loader) * EPOCHS
|
| 67 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
|
| 68 |
+
loss_fn = torch.nn.CrossEntropyLoss().to(device)
|
| 69 |
+
|
| 70 |
+
# 4. Training Loop
|
| 71 |
+
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples):
|
| 72 |
+
model = model.train()
|
| 73 |
+
losses = []
|
| 74 |
+
correct_predictions = 0
|
| 75 |
+
for d in data_loader:
|
| 76 |
+
input_ids = d["input_ids"].to(device)
|
| 77 |
+
attention_mask = d["attention_mask"].to(device)
|
| 78 |
+
labels = d["labels"].to(device)
|
| 79 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 80 |
+
loss = outputs.loss
|
| 81 |
+
logits = outputs.logits
|
| 82 |
+
_, preds = torch.max(logits, dim=1)
|
| 83 |
+
correct_predictions += torch.sum(preds == labels)
|
| 84 |
+
losses.append(loss.item())
|
| 85 |
+
loss.backward()
|
| 86 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 87 |
+
optimizer.step()
|
| 88 |
+
scheduler.step()
|
| 89 |
+
optimizer.zero_grad()
|
| 90 |
+
return correct_predictions.double() / n_examples, np.mean(losses)
|
| 91 |
+
|
| 92 |
+
def eval_model(model, data_loader, loss_fn, device, n_examples):
|
| 93 |
+
model = model.eval()
|
| 94 |
+
losses = []
|
| 95 |
+
correct_predictions = 0
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
for d in data_loader:
|
| 98 |
+
input_ids = d["input_ids"].to(device)
|
| 99 |
+
attention_mask = d["attention_mask"].to(device)
|
| 100 |
+
labels = d["labels"].to(device)
|
| 101 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 102 |
+
loss = outputs.loss
|
| 103 |
+
logits = outputs.logits
|
| 104 |
+
_, preds = torch.max(logits, dim=1)
|
| 105 |
+
correct_predictions += torch.sum(preds == labels)
|
| 106 |
+
losses.append(loss.item())
|
| 107 |
+
return correct_predictions.double() / n_examples, np.mean(losses)
|
| 108 |
+
|
| 109 |
+
print("Starting training...")
|
| 110 |
+
for epoch in range(EPOCHS):
|
| 111 |
+
print(f'Epoch {epoch + 1}/{EPOCHS}')
|
| 112 |
+
train_acc, train_loss = train_epoch(model, train_data_loader, loss_fn, optimizer, device, scheduler, len(train_texts))
|
| 113 |
+
print(f'Train loss {train_loss} accuracy {train_acc}')
|
| 114 |
+
val_acc, val_loss = eval_model(model, test_data_loader, loss_fn, device, len(test_texts))
|
| 115 |
+
print(f'Val loss {val_loss} accuracy {val_acc}')
|
| 116 |
+
|
| 117 |
+
# 5. Final Evaluation
|
| 118 |
+
def get_predictions(model, data_loader):
|
| 119 |
+
model = model.eval()
|
| 120 |
+
messages = []
|
| 121 |
+
predictions = []
|
| 122 |
+
prediction_probs = []
|
| 123 |
+
real_values = []
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
for d in data_loader:
|
| 126 |
+
texts = d["text"]
|
| 127 |
+
input_ids = d["input_ids"].to(device)
|
| 128 |
+
attention_mask = d["attention_mask"].to(device)
|
| 129 |
+
labels = d["labels"].to(device)
|
| 130 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 131 |
+
logits = outputs.logits
|
| 132 |
+
_, preds = torch.max(logits, dim=1)
|
| 133 |
+
messages.extend(texts)
|
| 134 |
+
predictions.extend(preds)
|
| 135 |
+
prediction_probs.extend(logits)
|
| 136 |
+
real_values.extend(labels)
|
| 137 |
+
predictions = torch.stack(predictions).cpu()
|
| 138 |
+
real_values = torch.stack(real_values).cpu()
|
| 139 |
+
return messages, predictions, real_values
|
| 140 |
+
|
| 141 |
+
y_review_texts, y_pred, y_test = get_predictions(model, test_data_loader)
|
| 142 |
+
print("\nClassification Report:\n", classification_report(y_test, y_pred, target_names=['ham', 'spam']))
|
| 143 |
+
|
| 144 |
+
# Save results for report
|
| 145 |
+
with open('results.txt', 'w') as f:
|
| 146 |
+
f.write(f"Accuracy: {accuracy_score(y_test, y_pred)}\n")
|
| 147 |
+
f.write("\nClassification Report:\n")
|
| 148 |
+
f.write(classification_report(y_test, y_pred, target_names=['ham', 'spam']))
|
| 149 |
+
f.write("\nConfusion Matrix:\n")
|
| 150 |
+
f.write(str(confusion_matrix(y_test, y_pred)))
|
| 151 |
+
|
| 152 |
+
print("Training complete. Results saved to results.txt")
|
Deep_Learning_Project/train_model_hf.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
|
| 5 |
+
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
# 1. Load and Preprocess Data
|
| 9 |
+
df = pd.read_csv('mail_data.csv', names=['Category', 'Message'], header=None, skiprows=1)
|
| 10 |
+
df['label'] = df['Category'].map({'ham': 0, 'spam': 1})
|
| 11 |
+
|
| 12 |
+
train_texts, test_texts, train_labels, test_labels = train_test_split(
|
| 13 |
+
df['Message'].values.tolist(), df['label'].values.tolist(), test_size=0.2, random_state=42, stratify=df['label'].values
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# 2. Tokenization
|
| 17 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 18 |
+
|
| 19 |
+
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128)
|
| 20 |
+
test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=128)
|
| 21 |
+
|
| 22 |
+
class EmailDataset(torch.utils.data.Dataset):
|
| 23 |
+
def __init__(self, encodings, labels):
|
| 24 |
+
self.encodings = encodings
|
| 25 |
+
self.labels = labels
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, idx):
|
| 28 |
+
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
| 29 |
+
item['labels'] = torch.tensor(self.labels[idx])
|
| 30 |
+
return item
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.labels)
|
| 34 |
+
|
| 35 |
+
train_dataset = EmailDataset(train_encodings, train_labels)
|
| 36 |
+
test_dataset = EmailDataset(test_encodings, test_labels)
|
| 37 |
+
|
| 38 |
+
# 3. Model and Metrics
|
| 39 |
+
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
|
| 40 |
+
|
| 41 |
+
def compute_metrics(pred):
|
| 42 |
+
labels = pred.label_ids
|
| 43 |
+
preds = pred.predictions.argmax(-1)
|
| 44 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
|
| 45 |
+
acc = accuracy_score(labels, preds)
|
| 46 |
+
return {
|
| 47 |
+
'accuracy': acc,
|
| 48 |
+
'f1': f1,
|
| 49 |
+
'precision': precision,
|
| 50 |
+
'recall': recall
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# 4. Training Arguments
|
| 54 |
+
training_args = TrainingArguments(
|
| 55 |
+
output_dir='./results',
|
| 56 |
+
num_train_epochs=3,
|
| 57 |
+
per_device_train_batch_size=16,
|
| 58 |
+
per_device_eval_batch_size=64,
|
| 59 |
+
warmup_steps=500,
|
| 60 |
+
weight_decay=0.01,
|
| 61 |
+
logging_dir='./logs',
|
| 62 |
+
logging_steps=10,
|
| 63 |
+
eval_strategy="epoch",
|
| 64 |
+
save_strategy="epoch",
|
| 65 |
+
load_best_model_at_end=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# 5. Trainer
|
| 69 |
+
trainer = Trainer(
|
| 70 |
+
model=model,
|
| 71 |
+
args=training_args,
|
| 72 |
+
train_dataset=train_dataset,
|
| 73 |
+
eval_dataset=test_dataset,
|
| 74 |
+
compute_metrics=compute_metrics,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
print("Starting training with HF Trainer...")
|
| 78 |
+
trainer.train()
|
| 79 |
+
|
| 80 |
+
# 6. Evaluation
|
| 81 |
+
print("Evaluating...")
|
| 82 |
+
eval_results = trainer.evaluate()
|
| 83 |
+
print(eval_results)
|
| 84 |
+
|
| 85 |
+
# Final predictions for detailed report
|
| 86 |
+
predictions = trainer.predict(test_dataset)
|
| 87 |
+
preds = predictions.predictions.argmax(-1)
|
| 88 |
+
labels = predictions.label_ids
|
| 89 |
+
|
| 90 |
+
from sklearn.metrics import classification_report
|
| 91 |
+
report = classification_report(labels, preds, target_names=['ham', 'spam'])
|
| 92 |
+
cm = confusion_matrix(labels, preds)
|
| 93 |
+
|
| 94 |
+
with open('results.txt', 'w') as f:
|
| 95 |
+
f.write(f"Final Evaluation Results:\n{eval_results}\n")
|
| 96 |
+
f.write(f"\nClassification Report:\n{report}\n")
|
| 97 |
+
f.write(f"\nConfusion Matrix:\n{cm}\n")
|
| 98 |
+
|
| 99 |
+
print("Training complete. Results saved to results.txt")
|
Deep_Learning_Project/~$ep Learning Project Report.docx
ADDED
|
Binary file (162 Bytes). View file
|
|
|