import math import time from typing import Any, Optional, Dict, List import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from logger.logger import TrainerLogger from torch.utils.data import DataLoader from transformers import PreTrainedModel # Configuração do dispositivo device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class BaseTrainer: def __init__( self, model: PreTrainedModel, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler, tokenizer: Any, train_loader: DataLoader, test_loader: Optional[DataLoader] = None, logger_config: Dict[str, Any] = None, use_amp: bool = True, ): self.model = model.to(device) self.optimizer = optimizer self.scheduler = scheduler self.tokenizer = tokenizer self.train_loader = train_loader self.test_loader = test_loader self.use_amp = use_amp self.scaler = torch.amp.GradScaler('cuda') if use_amp else None self.train_step = 0 self._best_perplexity = float('inf') self._epochs_no_improve = 0 total_params = sum(p.numel() for p in model.parameters()) self.logger = TrainerLogger( tracking_uri=logger_config["tracking_uri"], experiment=logger_config["experiment"], run_name=logger_config["model_name"], model_name=logger_config["model_name"], total_params=total_params, tags={"version": "1.0", "environment": "development"}, ) def _generate_sample(self, sample_prompts: List[str] = []): self.model.eval() samples_html = "" for prompt in sample_prompts: try: # sample_text = generate_text_sample(self.model, self.tokenizer, prompt) inputs = self.tokenizer(prompt, return_tensors="pt") input_ids = inputs.input_ids.to(self.model.device) # 4) Gere texto with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16): generated_ids = self.model.generate( input_ids=input_ids, max_length=100, # comprimento total (prompt + continuação) num_beams=5, # número de “hips” em beam search do_sample=True, # ativa amostragem (em vez de pura greed) top_k=50, # restringe sampling aos top-50 tokens top_p=0.95, # usa nucleus sampling (p acumulado ≤ 0.95) temperature=0.7, # controle de “criatividade” repetition_penalty=1.2, # penaliza repetições exatas use_cache=True, # reutiliza past_key_values (default) eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, ) # 5) Decode para string generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) except Exception as e: generated_text = f"Erro: {e}" samples_html += f"

prompt: {prompt}

Resposta: {generated_text}

" self.model.train() return samples_html def _calc_loss_batch(self, inputs: torch.Tensor) -> torch.Tensor: """ Calcula apenas a entropia cruzada para um batch de input_ids, desativando o cache de chaves/valores durante o treinamento. """ ignore_idx = -100 # valida que todos os tokens estão no vocabulário ou são tokens de ignore valid = ((inputs >= 0) | (inputs == ignore_idx)) & (inputs < self.tokenizer.vocab_size) assert valid.all(), f"Há labels inválidos: min={inputs.min().item()}, max={inputs.max().item()}" inputs = inputs.to(device) with torch.autocast(device_type="cuda", dtype=torch.float16): outputs = self.model( input_ids=inputs, labels=inputs, use_cache=False, # desabilita o KV-cache no treino return_dict=True # garante acesso via .loss e .logits ) loss = outputs.loss logits = outputs.logits if torch.isnan(logits).any() or torch.isinf(logits).any(): raise RuntimeError("Logits inválidos detectados") return loss def _train_epoch(self, epoch: int, sample_prompts: Optional[List[str]] = None) -> List[float]: if sample_prompts is None: sample_prompts = [] self.model.train() losses = [] size_dataset = len(self.train_loader) pbar = tqdm( self.train_loader, total=size_dataset, desc=f"Epoch {epoch + 1}", unit="batch", leave=False, ) for i, batch in enumerate(pbar): start_time = time.time() self.optimizer.zero_grad() loss = self._calc_loss_batch(batch['input_ids']) losses.append(loss.item()) if self.use_amp: self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() self.scheduler.step() perplexity = math.exp(loss.item()) current_lr = self.optimizer.param_groups[0].get('lr', 0.0) elapsed_time = time.time() - start_time pbar.set_postfix({ "loss": f"{loss.item():.4f}", "perplexity": f"{perplexity:.4f}", "lr": f"{current_lr:.2e}", "elapsed_time": f"{elapsed_time:.2f}s", }) # Logging a cada 100 batches if (i + 1) % 100 == 0: self.train_step += 1 avg_loss = sum(losses[-100:]) / 100 avg_perplexity = math.exp(sum(losses[-100:]) / 100) self.logger.log_metrics( { "train_loss": avg_loss, "train_perplexity": avg_perplexity, "lr": current_lr, }, step=self.train_step, ) # Gera samples if (i + 1) % 500 == 0: samples_html = self._generate_sample(sample_prompts) self.logger.log_html(f"{samples_html}", step=self.train_step) # Checkpoint a cada 1000 batches if (i + 1) % 1000 == 0: avg_loss = sum(losses[-1000:]) / 1000 avg_perplexity = math.exp(sum(losses[-1000:]) / 1000) self.logger.log_checkpoint_table(current_lr, avg_loss, avg_perplexity, i + 1) self.logger.checkpoint_model(self.model) self.model.save_pretrained(f"../") return losses def train(self, num_epochs: int = 500, sample_prompts: Optional[List[str]] = None): for epoch in range(num_epochs): train_losses = self._train_epoch(epoch, sample_prompts) mean_train_loss = sum(train_losses) / len(train_losses) self.logger.log_metrics( {"mean_train_loss": mean_train_loss}, step=epoch, ) print(f"Epoch {epoch + 1} | Train Loss: {mean_train_loss:.4f}") self.logger.finish() print("Treinamento concluído!") # Exemplo de uso para Fine-Tuning: class TuningTrainer(BaseTrainer): pass # Exemplo de uso para Pré-Treinamento: class PreTrainer(BaseTrainer): pass