|
|
"""
|
|
|
Small Language Model (SLM) architecture for document text extraction.
|
|
|
Uses DistilBERT with transfer learning for Named Entity Recognition.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from transformers import (
|
|
|
DistilBertTokenizer,
|
|
|
DistilBertForTokenClassification,
|
|
|
DistilBertConfig,
|
|
|
get_linear_schedule_with_warmup
|
|
|
)
|
|
|
from typing import List, Dict, Tuple, Optional
|
|
|
import json
|
|
|
import numpy as np
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ModelConfig:
|
|
|
"""Configuration for the SLM model."""
|
|
|
model_name: str = "distilbert-base-uncased"
|
|
|
max_length: int = 512
|
|
|
batch_size: int = 16
|
|
|
learning_rate: float = 2e-5
|
|
|
num_epochs: int = 3
|
|
|
warmup_steps: int = 500
|
|
|
weight_decay: float = 0.01
|
|
|
dropout_rate: float = 0.3
|
|
|
|
|
|
|
|
|
entity_labels: List[str] = None
|
|
|
|
|
|
def __post_init__(self):
|
|
|
if self.entity_labels is None:
|
|
|
self.entity_labels = [
|
|
|
'O', 'B-NAME', 'I-NAME', 'B-DATE', 'I-DATE',
|
|
|
'B-INVOICE_NO', 'I-INVOICE_NO', 'B-AMOUNT', 'I-AMOUNT',
|
|
|
'B-ADDRESS', 'I-ADDRESS', 'B-PHONE', 'I-PHONE',
|
|
|
'B-EMAIL', 'I-EMAIL'
|
|
|
]
|
|
|
|
|
|
@property
|
|
|
def num_labels(self) -> int:
|
|
|
return len(self.entity_labels)
|
|
|
|
|
|
@property
|
|
|
def label2id(self) -> Dict[str, int]:
|
|
|
return {label: i for i, label in enumerate(self.entity_labels)}
|
|
|
|
|
|
@property
|
|
|
def id2label(self) -> Dict[int, str]:
|
|
|
return {i: label for i, label in enumerate(self.entity_labels)}
|
|
|
|
|
|
|
|
|
class NERDataset(Dataset):
|
|
|
"""PyTorch Dataset for NER training."""
|
|
|
|
|
|
def __init__(self, dataset: List[Dict], tokenizer: DistilBertTokenizer,
|
|
|
config: ModelConfig, mode: str = 'train'):
|
|
|
self.dataset = dataset
|
|
|
self.tokenizer = tokenizer
|
|
|
self.config = config
|
|
|
self.mode = mode
|
|
|
|
|
|
|
|
|
self.tokenized_data = self._tokenize_and_align_labels()
|
|
|
|
|
|
def _tokenize_and_align_labels(self) -> List[Dict]:
|
|
|
"""Tokenize text and align labels with subword tokens."""
|
|
|
tokenized_data = []
|
|
|
|
|
|
for example in self.dataset:
|
|
|
tokens = example['tokens']
|
|
|
labels = example['labels']
|
|
|
|
|
|
|
|
|
tokenized_inputs = self.tokenizer(
|
|
|
tokens,
|
|
|
is_split_into_words=True,
|
|
|
padding='max_length',
|
|
|
truncation=True,
|
|
|
max_length=self.config.max_length,
|
|
|
return_tensors='pt'
|
|
|
)
|
|
|
|
|
|
|
|
|
word_ids = tokenized_inputs.word_ids()
|
|
|
aligned_labels = []
|
|
|
previous_word_idx = None
|
|
|
|
|
|
for word_idx in word_ids:
|
|
|
if word_idx is None:
|
|
|
|
|
|
aligned_labels.append(-100)
|
|
|
elif word_idx != previous_word_idx:
|
|
|
|
|
|
if word_idx < len(labels):
|
|
|
label = labels[word_idx]
|
|
|
aligned_labels.append(self.config.label2id.get(label, 0))
|
|
|
else:
|
|
|
aligned_labels.append(-100)
|
|
|
else:
|
|
|
|
|
|
if word_idx < len(labels):
|
|
|
label = labels[word_idx]
|
|
|
if label.startswith('B-'):
|
|
|
|
|
|
i_label = label.replace('B-', 'I-')
|
|
|
aligned_labels.append(self.config.label2id.get(i_label, 0))
|
|
|
else:
|
|
|
aligned_labels.append(self.config.label2id.get(label, 0))
|
|
|
else:
|
|
|
aligned_labels.append(-100)
|
|
|
|
|
|
previous_word_idx = word_idx
|
|
|
|
|
|
tokenized_data.append({
|
|
|
'input_ids': tokenized_inputs['input_ids'].squeeze(),
|
|
|
'attention_mask': tokenized_inputs['attention_mask'].squeeze(),
|
|
|
'labels': torch.tensor(aligned_labels, dtype=torch.long),
|
|
|
'original_tokens': tokens,
|
|
|
'original_labels': labels
|
|
|
})
|
|
|
|
|
|
return tokenized_data
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
return len(self.tokenized_data)
|
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
|
return {
|
|
|
'input_ids': self.tokenized_data[idx]['input_ids'],
|
|
|
'attention_mask': self.tokenized_data[idx]['attention_mask'],
|
|
|
'labels': self.tokenized_data[idx]['labels']
|
|
|
}
|
|
|
|
|
|
|
|
|
class DocumentNERModel(nn.Module):
|
|
|
"""DistilBERT-based model for document NER."""
|
|
|
|
|
|
def __init__(self, config: ModelConfig):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
bert_config = DistilBertConfig.from_pretrained(
|
|
|
config.model_name,
|
|
|
num_labels=config.num_labels,
|
|
|
id2label=config.id2label,
|
|
|
label2id=config.label2id,
|
|
|
dropout=config.dropout_rate,
|
|
|
attention_dropout=config.dropout_rate
|
|
|
)
|
|
|
|
|
|
|
|
|
self.model = DistilBertForTokenClassification.from_pretrained(
|
|
|
config.model_name,
|
|
|
config=bert_config
|
|
|
)
|
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout_rate)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None):
|
|
|
"""Forward pass through the model."""
|
|
|
outputs = self.model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
labels=labels
|
|
|
)
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
def predict(self, input_ids, attention_mask):
|
|
|
"""Make predictions without computing loss."""
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
predictions = torch.argmax(outputs.logits, dim=-1)
|
|
|
probabilities = torch.softmax(outputs.logits, dim=-1)
|
|
|
|
|
|
return predictions, probabilities
|
|
|
|
|
|
|
|
|
class NERTrainer:
|
|
|
"""Trainer class for the NER model."""
|
|
|
|
|
|
def __init__(self, model: DocumentNERModel, config: ModelConfig):
|
|
|
self.model = model
|
|
|
self.config = config
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
self.model.to(self.device)
|
|
|
|
|
|
|
|
|
self.tokenizer = DistilBertTokenizer.from_pretrained(config.model_name)
|
|
|
|
|
|
def prepare_dataloaders(self, dataset: List[Dict],
|
|
|
test_size: float = 0.2) -> Tuple[DataLoader, DataLoader]:
|
|
|
"""Prepare training and validation dataloaders."""
|
|
|
|
|
|
train_data, val_data = train_test_split(
|
|
|
dataset, test_size=test_size, random_state=42
|
|
|
)
|
|
|
|
|
|
|
|
|
train_dataset = NERDataset(train_data, self.tokenizer, self.config, 'train')
|
|
|
val_dataset = NERDataset(val_data, self.tokenizer, self.config, 'val')
|
|
|
|
|
|
|
|
|
train_dataloader = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=self.config.batch_size,
|
|
|
shuffle=True
|
|
|
)
|
|
|
val_dataloader = DataLoader(
|
|
|
val_dataset,
|
|
|
batch_size=self.config.batch_size,
|
|
|
shuffle=False
|
|
|
)
|
|
|
|
|
|
return train_dataloader, val_dataloader
|
|
|
|
|
|
def train(self, train_dataloader: DataLoader,
|
|
|
val_dataloader: DataLoader) -> Dict[str, List[float]]:
|
|
|
"""Train the NER model."""
|
|
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
|
self.model.parameters(),
|
|
|
lr=self.config.learning_rate,
|
|
|
weight_decay=self.config.weight_decay
|
|
|
)
|
|
|
|
|
|
total_steps = len(train_dataloader) * self.config.num_epochs
|
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
|
optimizer,
|
|
|
num_warmup_steps=self.config.warmup_steps,
|
|
|
num_training_steps=total_steps
|
|
|
)
|
|
|
|
|
|
|
|
|
history = {
|
|
|
'train_loss': [],
|
|
|
'val_loss': [],
|
|
|
'val_accuracy': []
|
|
|
}
|
|
|
|
|
|
print(f"Training on device: {self.device}")
|
|
|
print(f"Total training steps: {total_steps}")
|
|
|
|
|
|
for epoch in range(self.config.num_epochs):
|
|
|
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")
|
|
|
print("-" * 50)
|
|
|
|
|
|
|
|
|
train_loss = self._train_epoch(train_dataloader, optimizer, scheduler)
|
|
|
history['train_loss'].append(train_loss)
|
|
|
|
|
|
|
|
|
val_loss, val_accuracy = self._validate_epoch(val_dataloader)
|
|
|
history['val_loss'].append(val_loss)
|
|
|
history['val_accuracy'].append(val_accuracy)
|
|
|
|
|
|
print(f"Train Loss: {train_loss:.4f}")
|
|
|
print(f"Val Loss: {val_loss:.4f}")
|
|
|
print(f"Val Accuracy: {val_accuracy:.4f}")
|
|
|
|
|
|
return history
|
|
|
|
|
|
def _train_epoch(self, dataloader: DataLoader, optimizer, scheduler) -> float:
|
|
|
"""Train for one epoch."""
|
|
|
self.model.train()
|
|
|
total_loss = 0
|
|
|
|
|
|
for batch_idx, batch in enumerate(dataloader):
|
|
|
|
|
|
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
outputs = self.model(**batch)
|
|
|
loss = outputs.loss
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
|
|
|
|
optimizer.step()
|
|
|
scheduler.step()
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
if batch_idx % 10 == 0:
|
|
|
print(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
|
|
|
|
|
|
return total_loss / len(dataloader)
|
|
|
|
|
|
def _validate_epoch(self, dataloader: DataLoader) -> Tuple[float, float]:
|
|
|
"""Validate for one epoch."""
|
|
|
self.model.eval()
|
|
|
total_loss = 0
|
|
|
total_correct = 0
|
|
|
total_tokens = 0
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch in dataloader:
|
|
|
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
|
|
|
|
outputs = self.model(**batch)
|
|
|
loss = outputs.loss
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
|
|
|
predictions = torch.argmax(outputs.logits, dim=-1)
|
|
|
labels = batch['labels']
|
|
|
|
|
|
|
|
|
valid_mask = labels != -100
|
|
|
|
|
|
correct = (predictions == labels) & valid_mask
|
|
|
total_correct += correct.sum().item()
|
|
|
total_tokens += valid_mask.sum().item()
|
|
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
|
accuracy = total_correct / total_tokens if total_tokens > 0 else 0
|
|
|
|
|
|
return avg_loss, accuracy
|
|
|
|
|
|
def save_model(self, save_path: str):
|
|
|
"""Save the trained model and tokenizer."""
|
|
|
self.model.model.save_pretrained(save_path)
|
|
|
self.tokenizer.save_pretrained(save_path)
|
|
|
|
|
|
|
|
|
config_path = f"{save_path}/training_config.json"
|
|
|
with open(config_path, 'w') as f:
|
|
|
json.dump(vars(self.config), f, indent=2)
|
|
|
|
|
|
print(f"Model saved to {save_path}")
|
|
|
|
|
|
def load_model(self, model_path: str):
|
|
|
"""Load a pre-trained model."""
|
|
|
self.model.model = DistilBertForTokenClassification.from_pretrained(model_path)
|
|
|
self.tokenizer = DistilBertTokenizer.from_pretrained(model_path)
|
|
|
self.model.to(self.device)
|
|
|
print(f"Model loaded from {model_path}")
|
|
|
|
|
|
|
|
|
def create_model_and_trainer(config: Optional[ModelConfig] = None) -> Tuple[DocumentNERModel, NERTrainer]:
|
|
|
"""Create model and trainer with configuration."""
|
|
|
if config is None:
|
|
|
config = ModelConfig()
|
|
|
|
|
|
model = DocumentNERModel(config)
|
|
|
trainer = NERTrainer(model, config)
|
|
|
|
|
|
return model, trainer
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Demonstrate model creation and setup."""
|
|
|
|
|
|
config = ModelConfig(
|
|
|
batch_size=8,
|
|
|
num_epochs=2,
|
|
|
learning_rate=3e-5
|
|
|
)
|
|
|
|
|
|
print("Model Configuration:")
|
|
|
print(f"Model: {config.model_name}")
|
|
|
print(f"Max Length: {config.max_length}")
|
|
|
print(f"Batch Size: {config.batch_size}")
|
|
|
print(f"Learning Rate: {config.learning_rate}")
|
|
|
print(f"Number of Labels: {config.num_labels}")
|
|
|
print(f"Entity Labels: {config.entity_labels}")
|
|
|
|
|
|
|
|
|
model, trainer = create_model_and_trainer(config)
|
|
|
|
|
|
print(f"\nModel created successfully!")
|
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
|
|
|
|
|
return model, trainer
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |