| | import math |
| | import time |
| | from typing import Any, Optional, Dict, List |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from tqdm import tqdm |
| | from logger.logger import TrainerLogger |
| | from torch.utils.data import DataLoader |
| | from transformers import PreTrainedModel |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | class BaseTrainer: |
| | def __init__( |
| | self, |
| | model: PreTrainedModel, |
| | optimizer: torch.optim.Optimizer, |
| | scheduler: torch.optim.lr_scheduler._LRScheduler, |
| | tokenizer: Any, |
| | train_loader: DataLoader, |
| | test_loader: Optional[DataLoader] = None, |
| | logger_config: Dict[str, Any] = None, |
| | use_amp: bool = True, |
| | ): |
| | self.model = model.to(device) |
| | self.optimizer = optimizer |
| | self.scheduler = scheduler |
| | self.tokenizer = tokenizer |
| | self.train_loader = train_loader |
| | self.test_loader = test_loader |
| | self.use_amp = use_amp |
| | self.scaler = torch.amp.GradScaler('cuda') if use_amp else None |
| | self.train_step = 0 |
| | self._best_perplexity = float('inf') |
| | self._epochs_no_improve = 0 |
| |
|
| | total_params = sum(p.numel() for p in model.parameters()) |
| | self.logger = TrainerLogger( |
| | tracking_uri=logger_config["tracking_uri"], |
| | experiment=logger_config["experiment"], |
| | run_name=logger_config["model_name"], |
| | model_name=logger_config["model_name"], |
| | total_params=total_params, |
| | tags={"version": "1.0", "environment": "development"}, |
| | ) |
| |
|
| | def _generate_sample(self, sample_prompts: List[str] = []): |
| | self.model.eval() |
| | samples_html = "" |
| | for prompt in sample_prompts: |
| | try: |
| | |
| | inputs = self.tokenizer(prompt, return_tensors="pt") |
| | input_ids = inputs.input_ids.to(self.model.device) |
| |
|
| | |
| | with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16): |
| | generated_ids = self.model.generate( |
| | input_ids=input_ids, |
| | max_length=100, |
| | num_beams=5, |
| | do_sample=True, |
| | top_k=50, |
| | top_p=0.95, |
| | temperature=0.7, |
| | repetition_penalty=1.2, |
| | use_cache=True, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | ) |
| |
|
| | |
| | generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| | except Exception as e: |
| | generated_text = f"Erro: {e}" |
| | samples_html += f"<h4><b>prompt:</b> {prompt}</h4><p><b>Resposta:</b> {generated_text}</p>" |
| | self.model.train() |
| | return samples_html |
| |
|
| | def _calc_loss_batch(self, inputs: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Calcula apenas a entropia cruzada para um batch de input_ids, |
| | desativando o cache de chaves/valores durante o treinamento. |
| | """ |
| | ignore_idx = -100 |
| | |
| | valid = ((inputs >= 0) | (inputs == ignore_idx)) & (inputs < self.tokenizer.vocab_size) |
| | assert valid.all(), f"Há labels inválidos: min={inputs.min().item()}, max={inputs.max().item()}" |
| |
|
| | inputs = inputs.to(device) |
| | with torch.autocast(device_type="cuda", dtype=torch.float16): |
| | outputs = self.model( |
| | input_ids=inputs, |
| | labels=inputs, |
| | use_cache=False, |
| | return_dict=True |
| | ) |
| | loss = outputs.loss |
| | logits = outputs.logits |
| | if torch.isnan(logits).any() or torch.isinf(logits).any(): |
| | raise RuntimeError("Logits inválidos detectados") |
| | return loss |
| |
|
| | def _train_epoch(self, epoch: int, sample_prompts: Optional[List[str]] = None) -> List[float]: |
| | if sample_prompts is None: |
| | sample_prompts = [] |
| |
|
| | self.model.train() |
| | losses = [] |
| | size_dataset = len(self.train_loader) |
| | pbar = tqdm( |
| | self.train_loader, |
| | total=size_dataset, |
| | desc=f"Epoch {epoch + 1}", |
| | unit="batch", |
| | leave=False, |
| | ) |
| |
|
| | for i, batch in enumerate(pbar): |
| | start_time = time.time() |
| | self.optimizer.zero_grad() |
| | loss = self._calc_loss_batch(batch['input_ids']) |
| | losses.append(loss.item()) |
| |
|
| | if self.use_amp: |
| | self.scaler.scale(loss).backward() |
| | self.scaler.unscale_(self.optimizer) |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | else: |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| | self.optimizer.step() |
| |
|
| | self.scheduler.step() |
| | perplexity = math.exp(loss.item()) |
| | current_lr = self.optimizer.param_groups[0].get('lr', 0.0) |
| | elapsed_time = time.time() - start_time |
| |
|
| | pbar.set_postfix({ |
| | "loss": f"{loss.item():.4f}", |
| | "perplexity": f"{perplexity:.4f}", |
| | "lr": f"{current_lr:.2e}", |
| | "elapsed_time": f"{elapsed_time:.2f}s", |
| | }) |
| |
|
| | |
| | if (i + 1) % 100 == 0: |
| | self.train_step += 1 |
| | avg_loss = sum(losses[-100:]) / 100 |
| | avg_perplexity = math.exp(sum(losses[-100:]) / 100) |
| | self.logger.log_metrics( |
| | { |
| | "train_loss": avg_loss, |
| | "train_perplexity": avg_perplexity, |
| | "lr": current_lr, |
| | }, |
| | step=self.train_step, |
| | ) |
| |
|
| | |
| | if (i + 1) % 500 == 0: |
| | samples_html = self._generate_sample(sample_prompts) |
| | self.logger.log_html(f"<html><head><meta charset='utf-8'></head><body>{samples_html}</body></html>", |
| | step=self.train_step) |
| |
|
| | |
| | if (i + 1) % 1000 == 0: |
| | avg_loss = sum(losses[-1000:]) / 1000 |
| | avg_perplexity = math.exp(sum(losses[-1000:]) / 1000) |
| | self.logger.log_checkpoint_table(current_lr, avg_loss, avg_perplexity, i + 1) |
| | self.logger.checkpoint_model(self.model) |
| | self.model.save_pretrained(f"../") |
| |
|
| |
|
| | return losses |
| |
|
| | def train(self, num_epochs: int = 500, sample_prompts: Optional[List[str]] = None): |
| | for epoch in range(num_epochs): |
| | train_losses = self._train_epoch(epoch, sample_prompts) |
| | mean_train_loss = sum(train_losses) / len(train_losses) |
| | self.logger.log_metrics( |
| | {"mean_train_loss": mean_train_loss}, |
| | step=epoch, |
| | ) |
| | print(f"Epoch {epoch + 1} | Train Loss: {mean_train_loss:.4f}") |
| |
|
| | self.logger.finish() |
| | print("Treinamento concluído!") |
| |
|
| |
|
| | |
| | class TuningTrainer(BaseTrainer): |
| | pass |
| |
|
| | |
| | class PreTrainer(BaseTrainer): |
| | pass |
| |
|