| | import json
|
| | import shutil
|
| | import torch
|
| | import torch.nn as nn
|
| | from torch.utils.data import Dataset, DataLoader
|
| | from transformers import (
|
| | DebertaV2Model,
|
| | DebertaV2TokenizerFast,
|
| | DebertaV2Config,
|
| | get_linear_schedule_with_warmup,
|
| | set_seed
|
| | )
|
| | from torch.cuda.amp import autocast
|
| | from tqdm import tqdm
|
| | import numpy as np
|
| | from pathlib import Path
|
| | import logging
|
| | from dataclasses import dataclass
|
| | from typing import Optional, Dict, List, Tuple
|
| | import wandb
|
| | from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
| | import functools
|
| | import re
|
| |
|
| |
|
| | logging.basicConfig(
|
| | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
| | datefmt='%m/%d/%Y %H:%M:%S',
|
| | level=logging.INFO
|
| | )
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | @dataclass
|
| | class TrainingConfig:
|
| | """Training configuration for link token classification"""
|
| |
|
| | model_name: str = "microsoft/deberta-v3-large"
|
| | num_labels: int = 2
|
| |
|
| |
|
| | train_file: str = "train_windows.jsonl"
|
| | val_file: str = "val_windows.jsonl"
|
| | max_length: int = 512
|
| |
|
| |
|
| | batch_size: int = 8
|
| | gradient_accumulation_steps: int = 8
|
| | num_epochs: int = 3
|
| | learning_rate: float = 1e-6
|
| | warmup_ratio: float = 0.1
|
| | weight_decay: float = 0.01
|
| | max_grad_norm: float = 1.0
|
| | label_smoothing: float = 0.0
|
| |
|
| |
|
| | device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| | num_workers: int = 0
|
| | seed: int = 42
|
| | bf16: bool = True
|
| |
|
| |
|
| | logging_steps: int = 1
|
| | eval_steps: int = 5000
|
| | save_steps: int = 10000
|
| | output_dir: str = "./deberta_link_output"
|
| |
|
| |
|
| | wandb_project: str = "deberta-link-classification"
|
| | wandb_name: str = "deberta-v3-large-link-tokens"
|
| |
|
| |
|
| | patience: int = 2
|
| | min_delta: float = 0.0001
|
| |
|
| |
|
| | max_checkpoints: int = 5
|
| | protect_latest_epoch_step: bool = True
|
| |
|
| |
|
| | class LinkTokenDataset(Dataset):
|
| | """Dataset for link token classification"""
|
| |
|
| | def __init__(self, file_path: str, max_samples: Optional[int] = None):
|
| | self.data = []
|
| |
|
| | logger.info(f"Loading data from {file_path}")
|
| | seq_lengths = []
|
| |
|
| | with open(file_path, 'r') as f:
|
| | for i, line in enumerate(f):
|
| | if max_samples and i >= max_samples:
|
| | break
|
| | sample = json.loads(line)
|
| |
|
| | seq_len = len(sample['input_ids'])
|
| | seq_lengths.append(seq_len)
|
| |
|
| |
|
| | sample['input_ids'] = torch.tensor(sample['input_ids'], dtype=torch.long)
|
| | sample['attention_mask'] = torch.tensor(sample['attention_mask'], dtype=torch.long)
|
| | sample['labels'] = torch.tensor(sample['labels'], dtype=torch.long)
|
| |
|
| | self.data.append(sample)
|
| |
|
| | logger.info(f"Loaded {len(self.data)} samples")
|
| | logger.info(f"Sequence lengths - Min: {min(seq_lengths)}, Max: {max(seq_lengths)}, Avg: {np.mean(seq_lengths):.1f}")
|
| |
|
| |
|
| | total_labels = []
|
| | for s in self.data:
|
| |
|
| | valid_labels = s['labels'][s['labels'] != -100]
|
| | total_labels.append(valid_labels)
|
| |
|
| |
|
| | if total_labels:
|
| | total_labels = torch.cat(total_labels)
|
| | num_link_tokens = (total_labels == 1).sum().item()
|
| | num_non_link = (total_labels == 0).sum().item()
|
| |
|
| | logger.info(f"Label distribution - Non-link: {num_non_link}, Link: {num_link_tokens}")
|
| | if (num_link_tokens + num_non_link) > 0:
|
| | logger.info(f"Link token ratio: {num_link_tokens / (num_link_tokens + num_non_link):.4%}")
|
| | else:
|
| | logger.info("No valid labels found in the dataset.")
|
| |
|
| | def __len__(self):
|
| | return len(self.data)
|
| |
|
| | def __getitem__(self, idx):
|
| | return self.data[idx]
|
| |
|
| |
|
| | def collate_fn(batch: List[Dict], max_seq_length: int) -> Dict[str, torch.Tensor]:
|
| | """
|
| | Custom collate function for batching with padding to a fixed max_seq_length.
|
| |
|
| | Args:
|
| | batch (List[Dict]): A list of samples from the dataset.
|
| | max_seq_length (int): The maximum sequence length to pad all samples to.
|
| |
|
| | Returns:
|
| | Dict[str, torch.Tensor]: A dictionary containing stacked and padded tensors.
|
| | """
|
| |
|
| | input_ids = []
|
| | attention_mask = []
|
| | labels = []
|
| |
|
| | for x in batch:
|
| | seq_len = len(x['input_ids'])
|
| |
|
| |
|
| | if seq_len > max_seq_length:
|
| | x['input_ids'] = x['input_ids'][:max_seq_length]
|
| | x['attention_mask'] = x['attention_mask'][:max_seq_length]
|
| | x['labels'] = x['labels'][:max_seq_length]
|
| | seq_len = max_seq_length
|
| |
|
| |
|
| | padding_len = max_seq_length - seq_len
|
| |
|
| |
|
| | padded_input = torch.cat([
|
| | x['input_ids'],
|
| | torch.zeros(padding_len, dtype=torch.long)
|
| | ])
|
| |
|
| |
|
| | padded_mask = torch.cat([
|
| | x['attention_mask'],
|
| | torch.zeros(padding_len, dtype=torch.long)
|
| | ])
|
| |
|
| |
|
| | padded_labels = torch.cat([
|
| | x['labels'],
|
| | torch.full((padding_len,), -100, dtype=torch.long)
|
| | ])
|
| |
|
| | input_ids.append(padded_input)
|
| | attention_mask.append(padded_mask)
|
| | labels.append(padded_labels)
|
| |
|
| | return {
|
| | 'input_ids': torch.stack(input_ids),
|
| | 'attention_mask': torch.stack(attention_mask),
|
| | 'labels': torch.stack(labels)
|
| | }
|
| |
|
| |
|
| | class DeBERTaForTokenClassification(nn.Module):
|
| | """DeBERTa model for token classification"""
|
| |
|
| | def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1):
|
| | super().__init__()
|
| |
|
| | self.config = DebertaV2Config.from_pretrained(model_name)
|
| | self.deberta = DebertaV2Model.from_pretrained(model_name)
|
| |
|
| | self.dropout = nn.Dropout(dropout_rate)
|
| | self.classifier = nn.Linear(self.config.hidden_size, num_labels)
|
| |
|
| |
|
| | nn.init.xavier_uniform_(self.classifier.weight)
|
| | nn.init.zeros_(self.classifier.bias)
|
| |
|
| | def forward(
|
| | self,
|
| | input_ids: torch.Tensor,
|
| | attention_mask: torch.Tensor,
|
| | labels: Optional[torch.Tensor] = None
|
| | ) -> Dict[str, torch.Tensor]:
|
| |
|
| | outputs = self.deberta(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask
|
| | )
|
| |
|
| | sequence_output = outputs.last_hidden_state
|
| | sequence_output = self.dropout(sequence_output)
|
| | logits = self.classifier(sequence_output)
|
| |
|
| | loss = None
|
| | if labels is not None:
|
| |
|
| |
|
| |
|
| | weight = torch.tensor([1.0, 25.0]).to(logits.device)
|
| |
|
| | loss_fct = nn.CrossEntropyLoss(weight=weight, ignore_index=-100)
|
| |
|
| |
|
| | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| |
|
| | return {
|
| | 'loss': loss,
|
| | 'logits': logits
|
| | }
|
| |
|
| |
|
| | def compute_metrics(predictions: np.ndarray, labels: np.ndarray, mask: np.ndarray) -> Dict[str, float]:
|
| | """Compute metrics for token classification"""
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | predictions_flat = predictions.flatten()
|
| | labels_flat = labels.flatten()
|
| | mask_flat = mask.flatten()
|
| |
|
| |
|
| | valid_indices = (labels_flat != -100) & (mask_flat == 1)
|
| |
|
| | preds_filtered = predictions_flat[valid_indices]
|
| | labels_filtered = labels_flat[valid_indices]
|
| |
|
| |
|
| | if len(labels_filtered) == 0:
|
| | return {
|
| | 'accuracy': 0.0,
|
| | 'precision': 0.0,
|
| | 'recall': 0.0,
|
| | 'f1': 0.0,
|
| | 'f1_non_link': 0.0,
|
| | 'f1_link': 0.0,
|
| | 'precision_link': 0.0,
|
| | 'recall_link': 0.0,
|
| | 'num_valid_tokens': 0
|
| | }
|
| |
|
| |
|
| | accuracy = accuracy_score(labels_filtered, preds_filtered)
|
| |
|
| | precision, recall, f1, support = precision_recall_fscore_support(
|
| | labels_filtered, preds_filtered, average='binary', pos_label=1, zero_division=0
|
| | )
|
| |
|
| |
|
| | unique_labels_in_data = np.unique(labels_filtered)
|
| |
|
| | precision_per_class = [0.0, 0.0]
|
| | recall_per_class = [0.0, 0.0]
|
| | f1_per_class = [0.0, 0.0]
|
| |
|
| |
|
| | if 0 in unique_labels_in_data:
|
| | p0, r0, f0, _ = precision_recall_fscore_support(
|
| | labels_filtered, preds_filtered, labels=[0], average='binary', pos_label=0, zero_division=0
|
| | )
|
| | precision_per_class[0] = p0
|
| | recall_per_class[0] = r0
|
| | f1_per_class[0] = f0
|
| |
|
| |
|
| | if 1 in unique_labels_in_data:
|
| | p1, r1, f1_1, _ = precision_recall_fscore_support(
|
| | labels_filtered, preds_filtered, labels=[1], average='binary', pos_label=1, zero_division=0
|
| | )
|
| | precision_per_class[1] = p1
|
| | recall_per_class[1] = r1
|
| | f1_per_class[1] = f1_1
|
| |
|
| | return {
|
| | 'accuracy': accuracy,
|
| | 'precision': precision,
|
| | 'recall': recall,
|
| | 'f1': f1,
|
| | 'f1_non_link': f1_per_class[0],
|
| | 'f1_link': f1_per_class[1],
|
| | 'precision_link': precision_per_class[1],
|
| | 'recall_link': recall_per_class[1],
|
| | 'num_valid_tokens': len(labels_filtered)
|
| | }
|
| |
|
| |
|
| | class Trainer:
|
| | """Trainer class for DeBERTa token classification"""
|
| |
|
| | def __init__(self, config: TrainingConfig):
|
| | self.config = config
|
| | set_seed(config.seed)
|
| |
|
| |
|
| | wandb.init(
|
| | project=config.wandb_project,
|
| | name=config.wandb_name,
|
| | config=vars(config)
|
| | )
|
| |
|
| |
|
| | Path(config.output_dir).mkdir(parents=True, exist_ok=True)
|
| |
|
| |
|
| | self.train_dataset = LinkTokenDataset(config.train_file)
|
| | self.val_dataset = LinkTokenDataset(config.val_file)
|
| |
|
| |
|
| |
|
| | self.train_loader = DataLoader(
|
| | self.train_dataset,
|
| | batch_size=config.batch_size,
|
| | shuffle=False,
|
| | num_workers=config.num_workers,
|
| | collate_fn=functools.partial(collate_fn, max_seq_length=config.max_length),
|
| | pin_memory=True
|
| | )
|
| |
|
| | self.val_loader = DataLoader(
|
| | self.val_dataset,
|
| | batch_size=config.batch_size * 2,
|
| | shuffle=False,
|
| | num_workers=config.num_workers,
|
| | collate_fn=functools.partial(collate_fn, max_seq_length=config.max_length),
|
| | pin_memory=True
|
| | )
|
| |
|
| |
|
| | self.model = DeBERTaForTokenClassification(
|
| | config.model_name,
|
| | config.num_labels
|
| | ).to(config.device)
|
| |
|
| |
|
| | total_params = sum(p.numel() for p in self.model.parameters())
|
| | trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| | logger.info(f"Total parameters: {total_params:,}")
|
| | logger.info(f"Trainable parameters: {trainable_params:,}")
|
| |
|
| |
|
| | no_decay = ['bias', 'LayerNorm.weight']
|
| | optimizer_grouped_parameters = [
|
| | {
|
| | 'params': [p for n, p in self.model.named_parameters()
|
| | if not any(nd in n for nd in no_decay)],
|
| | 'weight_decay': config.weight_decay
|
| | },
|
| | {
|
| | 'params': [p for n, p in self.model.named_parameters()
|
| | if any(nd in n for nd in no_decay)],
|
| | 'weight_decay': 0.0
|
| | }
|
| | ]
|
| |
|
| | self.optimizer = torch.optim.AdamW(
|
| | optimizer_grouped_parameters,
|
| | lr=config.learning_rate,
|
| | eps=1e-6
|
| | )
|
| |
|
| |
|
| | total_steps = len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps
|
| | warmup_steps = int(total_steps * config.warmup_ratio)
|
| |
|
| | self.scheduler = get_linear_schedule_with_warmup(
|
| | self.optimizer,
|
| | num_warmup_steps=warmup_steps,
|
| | num_training_steps=total_steps
|
| | )
|
| |
|
| |
|
| | self.global_step = 0
|
| | self.best_val_loss = float('inf')
|
| | self.patience_counter = 0
|
| |
|
| | def train_epoch(self, epoch: int) -> float:
|
| | """Train for one epoch"""
|
| | self.model.train()
|
| | total_loss = 0
|
| | progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
| |
|
| |
|
| | early_stop_triggered = False
|
| |
|
| | for step, batch in enumerate(progress_bar):
|
| |
|
| | batch = {k: v.to(self.config.device) for k, v in batch.items()}
|
| |
|
| |
|
| | if self.config.bf16:
|
| | with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| | outputs = self.model(**batch)
|
| | loss = outputs['loss'] / self.config.gradient_accumulation_steps
|
| | else:
|
| | outputs = self.model(**batch)
|
| | loss = outputs['loss'] / self.config.gradient_accumulation_steps
|
| |
|
| |
|
| | if torch.isnan(loss) or torch.isinf(loss):
|
| | logger.warning(f"NaN or Inf loss encountered at step {self.global_step}. Skipping backward pass.")
|
| | self.optimizer.zero_grad()
|
| | continue
|
| |
|
| | loss.backward()
|
| | total_loss += loss.item()
|
| |
|
| |
|
| | if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
| | self.optimizer.step()
|
| | self.scheduler.step()
|
| | self.optimizer.zero_grad()
|
| | self.global_step += 1
|
| |
|
| |
|
| | if self.global_step % self.config.logging_steps == 0:
|
| | current_loss = loss.item() * self.config.gradient_accumulation_steps
|
| | wandb.log({
|
| | 'train/loss': current_loss,
|
| | 'train/learning_rate': self.scheduler.get_last_lr()[0],
|
| | 'train/global_step': self.global_step,
|
| | 'train/epoch': epoch
|
| | })
|
| | progress_bar.set_postfix({'loss': f'{current_loss:.4f}'})
|
| |
|
| |
|
| | if self.global_step % self.config.eval_steps == 0:
|
| | eval_metrics = self.evaluate()
|
| | logger.info(f"Step {self.global_step} - Eval metrics: {eval_metrics}")
|
| |
|
| |
|
| | current_val_loss = eval_metrics['loss']
|
| | if current_val_loss < self.best_val_loss - self.config.min_delta:
|
| | self.best_val_loss = current_val_loss
|
| | self.patience_counter = 0
|
| | self.save_model(f"best_model_step_{self.global_step}")
|
| | logger.info(f"New best validation loss: {self.best_val_loss:.4f}")
|
| | else:
|
| | self.patience_counter += 1
|
| | logger.info(f"No improvement in validation loss. Patience: {self.patience_counter}/{self.config.patience}")
|
| | if self.patience_counter >= self.config.patience:
|
| | logger.info("Early stopping triggered mid-epoch!")
|
| | early_stop_triggered = True
|
| | break
|
| |
|
| | if early_stop_triggered:
|
| | break
|
| |
|
| | return total_loss / len(self.train_loader) if len(self.train_loader) > 0 else 0.0
|
| |
|
| | def evaluate(self) -> Dict[str, float]:
|
| | """Evaluate on validation set"""
|
| | self.model.eval()
|
| |
|
| | all_predictions = []
|
| | all_labels = []
|
| | all_masks = []
|
| | total_loss = 0
|
| | num_batches = 0
|
| |
|
| | with torch.no_grad():
|
| | for batch in tqdm(self.val_loader, desc="Evaluating"):
|
| | batch = {k: v.to(self.config.device) for k, v in batch.items()}
|
| |
|
| |
|
| | if self.config.bf16:
|
| | with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| | outputs = self.model(**batch)
|
| | else:
|
| | outputs = self.model(**batch)
|
| |
|
| | if outputs['loss'] is not None:
|
| | total_loss += outputs['loss'].item()
|
| | num_batches += 1
|
| |
|
| | predictions = torch.argmax(outputs['logits'], dim=-1)
|
| |
|
| | all_predictions.append(predictions.cpu().numpy())
|
| | all_labels.append(batch['labels'].cpu().numpy())
|
| | all_masks.append(batch['attention_mask'].cpu().numpy())
|
| |
|
| | all_predictions = np.concatenate(all_predictions, axis=0)
|
| | all_labels = np.concatenate(all_labels, axis=0)
|
| | all_masks = np.concatenate(all_masks, axis=0)
|
| |
|
| |
|
| | metrics = compute_metrics(all_predictions, all_labels, all_masks)
|
| | metrics['loss'] = total_loss / num_batches if num_batches > 0 else 0.0
|
| |
|
| |
|
| | wandb.log({f'eval/{k}': v for k, v in metrics.items()}, step=self.global_step)
|
| |
|
| | self.model.train()
|
| | return metrics
|
| |
|
| | def _enforce_checkpoint_limit(self):
|
| | """
|
| | Enforce checkpoint retention:
|
| | - Count all subdirectories in output_dir except 'final_model'
|
| | - Keep at most config.max_checkpoints
|
| | - Delete oldest by modification time
|
| | - Always protect:
|
| | * 'final_model'
|
| | * latest 'best_model_epoch_*'
|
| | * latest 'best_model_step_*'
|
| | """
|
| | output_dir = Path(self.config.output_dir)
|
| | if not output_dir.exists():
|
| | return
|
| |
|
| |
|
| | subdirs = [p for p in output_dir.iterdir() if p.is_dir()]
|
| | if not subdirs:
|
| | return
|
| |
|
| |
|
| | protected = set()
|
| |
|
| |
|
| | final_dir = output_dir / "final_model"
|
| | if final_dir.exists() and final_dir.is_dir():
|
| | protected.add(final_dir.resolve())
|
| |
|
| | if self.config.protect_latest_epoch_step:
|
| |
|
| | epoch_dirs = [d for d in subdirs if re.match(r"best_model_epoch_\d+$", d.name)]
|
| | if epoch_dirs:
|
| | latest_epoch = max(epoch_dirs, key=lambda d: d.stat().st_mtime)
|
| | protected.add(latest_epoch.resolve())
|
| |
|
| |
|
| | step_dirs = [d for d in subdirs if re.match(r"best_model_step_\d+$", d.name)]
|
| | if step_dirs:
|
| | latest_step = max(step_dirs, key=lambda d: d.stat().st_mtime)
|
| | protected.add(latest_step.resolve())
|
| |
|
| |
|
| | counted = [d for d in subdirs if d.resolve() != final_dir.resolve()]
|
| |
|
| |
|
| | if len(counted) <= self.config.max_checkpoints:
|
| | return
|
| |
|
| |
|
| | counted_sorted = sorted(counted, key=lambda d: d.stat().st_mtime)
|
| |
|
| |
|
| | to_delete = []
|
| | current = len(counted)
|
| | for d in counted_sorted:
|
| | if current <= self.config.max_checkpoints:
|
| | break
|
| | if d.resolve() in protected:
|
| | continue
|
| | to_delete.append(d)
|
| | current -= 1
|
| |
|
| |
|
| |
|
| |
|
| | if current > self.config.max_checkpoints:
|
| |
|
| | extras = [d for d in counted_sorted if d.resolve() != final_dir.resolve() and d not in to_delete]
|
| | for d in extras:
|
| | if current <= self.config.max_checkpoints:
|
| | break
|
| |
|
| | if d.resolve() in protected:
|
| | continue
|
| | to_delete.append(d)
|
| | current -= 1
|
| |
|
| |
|
| | for d in to_delete:
|
| | try:
|
| | shutil.rmtree(d)
|
| | logger.info(f"Deleted old checkpoint: {d}")
|
| | except Exception as e:
|
| | logger.warning(f"Failed to delete {d}: {e}")
|
| |
|
| | def save_model(self, name: str):
|
| | """Save model checkpoint"""
|
| | save_path = Path(self.config.output_dir) / name
|
| | save_path.mkdir(parents=True, exist_ok=True)
|
| |
|
| |
|
| | torch.save(self.model.state_dict(), save_path / 'pytorch_model.bin')
|
| |
|
| |
|
| | with open(save_path / 'training_config.json', 'w') as f:
|
| | json.dump(vars(self.config), f, indent=4)
|
| |
|
| | logger.info(f"Model saved to {save_path}")
|
| |
|
| |
|
| | self._enforce_checkpoint_limit()
|
| |
|
| | def train(self):
|
| | """Main training loop"""
|
| | logger.info("Starting training...")
|
| | logger.info(f"Training samples: {len(self.train_dataset)}")
|
| | logger.info(f"Validation samples: {len(self.val_dataset)}")
|
| |
|
| |
|
| | total_optimization_steps = (len(self.train_loader) + self.config.gradient_accumulation_steps - 1) // self.config.gradient_accumulation_steps * self.config.num_epochs
|
| | logger.info(f"Total optimization steps: {total_optimization_steps}")
|
| | logger.info(f"Early stopping: monitoring validation loss with patience={self.config.patience}")
|
| |
|
| | for epoch in range(self.config.num_epochs):
|
| | logger.info(f"\n{'='*50}")
|
| | logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
|
| |
|
| |
|
| | avg_train_loss = self.train_epoch(epoch + 1)
|
| | logger.info(f"Average training loss: {avg_train_loss:.4f}")
|
| |
|
| |
|
| | if self.patience_counter >= self.config.patience:
|
| | logger.info("Training stopped due to early stopping during epoch.")
|
| | break
|
| |
|
| |
|
| | eval_metrics = self.evaluate()
|
| | logger.info(f"Epoch {epoch + 1} - Eval metrics:")
|
| | for key, value in eval_metrics.items():
|
| | logger.info(f" {key}: {value:.4f}")
|
| |
|
| |
|
| | current_val_loss = eval_metrics['loss']
|
| | if current_val_loss < self.best_val_loss - self.config.min_delta:
|
| | self.best_val_loss = current_val_loss
|
| | self.patience_counter = 0
|
| | self.save_model(f"best_model_epoch_{epoch + 1}")
|
| | logger.info(f"New best validation loss at epoch end: {self.best_val_loss:.4f}")
|
| | else:
|
| | self.patience_counter += 1
|
| | logger.info(f"No improvement in validation loss. Patience: {self.patience_counter}/{self.config.patience}")
|
| |
|
| |
|
| | if self.patience_counter >= self.config.patience:
|
| | logger.info("Training stopped due to early stopping")
|
| | break
|
| |
|
| |
|
| | self.save_model("final_model")
|
| |
|
| | logger.info("Training completed!")
|
| | logger.info(f"Best validation loss: {self.best_val_loss:.4f}")
|
| | wandb.finish()
|
| |
|
| |
|
| | def main():
|
| | """Main function"""
|
| | config = TrainingConfig()
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | trainer = Trainer(config)
|
| | trainer.train()
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|