tynerox / src /training.py
Ubuntu
Re-adiciona model.safetensors via LFS
58d9159
import math
import time
from typing import Any, Optional, Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from logger.logger import TrainerLogger
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
# Configuração do dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class BaseTrainer:
def __init__(
self,
model: PreTrainedModel,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
tokenizer: Any,
train_loader: DataLoader,
test_loader: Optional[DataLoader] = None,
logger_config: Dict[str, Any] = None,
use_amp: bool = True,
):
self.model = model.to(device)
self.optimizer = optimizer
self.scheduler = scheduler
self.tokenizer = tokenizer
self.train_loader = train_loader
self.test_loader = test_loader
self.use_amp = use_amp
self.scaler = torch.amp.GradScaler('cuda') if use_amp else None
self.train_step = 0
self._best_perplexity = float('inf')
self._epochs_no_improve = 0
total_params = sum(p.numel() for p in model.parameters())
self.logger = TrainerLogger(
tracking_uri=logger_config["tracking_uri"],
experiment=logger_config["experiment"],
run_name=logger_config["model_name"],
model_name=logger_config["model_name"],
total_params=total_params,
tags={"version": "1.0", "environment": "development"},
)
def _generate_sample(self, sample_prompts: List[str] = []):
self.model.eval()
samples_html = ""
for prompt in sample_prompts:
try:
# sample_text = generate_text_sample(self.model, self.tokenizer, prompt)
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids.to(self.model.device)
# 4) Gere texto
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
generated_ids = self.model.generate(
input_ids=input_ids,
max_length=100, # comprimento total (prompt + continuação)
num_beams=5, # número de “hips” em beam search
do_sample=True, # ativa amostragem (em vez de pura greed)
top_k=50, # restringe sampling aos top-50 tokens
top_p=0.95, # usa nucleus sampling (p acumulado ≤ 0.95)
temperature=0.7, # controle de “criatividade”
repetition_penalty=1.2, # penaliza repetições exatas
use_cache=True, # reutiliza past_key_values (default)
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
# 5) Decode para string
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
except Exception as e:
generated_text = f"Erro: {e}"
samples_html += f"<h4><b>prompt:</b> {prompt}</h4><p><b>Resposta:</b> {generated_text}</p>"
self.model.train()
return samples_html
def _calc_loss_batch(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Calcula apenas a entropia cruzada para um batch de input_ids,
desativando o cache de chaves/valores durante o treinamento.
"""
ignore_idx = -100
# valida que todos os tokens estão no vocabulário ou são tokens de ignore
valid = ((inputs >= 0) | (inputs == ignore_idx)) & (inputs < self.tokenizer.vocab_size)
assert valid.all(), f"Há labels inválidos: min={inputs.min().item()}, max={inputs.max().item()}"
inputs = inputs.to(device)
with torch.autocast(device_type="cuda", dtype=torch.float16):
outputs = self.model(
input_ids=inputs,
labels=inputs,
use_cache=False, # desabilita o KV-cache no treino
return_dict=True # garante acesso via .loss e .logits
)
loss = outputs.loss
logits = outputs.logits
if torch.isnan(logits).any() or torch.isinf(logits).any():
raise RuntimeError("Logits inválidos detectados")
return loss
def _train_epoch(self, epoch: int, sample_prompts: Optional[List[str]] = None) -> List[float]:
if sample_prompts is None:
sample_prompts = []
self.model.train()
losses = []
size_dataset = len(self.train_loader)
pbar = tqdm(
self.train_loader,
total=size_dataset,
desc=f"Epoch {epoch + 1}",
unit="batch",
leave=False,
)
for i, batch in enumerate(pbar):
start_time = time.time()
self.optimizer.zero_grad()
loss = self._calc_loss_batch(batch['input_ids'])
losses.append(loss.item())
if self.use_amp:
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
self.scheduler.step()
perplexity = math.exp(loss.item())
current_lr = self.optimizer.param_groups[0].get('lr', 0.0)
elapsed_time = time.time() - start_time
pbar.set_postfix({
"loss": f"{loss.item():.4f}",
"perplexity": f"{perplexity:.4f}",
"lr": f"{current_lr:.2e}",
"elapsed_time": f"{elapsed_time:.2f}s",
})
# Logging a cada 100 batches
if (i + 1) % 100 == 0:
self.train_step += 1
avg_loss = sum(losses[-100:]) / 100
avg_perplexity = math.exp(sum(losses[-100:]) / 100)
self.logger.log_metrics(
{
"train_loss": avg_loss,
"train_perplexity": avg_perplexity,
"lr": current_lr,
},
step=self.train_step,
)
# Gera samples
if (i + 1) % 500 == 0:
samples_html = self._generate_sample(sample_prompts)
self.logger.log_html(f"<html><head><meta charset='utf-8'></head><body>{samples_html}</body></html>",
step=self.train_step)
# Checkpoint a cada 1000 batches
if (i + 1) % 1000 == 0:
avg_loss = sum(losses[-1000:]) / 1000
avg_perplexity = math.exp(sum(losses[-1000:]) / 1000)
self.logger.log_checkpoint_table(current_lr, avg_loss, avg_perplexity, i + 1)
self.logger.checkpoint_model(self.model)
self.model.save_pretrained(f"../")
return losses
def train(self, num_epochs: int = 500, sample_prompts: Optional[List[str]] = None):
for epoch in range(num_epochs):
train_losses = self._train_epoch(epoch, sample_prompts)
mean_train_loss = sum(train_losses) / len(train_losses)
self.logger.log_metrics(
{"mean_train_loss": mean_train_loss},
step=epoch,
)
print(f"Epoch {epoch + 1} | Train Loss: {mean_train_loss:.4f}")
self.logger.finish()
print("Treinamento concluído!")
# Exemplo de uso para Fine-Tuning:
class TuningTrainer(BaseTrainer):
pass
# Exemplo de uso para Pré-Treinamento:
class PreTrainer(BaseTrainer):
pass