|
|
import json
|
|
|
import shutil
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from transformers import (
|
|
|
DebertaV2Model,
|
|
|
DebertaV2TokenizerFast,
|
|
|
DebertaV2Config,
|
|
|
get_linear_schedule_with_warmup,
|
|
|
set_seed
|
|
|
)
|
|
|
from torch.cuda.amp import autocast
|
|
|
from tqdm import tqdm
|
|
|
import numpy as np
|
|
|
from pathlib import Path
|
|
|
import logging
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Optional, Dict, List, Tuple
|
|
|
import wandb
|
|
|
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
|
|
import functools
|
|
|
import re
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
|
|
datefmt='%m/%d/%Y %H:%M:%S',
|
|
|
level=logging.INFO
|
|
|
)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
|
class TrainingConfig:
|
|
|
"""Training configuration for link token classification"""
|
|
|
|
|
|
model_name: str = "microsoft/deberta-v3-large"
|
|
|
num_labels: int = 2
|
|
|
|
|
|
|
|
|
train_file: str = "train_windows.jsonl"
|
|
|
val_file: str = "val_windows.jsonl"
|
|
|
max_length: int = 512
|
|
|
|
|
|
|
|
|
batch_size: int = 8
|
|
|
gradient_accumulation_steps: int = 8
|
|
|
num_epochs: int = 3
|
|
|
learning_rate: float = 1e-6
|
|
|
warmup_ratio: float = 0.1
|
|
|
weight_decay: float = 0.01
|
|
|
max_grad_norm: float = 1.0
|
|
|
label_smoothing: float = 0.0
|
|
|
|
|
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
num_workers: int = 0
|
|
|
seed: int = 42
|
|
|
bf16: bool = True
|
|
|
|
|
|
|
|
|
logging_steps: int = 1
|
|
|
eval_steps: int = 5000
|
|
|
save_steps: int = 10000
|
|
|
output_dir: str = "./deberta_link_output"
|
|
|
|
|
|
|
|
|
wandb_project: str = "deberta-link-classification"
|
|
|
wandb_name: str = "deberta-v3-large-link-tokens"
|
|
|
|
|
|
|
|
|
patience: int = 2
|
|
|
min_delta: float = 0.0001
|
|
|
|
|
|
|
|
|
max_checkpoints: int = 5
|
|
|
protect_latest_epoch_step: bool = True
|
|
|
|
|
|
|
|
|
class LinkTokenDataset(Dataset):
|
|
|
"""Dataset for link token classification"""
|
|
|
|
|
|
def __init__(self, file_path: str, max_samples: Optional[int] = None):
|
|
|
self.data = []
|
|
|
|
|
|
logger.info(f"Loading data from {file_path}")
|
|
|
seq_lengths = []
|
|
|
|
|
|
with open(file_path, 'r') as f:
|
|
|
for i, line in enumerate(f):
|
|
|
if max_samples and i >= max_samples:
|
|
|
break
|
|
|
sample = json.loads(line)
|
|
|
|
|
|
seq_len = len(sample['input_ids'])
|
|
|
seq_lengths.append(seq_len)
|
|
|
|
|
|
|
|
|
sample['input_ids'] = torch.tensor(sample['input_ids'], dtype=torch.long)
|
|
|
sample['attention_mask'] = torch.tensor(sample['attention_mask'], dtype=torch.long)
|
|
|
sample['labels'] = torch.tensor(sample['labels'], dtype=torch.long)
|
|
|
|
|
|
self.data.append(sample)
|
|
|
|
|
|
logger.info(f"Loaded {len(self.data)} samples")
|
|
|
logger.info(f"Sequence lengths - Min: {min(seq_lengths)}, Max: {max(seq_lengths)}, Avg: {np.mean(seq_lengths):.1f}")
|
|
|
|
|
|
|
|
|
total_labels = []
|
|
|
for s in self.data:
|
|
|
|
|
|
valid_labels = s['labels'][s['labels'] != -100]
|
|
|
total_labels.append(valid_labels)
|
|
|
|
|
|
|
|
|
if total_labels:
|
|
|
total_labels = torch.cat(total_labels)
|
|
|
num_link_tokens = (total_labels == 1).sum().item()
|
|
|
num_non_link = (total_labels == 0).sum().item()
|
|
|
|
|
|
logger.info(f"Label distribution - Non-link: {num_non_link}, Link: {num_link_tokens}")
|
|
|
if (num_link_tokens + num_non_link) > 0:
|
|
|
logger.info(f"Link token ratio: {num_link_tokens / (num_link_tokens + num_non_link):.4%}")
|
|
|
else:
|
|
|
logger.info("No valid labels found in the dataset.")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
return self.data[idx]
|
|
|
|
|
|
|
|
|
def collate_fn(batch: List[Dict], max_seq_length: int) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Custom collate function for batching with padding to a fixed max_seq_length.
|
|
|
|
|
|
Args:
|
|
|
batch (List[Dict]): A list of samples from the dataset.
|
|
|
max_seq_length (int): The maximum sequence length to pad all samples to.
|
|
|
|
|
|
Returns:
|
|
|
Dict[str, torch.Tensor]: A dictionary containing stacked and padded tensors.
|
|
|
"""
|
|
|
|
|
|
input_ids = []
|
|
|
attention_mask = []
|
|
|
labels = []
|
|
|
|
|
|
for x in batch:
|
|
|
seq_len = len(x['input_ids'])
|
|
|
|
|
|
|
|
|
if seq_len > max_seq_length:
|
|
|
x['input_ids'] = x['input_ids'][:max_seq_length]
|
|
|
x['attention_mask'] = x['attention_mask'][:max_seq_length]
|
|
|
x['labels'] = x['labels'][:max_seq_length]
|
|
|
seq_len = max_seq_length
|
|
|
|
|
|
|
|
|
padding_len = max_seq_length - seq_len
|
|
|
|
|
|
|
|
|
padded_input = torch.cat([
|
|
|
x['input_ids'],
|
|
|
torch.zeros(padding_len, dtype=torch.long)
|
|
|
])
|
|
|
|
|
|
|
|
|
padded_mask = torch.cat([
|
|
|
x['attention_mask'],
|
|
|
torch.zeros(padding_len, dtype=torch.long)
|
|
|
])
|
|
|
|
|
|
|
|
|
padded_labels = torch.cat([
|
|
|
x['labels'],
|
|
|
torch.full((padding_len,), -100, dtype=torch.long)
|
|
|
])
|
|
|
|
|
|
input_ids.append(padded_input)
|
|
|
attention_mask.append(padded_mask)
|
|
|
labels.append(padded_labels)
|
|
|
|
|
|
return {
|
|
|
'input_ids': torch.stack(input_ids),
|
|
|
'attention_mask': torch.stack(attention_mask),
|
|
|
'labels': torch.stack(labels)
|
|
|
}
|
|
|
|
|
|
|
|
|
class DeBERTaForTokenClassification(nn.Module):
|
|
|
"""DeBERTa model for token classification"""
|
|
|
|
|
|
def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1):
|
|
|
super().__init__()
|
|
|
|
|
|
self.config = DebertaV2Config.from_pretrained(model_name)
|
|
|
self.deberta = DebertaV2Model.from_pretrained(model_name)
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
|
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
|
|
|
|
|
|
|
|
|
nn.init.xavier_uniform_(self.classifier.weight)
|
|
|
nn.init.zeros_(self.classifier.bias)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.Tensor,
|
|
|
attention_mask: torch.Tensor,
|
|
|
labels: Optional[torch.Tensor] = None
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
|
|
|
outputs = self.deberta(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
|
|
|
sequence_output = outputs.last_hidden_state
|
|
|
sequence_output = self.dropout(sequence_output)
|
|
|
logits = self.classifier(sequence_output)
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
|
|
|
|
|
|
|
|
|
weight = torch.tensor([1.0, 25.0]).to(logits.device)
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(weight=weight, ignore_index=-100)
|
|
|
|
|
|
|
|
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
|
|
|
|
|
return {
|
|
|
'loss': loss,
|
|
|
'logits': logits
|
|
|
}
|
|
|
|
|
|
|
|
|
def compute_metrics(predictions: np.ndarray, labels: np.ndarray, mask: np.ndarray) -> Dict[str, float]:
|
|
|
"""Compute metrics for token classification"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predictions_flat = predictions.flatten()
|
|
|
labels_flat = labels.flatten()
|
|
|
mask_flat = mask.flatten()
|
|
|
|
|
|
|
|
|
valid_indices = (labels_flat != -100) & (mask_flat == 1)
|
|
|
|
|
|
preds_filtered = predictions_flat[valid_indices]
|
|
|
labels_filtered = labels_flat[valid_indices]
|
|
|
|
|
|
|
|
|
if len(labels_filtered) == 0:
|
|
|
return {
|
|
|
'accuracy': 0.0,
|
|
|
'precision': 0.0,
|
|
|
'recall': 0.0,
|
|
|
'f1': 0.0,
|
|
|
'f1_non_link': 0.0,
|
|
|
'f1_link': 0.0,
|
|
|
'precision_link': 0.0,
|
|
|
'recall_link': 0.0,
|
|
|
'num_valid_tokens': 0
|
|
|
}
|
|
|
|
|
|
|
|
|
accuracy = accuracy_score(labels_filtered, preds_filtered)
|
|
|
|
|
|
precision, recall, f1, support = precision_recall_fscore_support(
|
|
|
labels_filtered, preds_filtered, average='binary', pos_label=1, zero_division=0
|
|
|
)
|
|
|
|
|
|
|
|
|
unique_labels_in_data = np.unique(labels_filtered)
|
|
|
|
|
|
precision_per_class = [0.0, 0.0]
|
|
|
recall_per_class = [0.0, 0.0]
|
|
|
f1_per_class = [0.0, 0.0]
|
|
|
|
|
|
|
|
|
if 0 in unique_labels_in_data:
|
|
|
p0, r0, f0, _ = precision_recall_fscore_support(
|
|
|
labels_filtered, preds_filtered, labels=[0], average='binary', pos_label=0, zero_division=0
|
|
|
)
|
|
|
precision_per_class[0] = p0
|
|
|
recall_per_class[0] = r0
|
|
|
f1_per_class[0] = f0
|
|
|
|
|
|
|
|
|
if 1 in unique_labels_in_data:
|
|
|
p1, r1, f1_1, _ = precision_recall_fscore_support(
|
|
|
labels_filtered, preds_filtered, labels=[1], average='binary', pos_label=1, zero_division=0
|
|
|
)
|
|
|
precision_per_class[1] = p1
|
|
|
recall_per_class[1] = r1
|
|
|
f1_per_class[1] = f1_1
|
|
|
|
|
|
return {
|
|
|
'accuracy': accuracy,
|
|
|
'precision': precision,
|
|
|
'recall': recall,
|
|
|
'f1': f1,
|
|
|
'f1_non_link': f1_per_class[0],
|
|
|
'f1_link': f1_per_class[1],
|
|
|
'precision_link': precision_per_class[1],
|
|
|
'recall_link': recall_per_class[1],
|
|
|
'num_valid_tokens': len(labels_filtered)
|
|
|
}
|
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
"""Trainer class for DeBERTa token classification"""
|
|
|
|
|
|
def __init__(self, config: TrainingConfig):
|
|
|
self.config = config
|
|
|
set_seed(config.seed)
|
|
|
|
|
|
|
|
|
wandb.init(
|
|
|
project=config.wandb_project,
|
|
|
name=config.wandb_name,
|
|
|
config=vars(config)
|
|
|
)
|
|
|
|
|
|
|
|
|
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
self.train_dataset = LinkTokenDataset(config.train_file)
|
|
|
self.val_dataset = LinkTokenDataset(config.val_file)
|
|
|
|
|
|
|
|
|
|
|
|
self.train_loader = DataLoader(
|
|
|
self.train_dataset,
|
|
|
batch_size=config.batch_size,
|
|
|
shuffle=False,
|
|
|
num_workers=config.num_workers,
|
|
|
collate_fn=functools.partial(collate_fn, max_seq_length=config.max_length),
|
|
|
pin_memory=True
|
|
|
)
|
|
|
|
|
|
self.val_loader = DataLoader(
|
|
|
self.val_dataset,
|
|
|
batch_size=config.batch_size * 2,
|
|
|
shuffle=False,
|
|
|
num_workers=config.num_workers,
|
|
|
collate_fn=functools.partial(collate_fn, max_seq_length=config.max_length),
|
|
|
pin_memory=True
|
|
|
)
|
|
|
|
|
|
|
|
|
self.model = DeBERTaForTokenClassification(
|
|
|
config.model_name,
|
|
|
config.num_labels
|
|
|
).to(config.device)
|
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in self.model.parameters())
|
|
|
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
|
|
logger.info(f"Total parameters: {total_params:,}")
|
|
|
logger.info(f"Trainable parameters: {trainable_params:,}")
|
|
|
|
|
|
|
|
|
no_decay = ['bias', 'LayerNorm.weight']
|
|
|
optimizer_grouped_parameters = [
|
|
|
{
|
|
|
'params': [p for n, p in self.model.named_parameters()
|
|
|
if not any(nd in n for nd in no_decay)],
|
|
|
'weight_decay': config.weight_decay
|
|
|
},
|
|
|
{
|
|
|
'params': [p for n, p in self.model.named_parameters()
|
|
|
if any(nd in n for nd in no_decay)],
|
|
|
'weight_decay': 0.0
|
|
|
}
|
|
|
]
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW(
|
|
|
optimizer_grouped_parameters,
|
|
|
lr=config.learning_rate,
|
|
|
eps=1e-6
|
|
|
)
|
|
|
|
|
|
|
|
|
total_steps = len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps
|
|
|
warmup_steps = int(total_steps * config.warmup_ratio)
|
|
|
|
|
|
self.scheduler = get_linear_schedule_with_warmup(
|
|
|
self.optimizer,
|
|
|
num_warmup_steps=warmup_steps,
|
|
|
num_training_steps=total_steps
|
|
|
)
|
|
|
|
|
|
|
|
|
self.global_step = 0
|
|
|
self.best_val_loss = float('inf')
|
|
|
self.patience_counter = 0
|
|
|
|
|
|
def train_epoch(self, epoch: int) -> float:
|
|
|
"""Train for one epoch"""
|
|
|
self.model.train()
|
|
|
total_loss = 0
|
|
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
|
|
|
|
|
|
|
|
early_stop_triggered = False
|
|
|
|
|
|
for step, batch in enumerate(progress_bar):
|
|
|
|
|
|
batch = {k: v.to(self.config.device) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
if self.config.bf16:
|
|
|
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
|
|
outputs = self.model(**batch)
|
|
|
loss = outputs['loss'] / self.config.gradient_accumulation_steps
|
|
|
else:
|
|
|
outputs = self.model(**batch)
|
|
|
loss = outputs['loss'] / self.config.gradient_accumulation_steps
|
|
|
|
|
|
|
|
|
if torch.isnan(loss) or torch.isinf(loss):
|
|
|
logger.warning(f"NaN or Inf loss encountered at step {self.global_step}. Skipping backward pass.")
|
|
|
self.optimizer.zero_grad()
|
|
|
continue
|
|
|
|
|
|
loss.backward()
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
|
|
|
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
|
|
self.optimizer.step()
|
|
|
self.scheduler.step()
|
|
|
self.optimizer.zero_grad()
|
|
|
self.global_step += 1
|
|
|
|
|
|
|
|
|
if self.global_step % self.config.logging_steps == 0:
|
|
|
current_loss = loss.item() * self.config.gradient_accumulation_steps
|
|
|
wandb.log({
|
|
|
'train/loss': current_loss,
|
|
|
'train/learning_rate': self.scheduler.get_last_lr()[0],
|
|
|
'train/global_step': self.global_step,
|
|
|
'train/epoch': epoch
|
|
|
})
|
|
|
progress_bar.set_postfix({'loss': f'{current_loss:.4f}'})
|
|
|
|
|
|
|
|
|
if self.global_step % self.config.eval_steps == 0:
|
|
|
eval_metrics = self.evaluate()
|
|
|
logger.info(f"Step {self.global_step} - Eval metrics: {eval_metrics}")
|
|
|
|
|
|
|
|
|
current_val_loss = eval_metrics['loss']
|
|
|
if current_val_loss < self.best_val_loss - self.config.min_delta:
|
|
|
self.best_val_loss = current_val_loss
|
|
|
self.patience_counter = 0
|
|
|
self.save_model(f"best_model_step_{self.global_step}")
|
|
|
logger.info(f"New best validation loss: {self.best_val_loss:.4f}")
|
|
|
else:
|
|
|
self.patience_counter += 1
|
|
|
logger.info(f"No improvement in validation loss. Patience: {self.patience_counter}/{self.config.patience}")
|
|
|
if self.patience_counter >= self.config.patience:
|
|
|
logger.info("Early stopping triggered mid-epoch!")
|
|
|
early_stop_triggered = True
|
|
|
break
|
|
|
|
|
|
if early_stop_triggered:
|
|
|
break
|
|
|
|
|
|
return total_loss / len(self.train_loader) if len(self.train_loader) > 0 else 0.0
|
|
|
|
|
|
def evaluate(self) -> Dict[str, float]:
|
|
|
"""Evaluate on validation set"""
|
|
|
self.model.eval()
|
|
|
|
|
|
all_predictions = []
|
|
|
all_labels = []
|
|
|
all_masks = []
|
|
|
total_loss = 0
|
|
|
num_batches = 0
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch in tqdm(self.val_loader, desc="Evaluating"):
|
|
|
batch = {k: v.to(self.config.device) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
if self.config.bf16:
|
|
|
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
|
|
outputs = self.model(**batch)
|
|
|
else:
|
|
|
outputs = self.model(**batch)
|
|
|
|
|
|
if outputs['loss'] is not None:
|
|
|
total_loss += outputs['loss'].item()
|
|
|
num_batches += 1
|
|
|
|
|
|
predictions = torch.argmax(outputs['logits'], dim=-1)
|
|
|
|
|
|
all_predictions.append(predictions.cpu().numpy())
|
|
|
all_labels.append(batch['labels'].cpu().numpy())
|
|
|
all_masks.append(batch['attention_mask'].cpu().numpy())
|
|
|
|
|
|
all_predictions = np.concatenate(all_predictions, axis=0)
|
|
|
all_labels = np.concatenate(all_labels, axis=0)
|
|
|
all_masks = np.concatenate(all_masks, axis=0)
|
|
|
|
|
|
|
|
|
metrics = compute_metrics(all_predictions, all_labels, all_masks)
|
|
|
metrics['loss'] = total_loss / num_batches if num_batches > 0 else 0.0
|
|
|
|
|
|
|
|
|
wandb.log({f'eval/{k}': v for k, v in metrics.items()}, step=self.global_step)
|
|
|
|
|
|
self.model.train()
|
|
|
return metrics
|
|
|
|
|
|
def _enforce_checkpoint_limit(self):
|
|
|
"""
|
|
|
Enforce checkpoint retention:
|
|
|
- Count all subdirectories in output_dir except 'final_model'
|
|
|
- Keep at most config.max_checkpoints
|
|
|
- Delete oldest by modification time
|
|
|
- Always protect:
|
|
|
* 'final_model'
|
|
|
* latest 'best_model_epoch_*'
|
|
|
* latest 'best_model_step_*'
|
|
|
"""
|
|
|
output_dir = Path(self.config.output_dir)
|
|
|
if not output_dir.exists():
|
|
|
return
|
|
|
|
|
|
|
|
|
subdirs = [p for p in output_dir.iterdir() if p.is_dir()]
|
|
|
if not subdirs:
|
|
|
return
|
|
|
|
|
|
|
|
|
protected = set()
|
|
|
|
|
|
|
|
|
final_dir = output_dir / "final_model"
|
|
|
if final_dir.exists() and final_dir.is_dir():
|
|
|
protected.add(final_dir.resolve())
|
|
|
|
|
|
if self.config.protect_latest_epoch_step:
|
|
|
|
|
|
epoch_dirs = [d for d in subdirs if re.match(r"best_model_epoch_\d+$", d.name)]
|
|
|
if epoch_dirs:
|
|
|
latest_epoch = max(epoch_dirs, key=lambda d: d.stat().st_mtime)
|
|
|
protected.add(latest_epoch.resolve())
|
|
|
|
|
|
|
|
|
step_dirs = [d for d in subdirs if re.match(r"best_model_step_\d+$", d.name)]
|
|
|
if step_dirs:
|
|
|
latest_step = max(step_dirs, key=lambda d: d.stat().st_mtime)
|
|
|
protected.add(latest_step.resolve())
|
|
|
|
|
|
|
|
|
counted = [d for d in subdirs if d.resolve() != final_dir.resolve()]
|
|
|
|
|
|
|
|
|
if len(counted) <= self.config.max_checkpoints:
|
|
|
return
|
|
|
|
|
|
|
|
|
counted_sorted = sorted(counted, key=lambda d: d.stat().st_mtime)
|
|
|
|
|
|
|
|
|
to_delete = []
|
|
|
current = len(counted)
|
|
|
for d in counted_sorted:
|
|
|
if current <= self.config.max_checkpoints:
|
|
|
break
|
|
|
if d.resolve() in protected:
|
|
|
continue
|
|
|
to_delete.append(d)
|
|
|
current -= 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if current > self.config.max_checkpoints:
|
|
|
|
|
|
extras = [d for d in counted_sorted if d.resolve() != final_dir.resolve() and d not in to_delete]
|
|
|
for d in extras:
|
|
|
if current <= self.config.max_checkpoints:
|
|
|
break
|
|
|
|
|
|
if d.resolve() in protected:
|
|
|
continue
|
|
|
to_delete.append(d)
|
|
|
current -= 1
|
|
|
|
|
|
|
|
|
for d in to_delete:
|
|
|
try:
|
|
|
shutil.rmtree(d)
|
|
|
logger.info(f"Deleted old checkpoint: {d}")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to delete {d}: {e}")
|
|
|
|
|
|
def save_model(self, name: str):
|
|
|
"""Save model checkpoint"""
|
|
|
save_path = Path(self.config.output_dir) / name
|
|
|
save_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
torch.save(self.model.state_dict(), save_path / 'pytorch_model.bin')
|
|
|
|
|
|
|
|
|
with open(save_path / 'training_config.json', 'w') as f:
|
|
|
json.dump(vars(self.config), f, indent=4)
|
|
|
|
|
|
logger.info(f"Model saved to {save_path}")
|
|
|
|
|
|
|
|
|
self._enforce_checkpoint_limit()
|
|
|
|
|
|
def train(self):
|
|
|
"""Main training loop"""
|
|
|
logger.info("Starting training...")
|
|
|
logger.info(f"Training samples: {len(self.train_dataset)}")
|
|
|
logger.info(f"Validation samples: {len(self.val_dataset)}")
|
|
|
|
|
|
|
|
|
total_optimization_steps = (len(self.train_loader) + self.config.gradient_accumulation_steps - 1) // self.config.gradient_accumulation_steps * self.config.num_epochs
|
|
|
logger.info(f"Total optimization steps: {total_optimization_steps}")
|
|
|
logger.info(f"Early stopping: monitoring validation loss with patience={self.config.patience}")
|
|
|
|
|
|
for epoch in range(self.config.num_epochs):
|
|
|
logger.info(f"\n{'='*50}")
|
|
|
logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
|
|
|
|
|
|
|
|
|
avg_train_loss = self.train_epoch(epoch + 1)
|
|
|
logger.info(f"Average training loss: {avg_train_loss:.4f}")
|
|
|
|
|
|
|
|
|
if self.patience_counter >= self.config.patience:
|
|
|
logger.info("Training stopped due to early stopping during epoch.")
|
|
|
break
|
|
|
|
|
|
|
|
|
eval_metrics = self.evaluate()
|
|
|
logger.info(f"Epoch {epoch + 1} - Eval metrics:")
|
|
|
for key, value in eval_metrics.items():
|
|
|
logger.info(f" {key}: {value:.4f}")
|
|
|
|
|
|
|
|
|
current_val_loss = eval_metrics['loss']
|
|
|
if current_val_loss < self.best_val_loss - self.config.min_delta:
|
|
|
self.best_val_loss = current_val_loss
|
|
|
self.patience_counter = 0
|
|
|
self.save_model(f"best_model_epoch_{epoch + 1}")
|
|
|
logger.info(f"New best validation loss at epoch end: {self.best_val_loss:.4f}")
|
|
|
else:
|
|
|
self.patience_counter += 1
|
|
|
logger.info(f"No improvement in validation loss. Patience: {self.patience_counter}/{self.config.patience}")
|
|
|
|
|
|
|
|
|
if self.patience_counter >= self.config.patience:
|
|
|
logger.info("Training stopped due to early stopping")
|
|
|
break
|
|
|
|
|
|
|
|
|
self.save_model("final_model")
|
|
|
|
|
|
logger.info("Training completed!")
|
|
|
logger.info(f"Best validation loss: {self.best_val_loss:.4f}")
|
|
|
wandb.finish()
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Main function"""
|
|
|
config = TrainingConfig()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer(config)
|
|
|
trainer.train()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|