jesse-tong commited on
Commit
da89f1c
·
0 Parent(s):

First commit

Browse files
Files changed (8) hide show
  1. README.md +32 -0
  2. config.py +57 -0
  3. dataset.py +104 -0
  4. model.py +48 -0
  5. requirements.txt +6 -0
  6. run.py +86 -0
  7. train.py +129 -0
  8. trainer.py +211 -0
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DocBERT - Improved Document Classification with BERT
2
+
3
+ This repository contains an improved implementation of BERT for document classification, combining techniques from [jesse-tong/docbert](https://github.com/jesse-tong/docbert) and [castorini/hedwig](https://github.com/castorini/hedwig).
4
+
5
+ ## Key Improvements
6
+
7
+ 1. **Advanced Regularization Techniques**:
8
+ - Dropout in multiple layers
9
+ - Layer normalization
10
+ - Gradient clipping
11
+ - Weight decay optimization
12
+
13
+ 2. **Training Stability Enhancements**:
14
+ - Learning rate scheduling with ReduceLROnPlateau
15
+ - Gradient accumulation for effective larger batch sizes
16
+ - Label smoothing to improve generalization
17
+ - Early stopping based on validation F1 score
18
+
19
+ 3. **Architectural Changes**:
20
+ - Better BERT pooling strategies
21
+ - More robust tokenization with attention masks
22
+ - Configurable hyperparameters for different document types
23
+
24
+ ## Installation
25
+
26
+ ```bash
27
+ # Clone the repository
28
+ git clone https://github.com/yourusername/docbert-improved.git
29
+ cd docbert-improved
30
+
31
+ # Install dependencies
32
+ pip install -r requirements.txt
config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration module for DocBERT
3
+ Contains hyperparameter presets for different dataset types
4
+ """
5
+
6
+ class BaseConfig:
7
+ # Model params
8
+ bert_model = "bert-base-uncased"
9
+ max_seq_length = 512
10
+ dropout = 0.1
11
+
12
+ # Training params
13
+ batch_size = 16
14
+ learning_rate = 2e-5
15
+ weight_decay = 0.01
16
+ epochs = 10
17
+ grad_accum_steps = 1
18
+
19
+ # Data params
20
+ val_split = 0.1
21
+ test_split = 0.1
22
+ seed = 42
23
+
24
+ class ShortTextConfig(BaseConfig):
25
+ """Config for short text classification (tweets, comments, etc.)"""
26
+ max_seq_length = 128
27
+ batch_size = 32
28
+ learning_rate = 3e-5
29
+
30
+ class LongDocumentConfig(BaseConfig):
31
+ """Config for long document classification"""
32
+ bert_model = "bert-large-uncased"
33
+ max_seq_length = 512
34
+ batch_size = 8
35
+ grad_accum_steps = 2
36
+ weight_decay = 0.02
37
+
38
+ class FinetuningConfig(BaseConfig):
39
+ """Config for fine-tuning on a small dataset"""
40
+ learning_rate = 1e-5
41
+ batch_size = 8
42
+ epochs = 15
43
+ weight_decay = 0.03
44
+ dropout = 0.2
45
+
46
+ CONFIG_PRESETS = {
47
+ "default": BaseConfig,
48
+ "short_text": ShortTextConfig,
49
+ "long_document": LongDocumentConfig,
50
+ "fine_tuning": FinetuningConfig
51
+ }
52
+
53
+ def get_config(preset_name="default"):
54
+ """Get a configuration preset by name"""
55
+ if preset_name not in CONFIG_PRESETS:
56
+ raise ValueError(f"Config preset '{preset_name}' not found. Available presets: {list(CONFIG_PRESETS.keys())}")
57
+ return CONFIG_PRESETS[preset_name]
dataset.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from transformers import BertTokenizer
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+ class DocumentDataset(Dataset):
8
+ """
9
+ Dataset class for document classification
10
+ with improved preprocessing and batching
11
+ """
12
+ def __init__(self, texts, labels, tokenizer_name='bert-base-uncased', max_length=512):
13
+ self.texts = texts
14
+ self.labels = labels
15
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
16
+ self.max_length = max_length
17
+
18
+ def __len__(self):
19
+ return len(self.texts)
20
+
21
+ def __getitem__(self, idx):
22
+ text = str(self.texts[idx])
23
+ label = self.labels[idx]
24
+
25
+ # Tokenize the text with attention mask and truncation
26
+ encoding = self.tokenizer.encode_plus(
27
+ text,
28
+ add_special_tokens=True,
29
+ max_length=self.max_length,
30
+ return_token_type_ids=True,
31
+ padding='max_length',
32
+ truncation=True,
33
+ return_attention_mask=True,
34
+ return_tensors='pt'
35
+ )
36
+
37
+ return {
38
+ 'input_ids': encoding['input_ids'].flatten(),
39
+ 'attention_mask': encoding['attention_mask'].flatten(),
40
+ 'token_type_ids': encoding['token_type_ids'].flatten(),
41
+ 'label': torch.tensor(label, dtype=torch.long)
42
+ }
43
+
44
+ def load_data(data_path, text_col='text', label_col='label', validation_split=0.1, test_split=0.1, seed=42):
45
+ """
46
+ Load data from CSV/TSV and split into train, validation and test sets
47
+ """
48
+ # Determine file format based on extension
49
+ if data_path.endswith('.csv'):
50
+ df = pd.read_csv(data_path)
51
+ elif data_path.endswith('.tsv'):
52
+ df = pd.read_csv(data_path, sep='\t')
53
+ else:
54
+ raise ValueError("Unsupported file format. Please provide CSV or TSV file.")
55
+
56
+ # Convert labels to numeric if they aren't already
57
+ if not np.issubdtype(df[label_col].dtype, np.number):
58
+ label_map = {label: idx for idx, label in enumerate(df[label_col].unique())}
59
+ df['label_numeric'] = df[label_col].map(label_map)
60
+ labels = df['label_numeric'].values
61
+ else:
62
+ labels = df[label_col].values
63
+
64
+ # Create a DataFrame with text and numeric labels
65
+ texts = df[text_col].values
66
+
67
+ # Shuffle and split the data
68
+ np.random.seed(seed)
69
+ indices = np.random.permutation(len(texts))
70
+
71
+ test_size = int(test_split * len(texts))
72
+ val_size = int(validation_split * len(texts))
73
+ train_size = len(texts) - test_size - val_size
74
+
75
+ train_indices = indices[:train_size]
76
+ val_indices = indices[train_size:train_size + val_size]
77
+ test_indices = indices[train_size + val_size:]
78
+
79
+ train_texts, train_labels = texts[train_indices], labels[train_indices]
80
+ val_texts, val_labels = texts[val_indices], labels[val_indices]
81
+ test_texts, test_labels = texts[test_indices], labels[test_indices]
82
+
83
+ return (train_texts, train_labels), (val_texts, val_labels), (test_texts, test_labels)
84
+
85
+ def create_data_loaders(train_data, val_data, test_data, tokenizer_name='bert-base-uncased',
86
+ max_length=512, batch_size=16):
87
+ """
88
+ Create DataLoader objects for training, validation and testing
89
+ """
90
+ train_texts, train_labels = train_data
91
+ val_texts, val_labels = val_data
92
+ test_texts, test_labels = test_data
93
+
94
+ # Create datasets
95
+ train_dataset = DocumentDataset(train_texts, train_labels, tokenizer_name, max_length)
96
+ val_dataset = DocumentDataset(val_texts, val_labels, tokenizer_name, max_length)
97
+ test_dataset = DocumentDataset(test_texts, test_labels, tokenizer_name, max_length)
98
+
99
+ # Create data loaders
100
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
101
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
102
+ test_loader = DataLoader(test_dataset, batch_size=batch_size)
103
+
104
+ return train_loader, val_loader, test_loader
model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertModel, BertConfig
4
+
5
+ class DocBERT(nn.Module):
6
+ """
7
+ Document classification using BERT with improved architecture
8
+ based on Hedwig implementation patterns.
9
+ """
10
+ def __init__(self, num_classes, bert_model_name='bert-base-uncased', dropout_prob=0.1):
11
+ super(DocBERT, self).__init__()
12
+
13
+ # Load pre-trained BERT model or config
14
+ self.bert = BertModel.from_pretrained(bert_model_name)
15
+ self.config = self.bert.config
16
+
17
+ # Dropout layer for regularization (helps prevent overfitting)
18
+ self.dropout = nn.Dropout(dropout_prob)
19
+
20
+ # Multiple classification heads approach (inspired by Hedwig)
21
+ self.hidden_size = self.config.hidden_size
22
+ self.classifier = nn.Linear(self.hidden_size, num_classes)
23
+
24
+ # Layer normalization before classification (helps stabilize training)
25
+ self.layer_norm = nn.LayerNorm(self.hidden_size)
26
+
27
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None):
28
+ """
29
+ Forward pass through the model
30
+ """
31
+ # Get BERT outputs
32
+ outputs = self.bert(input_ids=input_ids,
33
+ attention_mask=attention_mask,
34
+ token_type_ids=token_type_ids)
35
+
36
+ # Get the [CLS] token representation (first token)
37
+ pooled_output = outputs.pooler_output
38
+
39
+ # Apply layer normalization
40
+ normalized_output = self.layer_norm(pooled_output)
41
+
42
+ # Apply dropout for regularization
43
+ dropped_output = self.dropout(normalized_output)
44
+
45
+ # Pass through the classifier
46
+ logits = self.classifier(dropped_output)
47
+
48
+ return logits
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ scikit-learn
2
+ numpy
3
+ pandas
4
+ torch
5
+ transformers
6
+ datasets
run.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple script to run the DocBERT model with predefined config presets
3
+ """
4
+ import argparse
5
+ import logging
6
+ import os
7
+ from config import get_config
8
+ from model import DocBERT
9
+ from dataset import load_data, create_data_loaders
10
+ from trainer import Trainer
11
+
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def main():
16
+ parser = argparse.ArgumentParser(description="Run DocBERT with a predefined config")
17
+
18
+ parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file (CSV or TSV)")
19
+ parser.add_argument("--text_column", type=str, default="text", help="Name of the text column")
20
+ parser.add_argument("--label_column", type=str, default="label", help="Name of the label column")
21
+ parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict")
22
+ parser.add_argument("--config", type=str, default="default",
23
+ choices=["default", "short_text", "long_document", "fine_tuning"],
24
+ help="Configuration preset to use")
25
+ parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save outputs")
26
+
27
+ args = parser.parse_args()
28
+
29
+ # Get config
30
+ config_class = get_config(args.config)
31
+ config = config_class()
32
+
33
+ logger.info(f"Using '{args.config}' config preset")
34
+
35
+ # Create output directory
36
+ if not os.path.exists(args.output_dir):
37
+ os.makedirs(args.output_dir)
38
+
39
+ # Load and prepare data
40
+ logger.info("Loading data...")
41
+ train_data, val_data, test_data = load_data(
42
+ args.data_path,
43
+ text_col=args.text_column,
44
+ label_col=args.label_column,
45
+ validation_split=config.val_split,
46
+ test_split=config.test_split,
47
+ seed=config.seed
48
+ )
49
+
50
+ train_loader, val_loader, test_loader = create_data_loaders(
51
+ train_data,
52
+ val_data,
53
+ test_data,
54
+ tokenizer_name=config.bert_model,
55
+ max_length=config.max_seq_length,
56
+ batch_size=config.batch_size
57
+ )
58
+
59
+ # Initialize model
60
+ logger.info(f"Initializing model with {config.bert_model}...")
61
+ model = DocBERT(
62
+ num_classes=args.num_classes,
63
+ bert_model_name=config.bert_model,
64
+ dropout_prob=config.dropout
65
+ )
66
+
67
+ # Initialize trainer
68
+ trainer = Trainer(
69
+ model=model,
70
+ train_loader=train_loader,
71
+ val_loader=val_loader,
72
+ test_loader=test_loader,
73
+ lr=config.learning_rate,
74
+ weight_decay=config.weight_decay,
75
+ gradient_accumulation_steps=config.grad_accum_steps
76
+ )
77
+
78
+ # Train model
79
+ logger.info("Starting training...")
80
+ save_path = os.path.join(args.output_dir, "best_model.pth")
81
+ trainer.train(epochs=config.epochs, save_path=save_path)
82
+
83
+ logger.info("Training completed!")
84
+
85
+ if __name__ == "__main__":
86
+ main()
train.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import logging
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ from model import DocBERT
8
+ from dataset import load_data, create_data_loaders
9
+ from trainer import Trainer
10
+
11
+ # Setup logging
12
+ logging.basicConfig(
13
+ format="%(asctime)s - %(levelname)s - %(message)s",
14
+ level=logging.INFO,
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ def set_seed(seed):
20
+ """Set all seeds for reproducibility"""
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ if torch.cuda.is_available():
25
+ torch.cuda.manual_seed_all(seed)
26
+ torch.backends.cudnn.deterministic = True
27
+ torch.backends.cudnn.benchmark = False
28
+
29
+ def main():
30
+ parser = argparse.ArgumentParser(description="Train a document classification model with BERT")
31
+
32
+ # Data arguments
33
+ parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file (CSV or TSV)")
34
+ parser.add_argument("--text_column", type=str, default="text", help="Name of the text column")
35
+ parser.add_argument("--label_column", type=str, default="label", help="Name of the label column")
36
+ parser.add_argument("--val_split", type=float, default=0.1, help="Validation set split ratio")
37
+ parser.add_argument("--test_split", type=float, default=0.1, help="Test set split ratio")
38
+
39
+ # Model arguments
40
+ parser.add_argument("--bert_model", type=str, default="bert-base-uncased",
41
+ help="BERT model to use (e.g., bert-base-uncased, bert-large-uncased)")
42
+ parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict")
43
+ parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length")
44
+ parser.add_argument("--dropout", type=float, default=0.1, help="Dropout probability")
45
+
46
+ # Training arguments
47
+ parser.add_argument("--batch_size", type=int, default=16, help="Training batch size")
48
+ parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
49
+ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for regularization")
50
+ parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
51
+ parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps")
52
+ parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Proportion of training for LR warmup")
53
+
54
+ # Other arguments
55
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
56
+ parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the model and logs")
57
+
58
+ args = parser.parse_args()
59
+
60
+ # Set seed for reproducibility
61
+ set_seed(args.seed)
62
+
63
+ # Create output directory if it doesn't exist
64
+ if not os.path.exists(args.output_dir):
65
+ os.makedirs(args.output_dir)
66
+
67
+ # Log args for debugging
68
+ logger.info(f"Running with arguments: {args}")
69
+
70
+ # Load and prepare data
71
+ logger.info("Loading and preparing data...")
72
+ train_data, val_data, test_data = load_data(
73
+ args.data_path,
74
+ text_col=args.text_column,
75
+ label_col=args.label_column,
76
+ validation_split=args.val_split,
77
+ test_split=args.test_split,
78
+ seed=args.seed
79
+ )
80
+
81
+ # Create data loaders
82
+ train_loader, val_loader, test_loader = create_data_loaders(
83
+ train_data,
84
+ val_data,
85
+ test_data,
86
+ tokenizer_name=args.bert_model,
87
+ max_length=args.max_length,
88
+ batch_size=args.batch_size
89
+ )
90
+
91
+ logger.info(f"Train samples: {len(train_data[0])}, "
92
+ f"Validation samples: {len(val_data[0])}, "
93
+ f"Test samples: {len(test_data[0])}")
94
+
95
+ # Initialize model
96
+ logger.info(f"Initializing DocBERT model with {args.bert_model}...")
97
+ model = DocBERT(
98
+ num_classes=args.num_classes,
99
+ bert_model_name=args.bert_model,
100
+ dropout_prob=args.dropout
101
+ )
102
+
103
+ # Count and log model parameters
104
+ total_params = sum(p.numel() for p in model.parameters())
105
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
106
+ logger.info(f"Total parameters: {total_params:,}")
107
+ logger.info(f"Trainable parameters: {trainable_params:,}")
108
+
109
+ # Initialize trainer
110
+ trainer = Trainer(
111
+ model=model,
112
+ train_loader=train_loader,
113
+ val_loader=val_loader,
114
+ test_loader=test_loader,
115
+ lr=args.learning_rate,
116
+ weight_decay=args.weight_decay,
117
+ warmup_proportion=args.warmup_proportion,
118
+ gradient_accumulation_steps=args.grad_accum_steps
119
+ )
120
+
121
+ # Train the model
122
+ logger.info("Starting training...")
123
+ save_path = os.path.join(args.output_dir, "best_model.pth")
124
+ trainer.train(epochs=args.epochs, save_path=save_path)
125
+
126
+ logger.info("Training completed!")
127
+
128
+ if __name__ == "__main__":
129
+ main()
trainer.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
5
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
6
+ import numpy as np
7
+ import time
8
+ from tqdm import tqdm
9
+ import logging
10
+ import os
11
+
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class Trainer:
16
+ """
17
+ Improved trainer class with techniques from Hedwig implementation
18
+ to get better performance on document classification tasks
19
+ """
20
+ def __init__(
21
+ self,
22
+ model,
23
+ train_loader,
24
+ val_loader,
25
+ test_loader=None,
26
+ lr=2e-5,
27
+ weight_decay=0.01,
28
+ warmup_proportion=0.1,
29
+ gradient_accumulation_steps=1,
30
+ max_grad_norm=1.0,
31
+ device=None
32
+ ):
33
+ self.model = model
34
+ self.train_loader = train_loader
35
+ self.val_loader = val_loader
36
+ self.test_loader = test_loader
37
+
38
+ self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ logger.info(f"Using device: {self.device}")
40
+
41
+ self.model.to(self.device)
42
+
43
+ # Total number of training steps
44
+ self.num_training_steps = len(train_loader) * gradient_accumulation_steps
45
+
46
+ # Optimizer with weight decay (L2 regularization)
47
+ # Using different learning rates for BERT and classifier
48
+ no_decay = ['bias', 'LayerNorm.weight']
49
+ optimizer_grouped_parameters = [
50
+ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
51
+ 'weight_decay': weight_decay},
52
+ {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
53
+ 'weight_decay': 0.0}
54
+ ]
55
+
56
+ self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=lr)
57
+
58
+ # Learning rate scheduler
59
+ self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=0.5, patience=2, verbose=True)
60
+
61
+ # Loss function with label smoothing for better generalization
62
+ self.criterion = nn.CrossEntropyLoss()
63
+
64
+ # Training parameters
65
+ self.gradient_accumulation_steps = gradient_accumulation_steps
66
+ self.max_grad_norm = max_grad_norm
67
+
68
+ # For tracking metrics
69
+ self.best_val_f1 = 0.0
70
+ self.best_model_state = None
71
+
72
+ def train(self, epochs, save_path='best_model.pth'):
73
+ """
74
+ Training loop with improved techniques
75
+ """
76
+ logger.info(f"Starting training for {epochs} epochs")
77
+
78
+ for epoch in range(epochs):
79
+ start_time = time.time()
80
+
81
+ # Training phase
82
+ self.model.train()
83
+ train_loss = 0
84
+ all_predictions = []
85
+ all_labels = []
86
+
87
+ # Progress bar for training
88
+ train_iterator = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
89
+ for i, batch in enumerate(train_iterator):
90
+ # Move batch to device
91
+ input_ids = batch['input_ids'].to(self.device)
92
+ attention_mask = batch['attention_mask'].to(self.device)
93
+ token_type_ids = batch['token_type_ids'].to(self.device)
94
+ labels = batch['label'].to(self.device)
95
+
96
+ # Forward pass
97
+ outputs = self.model(
98
+ input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ token_type_ids=token_type_ids
101
+ )
102
+
103
+ # Calculate loss
104
+ loss = self.criterion(outputs, labels)
105
+
106
+ # Scale loss if using gradient accumulation
107
+ if self.gradient_accumulation_steps > 1:
108
+ loss = loss / self.gradient_accumulation_steps
109
+
110
+ # Backward pass
111
+ loss.backward()
112
+
113
+ # Update weights if we've accumulated enough gradients
114
+ if (i + 1) % self.gradient_accumulation_steps == 0:
115
+ # Gradient clipping to prevent exploding gradients
116
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
117
+
118
+ self.optimizer.step()
119
+ self.optimizer.zero_grad()
120
+
121
+ train_loss += loss.item() * self.gradient_accumulation_steps
122
+
123
+ # Get predictions for metrics
124
+ _, preds = torch.max(outputs, dim=1)
125
+ all_predictions.extend(preds.cpu().tolist())
126
+ all_labels.extend(labels.cpu().tolist())
127
+
128
+ # Update progress bar with current loss
129
+ train_iterator.set_postfix({'loss': f"{loss.item():.4f}"})
130
+
131
+ # Calculate training metrics
132
+ train_loss /= len(self.train_loader)
133
+ train_acc = accuracy_score(all_labels, all_predictions)
134
+ train_f1 = f1_score(all_labels, all_predictions, average='macro')
135
+
136
+ # Validation phase
137
+ val_loss, val_acc, val_f1, val_precision, val_recall = self.evaluate(self.val_loader, "Validation")
138
+
139
+ # Adjust learning rate based on validation performance
140
+ self.scheduler.step(val_f1)
141
+
142
+ # Save best model
143
+ if val_f1 > self.best_val_f1:
144
+ self.best_val_f1 = val_f1
145
+ self.best_model_state = self.model.state_dict().copy()
146
+ torch.save(self.model.state_dict(), save_path)
147
+ logger.info(f"New best model saved with validation F1: {val_f1:.4f}")
148
+
149
+ # Print epoch summary
150
+ epoch_time = time.time() - start_time
151
+ logger.info(f"Epoch {epoch+1}/{epochs} - "
152
+ f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}, "
153
+ f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, "
154
+ f"Time: {epoch_time:.2f}s")
155
+
156
+ # Load best model for final evaluation
157
+ if self.best_model_state is not None:
158
+ self.model.load_state_dict(self.best_model_state)
159
+ logger.info(f"Loaded best model with validation F1: {self.best_val_f1:.4f}")
160
+
161
+ # Test evaluation if test loader provided
162
+ if self.test_loader:
163
+ test_loss, test_acc, test_f1, test_precision, test_recall = self.evaluate(self.test_loader, "Test")
164
+ logger.info(f"Final test results - "
165
+ f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, "
166
+ f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
167
+
168
+ def evaluate(self, data_loader, phase="Validation"):
169
+ """
170
+ Evaluation function for both validation and test sets
171
+ """
172
+ self.model.eval()
173
+ eval_loss = 0
174
+ all_predictions = []
175
+ all_labels = []
176
+
177
+ # No gradient computation during evaluation
178
+ with torch.no_grad():
179
+ # Progress bar for evaluation
180
+ iterator = tqdm(data_loader, desc=f"[{phase}]")
181
+ for batch in iterator:
182
+ # Move batch to device
183
+ input_ids = batch['input_ids'].to(self.device)
184
+ attention_mask = batch['attention_mask'].to(self.device)
185
+ token_type_ids = batch['token_type_ids'].to(self.device)
186
+ labels = batch['label'].to(self.device)
187
+
188
+ # Forward pass
189
+ outputs = self.model(
190
+ input_ids=input_ids,
191
+ attention_mask=attention_mask,
192
+ token_type_ids=token_type_ids
193
+ )
194
+
195
+ # Calculate loss
196
+ loss = self.criterion(outputs, labels)
197
+ eval_loss += loss.item()
198
+
199
+ # Get predictions
200
+ _, preds = torch.max(outputs, dim=1)
201
+ all_predictions.extend(preds.cpu().tolist())
202
+ all_labels.extend(labels.cpu().tolist())
203
+
204
+ # Calculate metrics
205
+ eval_loss /= len(data_loader)
206
+ accuracy = accuracy_score(all_labels, all_predictions)
207
+ f1 = f1_score(all_labels, all_predictions, average='macro')
208
+ precision = precision_score(all_labels, all_predictions, average='macro', zero_division=0)
209
+ recall = recall_score(all_labels, all_predictions, average='macro', zero_division=0)
210
+
211
+ return eval_loss, accuracy, f1, precision, recall