"""Training orchestration for Précis.""" import logging from pathlib import Path from typing import Optional from transformers import ( Trainer, TrainingArguments, PreTrainedModel, PreTrainedTokenizer, DataCollatorForLanguageModeling, ) from torch.utils.data import Dataset from src.config import TrainingConfig logger = logging.getLogger(__name__) class PrecisTrainer: """Wrapper around HuggingFace Trainer for summarization fine-tuning.""" def __init__( self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, train_dataset: Dataset, eval_dataset: Optional[Dataset] = None, config: Optional[TrainingConfig] = None, ): self.model = model self.tokenizer = tokenizer self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.config = config or TrainingConfig() self.training_args = self._create_training_args() self.trainer = self._create_trainer() def _create_training_args(self) -> TrainingArguments: """Create HuggingFace TrainingArguments from config.""" return TrainingArguments( output_dir=self.config.output_dir, num_train_epochs=self.config.num_epochs, per_device_train_batch_size=self.config.batch_size, gradient_accumulation_steps=self.config.gradient_accumulation_steps, learning_rate=self.config.learning_rate, warmup_ratio=self.config.warmup_ratio, weight_decay=self.config.weight_decay, max_grad_norm=self.config.max_grad_norm, optim=self.config.optim, logging_steps=self.config.logging_steps, save_steps=self.config.save_steps, eval_steps=self.config.eval_steps if self.eval_dataset else None, evaluation_strategy="steps" if self.eval_dataset else "no", save_total_limit=3, load_best_model_at_end=bool(self.eval_dataset), seed=self.config.seed, fp16=True, report_to="none", ) def _create_trainer(self) -> Trainer: """Create HuggingFace Trainer instance.""" data_collator = DataCollatorForLanguageModeling( tokenizer=self.tokenizer, mlm=False, ) return Trainer( model=self.model, args=self.training_args, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, data_collator=data_collator, ) def train(self) -> None: """Execute training loop.""" logger.info("Starting training...") self.trainer.train() logger.info("Training complete.") def evaluate(self) -> dict: """Run evaluation and return metrics.""" if self.eval_dataset is None: logger.warning("No eval dataset provided") return {} logger.info("Running evaluation...") return self.trainer.evaluate() def save(self, output_path: Optional[str] = None) -> None: """Save model checkpoint.""" path = output_path or self.config.output_dir Path(path).mkdir(parents=True, exist_ok=True) self.trainer.save_model(path) self.tokenizer.save_pretrained(path) logger.info(f"Model saved to {path}")