Commit
·
da89f1c
0
Parent(s):
First commit
Browse files- README.md +32 -0
- config.py +57 -0
- dataset.py +104 -0
- model.py +48 -0
- requirements.txt +6 -0
- run.py +86 -0
- train.py +129 -0
- trainer.py +211 -0
README.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DocBERT - Improved Document Classification with BERT
|
| 2 |
+
|
| 3 |
+
This repository contains an improved implementation of BERT for document classification, combining techniques from [jesse-tong/docbert](https://github.com/jesse-tong/docbert) and [castorini/hedwig](https://github.com/castorini/hedwig).
|
| 4 |
+
|
| 5 |
+
## Key Improvements
|
| 6 |
+
|
| 7 |
+
1. **Advanced Regularization Techniques**:
|
| 8 |
+
- Dropout in multiple layers
|
| 9 |
+
- Layer normalization
|
| 10 |
+
- Gradient clipping
|
| 11 |
+
- Weight decay optimization
|
| 12 |
+
|
| 13 |
+
2. **Training Stability Enhancements**:
|
| 14 |
+
- Learning rate scheduling with ReduceLROnPlateau
|
| 15 |
+
- Gradient accumulation for effective larger batch sizes
|
| 16 |
+
- Label smoothing to improve generalization
|
| 17 |
+
- Early stopping based on validation F1 score
|
| 18 |
+
|
| 19 |
+
3. **Architectural Changes**:
|
| 20 |
+
- Better BERT pooling strategies
|
| 21 |
+
- More robust tokenization with attention masks
|
| 22 |
+
- Configurable hyperparameters for different document types
|
| 23 |
+
|
| 24 |
+
## Installation
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
# Clone the repository
|
| 28 |
+
git clone https://github.com/yourusername/docbert-improved.git
|
| 29 |
+
cd docbert-improved
|
| 30 |
+
|
| 31 |
+
# Install dependencies
|
| 32 |
+
pip install -r requirements.txt
|
config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration module for DocBERT
|
| 3 |
+
Contains hyperparameter presets for different dataset types
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
class BaseConfig:
|
| 7 |
+
# Model params
|
| 8 |
+
bert_model = "bert-base-uncased"
|
| 9 |
+
max_seq_length = 512
|
| 10 |
+
dropout = 0.1
|
| 11 |
+
|
| 12 |
+
# Training params
|
| 13 |
+
batch_size = 16
|
| 14 |
+
learning_rate = 2e-5
|
| 15 |
+
weight_decay = 0.01
|
| 16 |
+
epochs = 10
|
| 17 |
+
grad_accum_steps = 1
|
| 18 |
+
|
| 19 |
+
# Data params
|
| 20 |
+
val_split = 0.1
|
| 21 |
+
test_split = 0.1
|
| 22 |
+
seed = 42
|
| 23 |
+
|
| 24 |
+
class ShortTextConfig(BaseConfig):
|
| 25 |
+
"""Config for short text classification (tweets, comments, etc.)"""
|
| 26 |
+
max_seq_length = 128
|
| 27 |
+
batch_size = 32
|
| 28 |
+
learning_rate = 3e-5
|
| 29 |
+
|
| 30 |
+
class LongDocumentConfig(BaseConfig):
|
| 31 |
+
"""Config for long document classification"""
|
| 32 |
+
bert_model = "bert-large-uncased"
|
| 33 |
+
max_seq_length = 512
|
| 34 |
+
batch_size = 8
|
| 35 |
+
grad_accum_steps = 2
|
| 36 |
+
weight_decay = 0.02
|
| 37 |
+
|
| 38 |
+
class FinetuningConfig(BaseConfig):
|
| 39 |
+
"""Config for fine-tuning on a small dataset"""
|
| 40 |
+
learning_rate = 1e-5
|
| 41 |
+
batch_size = 8
|
| 42 |
+
epochs = 15
|
| 43 |
+
weight_decay = 0.03
|
| 44 |
+
dropout = 0.2
|
| 45 |
+
|
| 46 |
+
CONFIG_PRESETS = {
|
| 47 |
+
"default": BaseConfig,
|
| 48 |
+
"short_text": ShortTextConfig,
|
| 49 |
+
"long_document": LongDocumentConfig,
|
| 50 |
+
"fine_tuning": FinetuningConfig
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
def get_config(preset_name="default"):
|
| 54 |
+
"""Get a configuration preset by name"""
|
| 55 |
+
if preset_name not in CONFIG_PRESETS:
|
| 56 |
+
raise ValueError(f"Config preset '{preset_name}' not found. Available presets: {list(CONFIG_PRESETS.keys())}")
|
| 57 |
+
return CONFIG_PRESETS[preset_name]
|
dataset.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
from transformers import BertTokenizer
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
class DocumentDataset(Dataset):
|
| 8 |
+
"""
|
| 9 |
+
Dataset class for document classification
|
| 10 |
+
with improved preprocessing and batching
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, texts, labels, tokenizer_name='bert-base-uncased', max_length=512):
|
| 13 |
+
self.texts = texts
|
| 14 |
+
self.labels = labels
|
| 15 |
+
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
|
| 16 |
+
self.max_length = max_length
|
| 17 |
+
|
| 18 |
+
def __len__(self):
|
| 19 |
+
return len(self.texts)
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, idx):
|
| 22 |
+
text = str(self.texts[idx])
|
| 23 |
+
label = self.labels[idx]
|
| 24 |
+
|
| 25 |
+
# Tokenize the text with attention mask and truncation
|
| 26 |
+
encoding = self.tokenizer.encode_plus(
|
| 27 |
+
text,
|
| 28 |
+
add_special_tokens=True,
|
| 29 |
+
max_length=self.max_length,
|
| 30 |
+
return_token_type_ids=True,
|
| 31 |
+
padding='max_length',
|
| 32 |
+
truncation=True,
|
| 33 |
+
return_attention_mask=True,
|
| 34 |
+
return_tensors='pt'
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return {
|
| 38 |
+
'input_ids': encoding['input_ids'].flatten(),
|
| 39 |
+
'attention_mask': encoding['attention_mask'].flatten(),
|
| 40 |
+
'token_type_ids': encoding['token_type_ids'].flatten(),
|
| 41 |
+
'label': torch.tensor(label, dtype=torch.long)
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def load_data(data_path, text_col='text', label_col='label', validation_split=0.1, test_split=0.1, seed=42):
|
| 45 |
+
"""
|
| 46 |
+
Load data from CSV/TSV and split into train, validation and test sets
|
| 47 |
+
"""
|
| 48 |
+
# Determine file format based on extension
|
| 49 |
+
if data_path.endswith('.csv'):
|
| 50 |
+
df = pd.read_csv(data_path)
|
| 51 |
+
elif data_path.endswith('.tsv'):
|
| 52 |
+
df = pd.read_csv(data_path, sep='\t')
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError("Unsupported file format. Please provide CSV or TSV file.")
|
| 55 |
+
|
| 56 |
+
# Convert labels to numeric if they aren't already
|
| 57 |
+
if not np.issubdtype(df[label_col].dtype, np.number):
|
| 58 |
+
label_map = {label: idx for idx, label in enumerate(df[label_col].unique())}
|
| 59 |
+
df['label_numeric'] = df[label_col].map(label_map)
|
| 60 |
+
labels = df['label_numeric'].values
|
| 61 |
+
else:
|
| 62 |
+
labels = df[label_col].values
|
| 63 |
+
|
| 64 |
+
# Create a DataFrame with text and numeric labels
|
| 65 |
+
texts = df[text_col].values
|
| 66 |
+
|
| 67 |
+
# Shuffle and split the data
|
| 68 |
+
np.random.seed(seed)
|
| 69 |
+
indices = np.random.permutation(len(texts))
|
| 70 |
+
|
| 71 |
+
test_size = int(test_split * len(texts))
|
| 72 |
+
val_size = int(validation_split * len(texts))
|
| 73 |
+
train_size = len(texts) - test_size - val_size
|
| 74 |
+
|
| 75 |
+
train_indices = indices[:train_size]
|
| 76 |
+
val_indices = indices[train_size:train_size + val_size]
|
| 77 |
+
test_indices = indices[train_size + val_size:]
|
| 78 |
+
|
| 79 |
+
train_texts, train_labels = texts[train_indices], labels[train_indices]
|
| 80 |
+
val_texts, val_labels = texts[val_indices], labels[val_indices]
|
| 81 |
+
test_texts, test_labels = texts[test_indices], labels[test_indices]
|
| 82 |
+
|
| 83 |
+
return (train_texts, train_labels), (val_texts, val_labels), (test_texts, test_labels)
|
| 84 |
+
|
| 85 |
+
def create_data_loaders(train_data, val_data, test_data, tokenizer_name='bert-base-uncased',
|
| 86 |
+
max_length=512, batch_size=16):
|
| 87 |
+
"""
|
| 88 |
+
Create DataLoader objects for training, validation and testing
|
| 89 |
+
"""
|
| 90 |
+
train_texts, train_labels = train_data
|
| 91 |
+
val_texts, val_labels = val_data
|
| 92 |
+
test_texts, test_labels = test_data
|
| 93 |
+
|
| 94 |
+
# Create datasets
|
| 95 |
+
train_dataset = DocumentDataset(train_texts, train_labels, tokenizer_name, max_length)
|
| 96 |
+
val_dataset = DocumentDataset(val_texts, val_labels, tokenizer_name, max_length)
|
| 97 |
+
test_dataset = DocumentDataset(test_texts, test_labels, tokenizer_name, max_length)
|
| 98 |
+
|
| 99 |
+
# Create data loaders
|
| 100 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 101 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size)
|
| 102 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size)
|
| 103 |
+
|
| 104 |
+
return train_loader, val_loader, test_loader
|
model.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import BertModel, BertConfig
|
| 4 |
+
|
| 5 |
+
class DocBERT(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Document classification using BERT with improved architecture
|
| 8 |
+
based on Hedwig implementation patterns.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, num_classes, bert_model_name='bert-base-uncased', dropout_prob=0.1):
|
| 11 |
+
super(DocBERT, self).__init__()
|
| 12 |
+
|
| 13 |
+
# Load pre-trained BERT model or config
|
| 14 |
+
self.bert = BertModel.from_pretrained(bert_model_name)
|
| 15 |
+
self.config = self.bert.config
|
| 16 |
+
|
| 17 |
+
# Dropout layer for regularization (helps prevent overfitting)
|
| 18 |
+
self.dropout = nn.Dropout(dropout_prob)
|
| 19 |
+
|
| 20 |
+
# Multiple classification heads approach (inspired by Hedwig)
|
| 21 |
+
self.hidden_size = self.config.hidden_size
|
| 22 |
+
self.classifier = nn.Linear(self.hidden_size, num_classes)
|
| 23 |
+
|
| 24 |
+
# Layer normalization before classification (helps stabilize training)
|
| 25 |
+
self.layer_norm = nn.LayerNorm(self.hidden_size)
|
| 26 |
+
|
| 27 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
|
| 28 |
+
"""
|
| 29 |
+
Forward pass through the model
|
| 30 |
+
"""
|
| 31 |
+
# Get BERT outputs
|
| 32 |
+
outputs = self.bert(input_ids=input_ids,
|
| 33 |
+
attention_mask=attention_mask,
|
| 34 |
+
token_type_ids=token_type_ids)
|
| 35 |
+
|
| 36 |
+
# Get the [CLS] token representation (first token)
|
| 37 |
+
pooled_output = outputs.pooler_output
|
| 38 |
+
|
| 39 |
+
# Apply layer normalization
|
| 40 |
+
normalized_output = self.layer_norm(pooled_output)
|
| 41 |
+
|
| 42 |
+
# Apply dropout for regularization
|
| 43 |
+
dropped_output = self.dropout(normalized_output)
|
| 44 |
+
|
| 45 |
+
# Pass through the classifier
|
| 46 |
+
logits = self.classifier(dropped_output)
|
| 47 |
+
|
| 48 |
+
return logits
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
scikit-learn
|
| 2 |
+
numpy
|
| 3 |
+
pandas
|
| 4 |
+
torch
|
| 5 |
+
transformers
|
| 6 |
+
datasets
|
run.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple script to run the DocBERT model with predefined config presets
|
| 3 |
+
"""
|
| 4 |
+
import argparse
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from config import get_config
|
| 8 |
+
from model import DocBERT
|
| 9 |
+
from dataset import load_data, create_data_loaders
|
| 10 |
+
from trainer import Trainer
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
parser = argparse.ArgumentParser(description="Run DocBERT with a predefined config")
|
| 17 |
+
|
| 18 |
+
parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file (CSV or TSV)")
|
| 19 |
+
parser.add_argument("--text_column", type=str, default="text", help="Name of the text column")
|
| 20 |
+
parser.add_argument("--label_column", type=str, default="label", help="Name of the label column")
|
| 21 |
+
parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict")
|
| 22 |
+
parser.add_argument("--config", type=str, default="default",
|
| 23 |
+
choices=["default", "short_text", "long_document", "fine_tuning"],
|
| 24 |
+
help="Configuration preset to use")
|
| 25 |
+
parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save outputs")
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
# Get config
|
| 30 |
+
config_class = get_config(args.config)
|
| 31 |
+
config = config_class()
|
| 32 |
+
|
| 33 |
+
logger.info(f"Using '{args.config}' config preset")
|
| 34 |
+
|
| 35 |
+
# Create output directory
|
| 36 |
+
if not os.path.exists(args.output_dir):
|
| 37 |
+
os.makedirs(args.output_dir)
|
| 38 |
+
|
| 39 |
+
# Load and prepare data
|
| 40 |
+
logger.info("Loading data...")
|
| 41 |
+
train_data, val_data, test_data = load_data(
|
| 42 |
+
args.data_path,
|
| 43 |
+
text_col=args.text_column,
|
| 44 |
+
label_col=args.label_column,
|
| 45 |
+
validation_split=config.val_split,
|
| 46 |
+
test_split=config.test_split,
|
| 47 |
+
seed=config.seed
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
train_loader, val_loader, test_loader = create_data_loaders(
|
| 51 |
+
train_data,
|
| 52 |
+
val_data,
|
| 53 |
+
test_data,
|
| 54 |
+
tokenizer_name=config.bert_model,
|
| 55 |
+
max_length=config.max_seq_length,
|
| 56 |
+
batch_size=config.batch_size
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Initialize model
|
| 60 |
+
logger.info(f"Initializing model with {config.bert_model}...")
|
| 61 |
+
model = DocBERT(
|
| 62 |
+
num_classes=args.num_classes,
|
| 63 |
+
bert_model_name=config.bert_model,
|
| 64 |
+
dropout_prob=config.dropout
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Initialize trainer
|
| 68 |
+
trainer = Trainer(
|
| 69 |
+
model=model,
|
| 70 |
+
train_loader=train_loader,
|
| 71 |
+
val_loader=val_loader,
|
| 72 |
+
test_loader=test_loader,
|
| 73 |
+
lr=config.learning_rate,
|
| 74 |
+
weight_decay=config.weight_decay,
|
| 75 |
+
gradient_accumulation_steps=config.grad_accum_steps
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Train model
|
| 79 |
+
logger.info("Starting training...")
|
| 80 |
+
save_path = os.path.join(args.output_dir, "best_model.pth")
|
| 81 |
+
trainer.train(epochs=config.epochs, save_path=save_path)
|
| 82 |
+
|
| 83 |
+
logger.info("Training completed!")
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
main()
|
train.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
from model import DocBERT
|
| 8 |
+
from dataset import load_data, create_data_loaders
|
| 9 |
+
from trainer import Trainer
|
| 10 |
+
|
| 11 |
+
# Setup logging
|
| 12 |
+
logging.basicConfig(
|
| 13 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 14 |
+
level=logging.INFO,
|
| 15 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 16 |
+
)
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
def set_seed(seed):
|
| 20 |
+
"""Set all seeds for reproducibility"""
|
| 21 |
+
random.seed(seed)
|
| 22 |
+
np.random.seed(seed)
|
| 23 |
+
torch.manual_seed(seed)
|
| 24 |
+
if torch.cuda.is_available():
|
| 25 |
+
torch.cuda.manual_seed_all(seed)
|
| 26 |
+
torch.backends.cudnn.deterministic = True
|
| 27 |
+
torch.backends.cudnn.benchmark = False
|
| 28 |
+
|
| 29 |
+
def main():
|
| 30 |
+
parser = argparse.ArgumentParser(description="Train a document classification model with BERT")
|
| 31 |
+
|
| 32 |
+
# Data arguments
|
| 33 |
+
parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file (CSV or TSV)")
|
| 34 |
+
parser.add_argument("--text_column", type=str, default="text", help="Name of the text column")
|
| 35 |
+
parser.add_argument("--label_column", type=str, default="label", help="Name of the label column")
|
| 36 |
+
parser.add_argument("--val_split", type=float, default=0.1, help="Validation set split ratio")
|
| 37 |
+
parser.add_argument("--test_split", type=float, default=0.1, help="Test set split ratio")
|
| 38 |
+
|
| 39 |
+
# Model arguments
|
| 40 |
+
parser.add_argument("--bert_model", type=str, default="bert-base-uncased",
|
| 41 |
+
help="BERT model to use (e.g., bert-base-uncased, bert-large-uncased)")
|
| 42 |
+
parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict")
|
| 43 |
+
parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length")
|
| 44 |
+
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout probability")
|
| 45 |
+
|
| 46 |
+
# Training arguments
|
| 47 |
+
parser.add_argument("--batch_size", type=int, default=16, help="Training batch size")
|
| 48 |
+
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
|
| 49 |
+
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for regularization")
|
| 50 |
+
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
|
| 51 |
+
parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps")
|
| 52 |
+
parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Proportion of training for LR warmup")
|
| 53 |
+
|
| 54 |
+
# Other arguments
|
| 55 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
| 56 |
+
parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the model and logs")
|
| 57 |
+
|
| 58 |
+
args = parser.parse_args()
|
| 59 |
+
|
| 60 |
+
# Set seed for reproducibility
|
| 61 |
+
set_seed(args.seed)
|
| 62 |
+
|
| 63 |
+
# Create output directory if it doesn't exist
|
| 64 |
+
if not os.path.exists(args.output_dir):
|
| 65 |
+
os.makedirs(args.output_dir)
|
| 66 |
+
|
| 67 |
+
# Log args for debugging
|
| 68 |
+
logger.info(f"Running with arguments: {args}")
|
| 69 |
+
|
| 70 |
+
# Load and prepare data
|
| 71 |
+
logger.info("Loading and preparing data...")
|
| 72 |
+
train_data, val_data, test_data = load_data(
|
| 73 |
+
args.data_path,
|
| 74 |
+
text_col=args.text_column,
|
| 75 |
+
label_col=args.label_column,
|
| 76 |
+
validation_split=args.val_split,
|
| 77 |
+
test_split=args.test_split,
|
| 78 |
+
seed=args.seed
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Create data loaders
|
| 82 |
+
train_loader, val_loader, test_loader = create_data_loaders(
|
| 83 |
+
train_data,
|
| 84 |
+
val_data,
|
| 85 |
+
test_data,
|
| 86 |
+
tokenizer_name=args.bert_model,
|
| 87 |
+
max_length=args.max_length,
|
| 88 |
+
batch_size=args.batch_size
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
logger.info(f"Train samples: {len(train_data[0])}, "
|
| 92 |
+
f"Validation samples: {len(val_data[0])}, "
|
| 93 |
+
f"Test samples: {len(test_data[0])}")
|
| 94 |
+
|
| 95 |
+
# Initialize model
|
| 96 |
+
logger.info(f"Initializing DocBERT model with {args.bert_model}...")
|
| 97 |
+
model = DocBERT(
|
| 98 |
+
num_classes=args.num_classes,
|
| 99 |
+
bert_model_name=args.bert_model,
|
| 100 |
+
dropout_prob=args.dropout
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Count and log model parameters
|
| 104 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 105 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 106 |
+
logger.info(f"Total parameters: {total_params:,}")
|
| 107 |
+
logger.info(f"Trainable parameters: {trainable_params:,}")
|
| 108 |
+
|
| 109 |
+
# Initialize trainer
|
| 110 |
+
trainer = Trainer(
|
| 111 |
+
model=model,
|
| 112 |
+
train_loader=train_loader,
|
| 113 |
+
val_loader=val_loader,
|
| 114 |
+
test_loader=test_loader,
|
| 115 |
+
lr=args.learning_rate,
|
| 116 |
+
weight_decay=args.weight_decay,
|
| 117 |
+
warmup_proportion=args.warmup_proportion,
|
| 118 |
+
gradient_accumulation_steps=args.grad_accum_steps
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Train the model
|
| 122 |
+
logger.info("Starting training...")
|
| 123 |
+
save_path = os.path.join(args.output_dir, "best_model.pth")
|
| 124 |
+
trainer.train(epochs=args.epochs, save_path=save_path)
|
| 125 |
+
|
| 126 |
+
logger.info("Training completed!")
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
trainer.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 5 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
| 6 |
+
import numpy as np
|
| 7 |
+
import time
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
class Trainer:
|
| 16 |
+
"""
|
| 17 |
+
Improved trainer class with techniques from Hedwig implementation
|
| 18 |
+
to get better performance on document classification tasks
|
| 19 |
+
"""
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
model,
|
| 23 |
+
train_loader,
|
| 24 |
+
val_loader,
|
| 25 |
+
test_loader=None,
|
| 26 |
+
lr=2e-5,
|
| 27 |
+
weight_decay=0.01,
|
| 28 |
+
warmup_proportion=0.1,
|
| 29 |
+
gradient_accumulation_steps=1,
|
| 30 |
+
max_grad_norm=1.0,
|
| 31 |
+
device=None
|
| 32 |
+
):
|
| 33 |
+
self.model = model
|
| 34 |
+
self.train_loader = train_loader
|
| 35 |
+
self.val_loader = val_loader
|
| 36 |
+
self.test_loader = test_loader
|
| 37 |
+
|
| 38 |
+
self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 39 |
+
logger.info(f"Using device: {self.device}")
|
| 40 |
+
|
| 41 |
+
self.model.to(self.device)
|
| 42 |
+
|
| 43 |
+
# Total number of training steps
|
| 44 |
+
self.num_training_steps = len(train_loader) * gradient_accumulation_steps
|
| 45 |
+
|
| 46 |
+
# Optimizer with weight decay (L2 regularization)
|
| 47 |
+
# Using different learning rates for BERT and classifier
|
| 48 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
| 49 |
+
optimizer_grouped_parameters = [
|
| 50 |
+
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
| 51 |
+
'weight_decay': weight_decay},
|
| 52 |
+
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
| 53 |
+
'weight_decay': 0.0}
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=lr)
|
| 57 |
+
|
| 58 |
+
# Learning rate scheduler
|
| 59 |
+
self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=0.5, patience=2, verbose=True)
|
| 60 |
+
|
| 61 |
+
# Loss function with label smoothing for better generalization
|
| 62 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 63 |
+
|
| 64 |
+
# Training parameters
|
| 65 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 66 |
+
self.max_grad_norm = max_grad_norm
|
| 67 |
+
|
| 68 |
+
# For tracking metrics
|
| 69 |
+
self.best_val_f1 = 0.0
|
| 70 |
+
self.best_model_state = None
|
| 71 |
+
|
| 72 |
+
def train(self, epochs, save_path='best_model.pth'):
|
| 73 |
+
"""
|
| 74 |
+
Training loop with improved techniques
|
| 75 |
+
"""
|
| 76 |
+
logger.info(f"Starting training for {epochs} epochs")
|
| 77 |
+
|
| 78 |
+
for epoch in range(epochs):
|
| 79 |
+
start_time = time.time()
|
| 80 |
+
|
| 81 |
+
# Training phase
|
| 82 |
+
self.model.train()
|
| 83 |
+
train_loss = 0
|
| 84 |
+
all_predictions = []
|
| 85 |
+
all_labels = []
|
| 86 |
+
|
| 87 |
+
# Progress bar for training
|
| 88 |
+
train_iterator = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
|
| 89 |
+
for i, batch in enumerate(train_iterator):
|
| 90 |
+
# Move batch to device
|
| 91 |
+
input_ids = batch['input_ids'].to(self.device)
|
| 92 |
+
attention_mask = batch['attention_mask'].to(self.device)
|
| 93 |
+
token_type_ids = batch['token_type_ids'].to(self.device)
|
| 94 |
+
labels = batch['label'].to(self.device)
|
| 95 |
+
|
| 96 |
+
# Forward pass
|
| 97 |
+
outputs = self.model(
|
| 98 |
+
input_ids=input_ids,
|
| 99 |
+
attention_mask=attention_mask,
|
| 100 |
+
token_type_ids=token_type_ids
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Calculate loss
|
| 104 |
+
loss = self.criterion(outputs, labels)
|
| 105 |
+
|
| 106 |
+
# Scale loss if using gradient accumulation
|
| 107 |
+
if self.gradient_accumulation_steps > 1:
|
| 108 |
+
loss = loss / self.gradient_accumulation_steps
|
| 109 |
+
|
| 110 |
+
# Backward pass
|
| 111 |
+
loss.backward()
|
| 112 |
+
|
| 113 |
+
# Update weights if we've accumulated enough gradients
|
| 114 |
+
if (i + 1) % self.gradient_accumulation_steps == 0:
|
| 115 |
+
# Gradient clipping to prevent exploding gradients
|
| 116 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
| 117 |
+
|
| 118 |
+
self.optimizer.step()
|
| 119 |
+
self.optimizer.zero_grad()
|
| 120 |
+
|
| 121 |
+
train_loss += loss.item() * self.gradient_accumulation_steps
|
| 122 |
+
|
| 123 |
+
# Get predictions for metrics
|
| 124 |
+
_, preds = torch.max(outputs, dim=1)
|
| 125 |
+
all_predictions.extend(preds.cpu().tolist())
|
| 126 |
+
all_labels.extend(labels.cpu().tolist())
|
| 127 |
+
|
| 128 |
+
# Update progress bar with current loss
|
| 129 |
+
train_iterator.set_postfix({'loss': f"{loss.item():.4f}"})
|
| 130 |
+
|
| 131 |
+
# Calculate training metrics
|
| 132 |
+
train_loss /= len(self.train_loader)
|
| 133 |
+
train_acc = accuracy_score(all_labels, all_predictions)
|
| 134 |
+
train_f1 = f1_score(all_labels, all_predictions, average='macro')
|
| 135 |
+
|
| 136 |
+
# Validation phase
|
| 137 |
+
val_loss, val_acc, val_f1, val_precision, val_recall = self.evaluate(self.val_loader, "Validation")
|
| 138 |
+
|
| 139 |
+
# Adjust learning rate based on validation performance
|
| 140 |
+
self.scheduler.step(val_f1)
|
| 141 |
+
|
| 142 |
+
# Save best model
|
| 143 |
+
if val_f1 > self.best_val_f1:
|
| 144 |
+
self.best_val_f1 = val_f1
|
| 145 |
+
self.best_model_state = self.model.state_dict().copy()
|
| 146 |
+
torch.save(self.model.state_dict(), save_path)
|
| 147 |
+
logger.info(f"New best model saved with validation F1: {val_f1:.4f}")
|
| 148 |
+
|
| 149 |
+
# Print epoch summary
|
| 150 |
+
epoch_time = time.time() - start_time
|
| 151 |
+
logger.info(f"Epoch {epoch+1}/{epochs} - "
|
| 152 |
+
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}, "
|
| 153 |
+
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, "
|
| 154 |
+
f"Time: {epoch_time:.2f}s")
|
| 155 |
+
|
| 156 |
+
# Load best model for final evaluation
|
| 157 |
+
if self.best_model_state is not None:
|
| 158 |
+
self.model.load_state_dict(self.best_model_state)
|
| 159 |
+
logger.info(f"Loaded best model with validation F1: {self.best_val_f1:.4f}")
|
| 160 |
+
|
| 161 |
+
# Test evaluation if test loader provided
|
| 162 |
+
if self.test_loader:
|
| 163 |
+
test_loss, test_acc, test_f1, test_precision, test_recall = self.evaluate(self.test_loader, "Test")
|
| 164 |
+
logger.info(f"Final test results - "
|
| 165 |
+
f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, "
|
| 166 |
+
f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
|
| 167 |
+
|
| 168 |
+
def evaluate(self, data_loader, phase="Validation"):
|
| 169 |
+
"""
|
| 170 |
+
Evaluation function for both validation and test sets
|
| 171 |
+
"""
|
| 172 |
+
self.model.eval()
|
| 173 |
+
eval_loss = 0
|
| 174 |
+
all_predictions = []
|
| 175 |
+
all_labels = []
|
| 176 |
+
|
| 177 |
+
# No gradient computation during evaluation
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
# Progress bar for evaluation
|
| 180 |
+
iterator = tqdm(data_loader, desc=f"[{phase}]")
|
| 181 |
+
for batch in iterator:
|
| 182 |
+
# Move batch to device
|
| 183 |
+
input_ids = batch['input_ids'].to(self.device)
|
| 184 |
+
attention_mask = batch['attention_mask'].to(self.device)
|
| 185 |
+
token_type_ids = batch['token_type_ids'].to(self.device)
|
| 186 |
+
labels = batch['label'].to(self.device)
|
| 187 |
+
|
| 188 |
+
# Forward pass
|
| 189 |
+
outputs = self.model(
|
| 190 |
+
input_ids=input_ids,
|
| 191 |
+
attention_mask=attention_mask,
|
| 192 |
+
token_type_ids=token_type_ids
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Calculate loss
|
| 196 |
+
loss = self.criterion(outputs, labels)
|
| 197 |
+
eval_loss += loss.item()
|
| 198 |
+
|
| 199 |
+
# Get predictions
|
| 200 |
+
_, preds = torch.max(outputs, dim=1)
|
| 201 |
+
all_predictions.extend(preds.cpu().tolist())
|
| 202 |
+
all_labels.extend(labels.cpu().tolist())
|
| 203 |
+
|
| 204 |
+
# Calculate metrics
|
| 205 |
+
eval_loss /= len(data_loader)
|
| 206 |
+
accuracy = accuracy_score(all_labels, all_predictions)
|
| 207 |
+
f1 = f1_score(all_labels, all_predictions, average='macro')
|
| 208 |
+
precision = precision_score(all_labels, all_predictions, average='macro', zero_division=0)
|
| 209 |
+
recall = recall_score(all_labels, all_predictions, average='macro', zero_division=0)
|
| 210 |
+
|
| 211 |
+
return eval_loss, accuracy, f1, precision, recall
|