|
|
| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, Dataset |
|
|
| from src.models.fiscal_llm import ( |
| CLASSES_OBRIGACAO, |
| BASE_CONHECIMENTO_FISCAL, |
| TokenizadorFiscal, |
| TransformerFiscal, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ExemploTreinamento: |
| texto: str |
| classe: str |
| peso: float = 1.0 |
|
|
|
|
| |
| EXEMPLOS_CLASSIFICACAO: list[ExemploTreinamento] = [ |
| |
| |
| |
| ExemploTreinamento("gerar EFD ICMS IPI do mês de janeiro", "EFD_ICMS_IPI"), |
| ExemploTreinamento("preciso da escrituração fiscal digital ICMS", "EFD_ICMS_IPI"), |
| ExemploTreinamento("bloco C nota fiscal SPED", "EFD_ICMS_IPI"), |
| ExemploTreinamento("apuração ICMS mensal livro fiscal", "EFD_ICMS_IPI"), |
| ExemploTreinamento("EFD fiscal digital ICMS IPI vencimento", "EFD_ICMS_IPI"), |
| ExemploTreinamento("substituição tributária ICMS escrituração", "EFD_ICMS_IPI"), |
| ExemploTreinamento("CIAP controle crédito ICMS ativo imobilizado", "EFD_ICMS_IPI"), |
| ExemploTreinamento("inventário físico bloco H SPED", "EFD_ICMS_IPI"), |
| ExemploTreinamento("quando devo entregar o sped fiscal icms", "EFD_ICMS_IPI"), |
| ExemploTreinamento("bloco E110 apuração saldo devedor credor ICMS", "EFD_ICMS_IPI"), |
| ExemploTreinamento("transmissão EFD ICMS IPI prazo 15 dia útil", "EFD_ICMS_IPI"), |
| ExemploTreinamento("controle producao estoque bloco K industria", "EFD_ICMS_IPI"), |
| |
| |
| |
| ExemploTreinamento("EFD contribuições PIS COFINS mensal", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("bloco M apuração PIS COFINS", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("créditos PIS COFINS não cumulativo", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("regime não cumulativo contribuições", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("EFD Contribuições prazo entrega transmissão", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("sped pis cofins lucro real", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("receita bruta PIS COFINS bloco F", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("M200 M600 apuração contribuições bloco M", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("gerar escrituração PIS COFINS mês", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("10 dia util segundo mes EFD contribuicoes prazo", "EFD_CONTRIBUICOES"), |
| ExemploTreinamento("CST PIS COFINS 01 02 07 49 50 tributação", "EFD_CONTRIBUICOES"), |
| |
| |
| |
| ExemploTreinamento("escrituração contábil digital ECD livro diário", "ECD"), |
| ExemploTreinamento("balancete mensal ECD razão contábil", "ECD"), |
| ExemploTreinamento("plano de contas ECD lançamentos contábeis", "ECD"), |
| ExemploTreinamento("PVA SPED contábil assinatura digital", "ECD"), |
| ExemploTreinamento("prazo ECD ultimo dia util junho", "ECD"), |
| ExemploTreinamento("livro diário razão balanço ECD transmissão", "ECD"), |
| ExemploTreinamento("SPED contábil lucro real obrigatoriedade", "ECD"), |
| ExemploTreinamento("I050 I100 I200 registros ECD lançamentos", "ECD"), |
| ExemploTreinamento("bloco I balanço patrimonial demonstração resultado", "ECD"), |
| ExemploTreinamento("contador CRC assinar ECD digital", "ECD"), |
| |
| |
| |
| ExemploTreinamento("ECF escrituração contábil fiscal DIPJ", "ECF"), |
| ExemploTreinamento("LALUR LACS lucro real ajustado", "ECF"), |
| ExemploTreinamento("lucro presumido ECF preenchimento", "ECF"), |
| ExemploTreinamento("IRPJ CSLL ECF apuração anual", "ECF"), |
| ExemploTreinamento("prazo ECF julho ano seguinte", "ECF"), |
| ExemploTreinamento("bloco N620 N630 ECF IRPJ CSLL", "ECF"), |
| ExemploTreinamento("bloco P lucro presumido receita trimestral ECF", "ECF"), |
| ExemploTreinamento("substituição DIPJ declaração ECF", "ECF"), |
| ExemploTreinamento("adições exclusões compensações LALUR ECF lucro real", "ECF"), |
| ExemploTreinamento("quadro de sócios Y600 ECF participação capital", "ECF"), |
| |
| |
| |
| ExemploTreinamento("emitir nota fiscal eletrônica NF-e", "NFe"), |
| ExemploTreinamento("XML nota fiscal SEFAZ autorização", "NFe"), |
| ExemploTreinamento("cancelar NF-e protocolo SEFAZ", "NFe"), |
| ExemploTreinamento("chave acesso NF-e danfe", "NFe"), |
| ExemploTreinamento("NF-e modelo 55 emissão produto", "NFe"), |
| ExemploTreinamento("autorizar nota fiscal SEFAZ SP online", "NFe"), |
| ExemploTreinamento("status servico sefaz disponivel", "NFe"), |
| ExemploTreinamento("gerar XML NF-e versao 4.00", "NFe"), |
| ExemploTreinamento("inutilizar numeração NF-e série", "NFe"), |
| ExemploTreinamento("carta correcao NF-e CC-e evento", "NFe"), |
| ExemploTreinamento("NFC-e modelo 65 consumidor balcão PDV", "NFe"), |
| ExemploTreinamento("emitir nota consumidor NFC-e cupom fiscal", "NFe"), |
| |
| |
| |
| ExemploTreinamento("nota fiscal serviços NFS-e ISS prefeitura", "NFSe"), |
| ExemploTreinamento("RPS nota fiscal serviço eletrônica", "NFSe"), |
| ExemploTreinamento("emitir NFS-e prestação serviço municipal", "NFSe"), |
| ExemploTreinamento("nota fiscal servico prefeitura sao paulo", "NFSe"), |
| ExemploTreinamento("ISS nota servico eletronico municipal", "NFSe"), |
| ExemploTreinamento("converção RPS lote NFS-e", "NFSe"), |
| ExemploTreinamento("cancelar nota fiscal servico eletronica", "NFSe"), |
| ExemploTreinamento("webservice prefeitura NFS-e ABRASF padrão", "NFSe"), |
| ExemploTreinamento("tomador servico retenção ISS nota servico", "NFSe"), |
| ExemploTreinamento("codigo servico LC 116 NFS-e item lista", "NFSe"), |
| |
| |
| |
| ExemploTreinamento("CT-e conhecimento transporte eletrônico", "CTe"), |
| ExemploTreinamento("documento fiscal transporte carga CTe", "CTe"), |
| ExemploTreinamento("gerar CTe frete mercadoria transportadora", "CTe"), |
| ExemploTreinamento("modal rodoviário RNTRC transporte CT-e", "CTe"), |
| ExemploTreinamento("ICMS transporte interestadual alíquota CTe", "CTe"), |
| ExemploTreinamento("MDF-e manifesto documentos fiscais transporte", "CTe"), |
| ExemploTreinamento("cancelar CTe evento conhecimento transporte", "CTe"), |
| ExemploTreinamento("chave CTe 44 digitos modelo 57", "CTe"), |
| ExemploTreinamento("seguro carga averbação CTe transportadora", "CTe"), |
| ExemploTreinamento("valor frete tomador remetente destinatario CTe", "CTe"), |
| |
| |
| |
| ExemploTreinamento("e-Social folha pagamento admissão", "eSocial"), |
| ExemploTreinamento("S-2200 admissão empregado e-Social", "eSocial"), |
| ExemploTreinamento("rescisão contrato trabalho e-Social S-2299", "eSocial"), |
| ExemploTreinamento("remuneração S-1200 e-Social", "eSocial"), |
| ExemploTreinamento("GFIP SEFIP substituído e-Social", "eSocial"), |
| ExemploTreinamento("evento S-1000 empregador cadastro e-Social", "eSocial"), |
| ExemploTreinamento("rubrica folha pagamento S-1010 tabela", "eSocial"), |
| ExemploTreinamento("S-1299 fechamento folha mensal esocial", "eSocial"), |
| ExemploTreinamento("enviar eventos trabalhistas esocial", "eSocial"), |
| ExemploTreinamento("aviso previo ferias CLT esocial", "eSocial"), |
| ExemploTreinamento("CAGED admissão demissão substituído esocial", "eSocial"), |
| ExemploTreinamento("RAIS DIRF obrigação trabalhista esocial", "eSocial"), |
| |
| |
| |
| ExemploTreinamento("EFD Reinf retenções serviços PJ", "EFD_REINF"), |
| ExemploTreinamento("R-2010 serviços tomados retenção CSLL IRRF", "EFD_REINF"), |
| ExemploTreinamento("R-2099 fechamento EFD-Reinf", "EFD_REINF"), |
| ExemploTreinamento("CPRB contribuição previdenciária receita bruta", "EFD_REINF"), |
| ExemploTreinamento("R-1000 cadastro empregador EFD-Reinf", "EFD_REINF"), |
| ExemploTreinamento("R-2020 serviços prestados retenção previdenciária", "EFD_REINF"), |
| ExemploTreinamento("retenção INSS PJ serviços prestados 11%", "EFD_REINF"), |
| ExemploTreinamento("R-2060 CPRB desoneração folha pagamento", "EFD_REINF"), |
| ExemploTreinamento("declaração digital reinf prazo dia 15", "EFD_REINF"), |
| ExemploTreinamento("GFIP substituição EFD reinf retenções", "EFD_REINF"), |
| |
| |
| |
| ExemploTreinamento("DCTF declaração débitos tributários federais", "DCTF"), |
| ExemploTreinamento("DCTF mensal PGD débitos créditos", "DCTF"), |
| ExemploTreinamento("DARF IRPJ CSLL PIS COFINS declarar DCTF", "DCTF"), |
| ExemploTreinamento("prazo DCTF 15 dia util segundo mês", "DCTF"), |
| ExemploTreinamento("código receita DARF 6912 IRPJ estimativa DCTF", "DCTF"), |
| ExemploTreinamento("compensar crédito PER COMP DCTF", "DCTF"), |
| ExemploTreinamento("suspensão débito decisão judicial DCTF", "DCTF"), |
| ExemploTreinamento("retificadora DCTF corrigir débito declarado", "DCTF"), |
| ExemploTreinamento("DCTF inativa empresa sem movimento", "DCTF"), |
| ExemploTreinamento("importar XML DCTF web PGD transmitir", "DCTF"), |
| |
| |
| |
| ExemploTreinamento("PGDAS Simples Nacional DAS mensal", "PGDAS"), |
| ExemploTreinamento("MEI microempreendedor DAS simples", "PGDAS"), |
| ExemploTreinamento("alíquota simples nacional receita bruta", "PGDAS"), |
| ExemploTreinamento("DEFIS declaração simples anual", "PGDAS"), |
| ExemploTreinamento("anexo I II III IV V simples nacional", "PGDAS"), |
| ExemploTreinamento("fator R simples nacional anexo III V serviço", "PGDAS"), |
| ExemploTreinamento("RBT12 receita bruta acumulada 12 meses simples", "PGDAS"), |
| ExemploTreinamento("calcular DAS simples nacional faixa aliquota efetiva", "PGDAS"), |
| ExemploTreinamento("partilha DAS ICMS ISS CPP IRPJ simples", "PGDAS"), |
| ExemploTreinamento("vencimento DAS dia 20 simples nacional", "PGDAS"), |
| ExemploTreinamento("DeSTDA ICMS ST diferencial alíquota simples", "PGDAS"), |
| ExemploTreinamento("sublimite simples nacional estado ICMS ISS", "PGDAS"), |
| |
| |
| |
| ExemploTreinamento("calcular ICMS nota fiscal saída", "calculo_icms"), |
| ExemploTreinamento("base cálculo ICMS mercadoria frete", "calculo_icms"), |
| ExemploTreinamento("alíquota ICMS interestadual SP MG", "calculo_icms"), |
| ExemploTreinamento("DIFAL diferencial alíquota operação interestadual", "calculo_icms"), |
| ExemploTreinamento("substituição tributária MVA base ST", "calculo_icms"), |
| ExemploTreinamento("FCP fundo combate pobreza ICMS", "calculo_icms"), |
| ExemploTreinamento("quanto é o ICMS de uma venda de R$ 10000", "calculo_icms"), |
| ExemploTreinamento("redução base de cálculo ICMS benefício fiscal", "calculo_icms"), |
| ExemploTreinamento("CST 000 020 040 041 ICMS tributado isento", "calculo_icms"), |
| ExemploTreinamento("calculo icms st com mva 40 porcento", "calculo_icms"), |
| ExemploTreinamento("alíquota interna ICMS 18% estado SP", "calculo_icms"), |
| ExemploTreinamento("ICMS de entrada crédito compra mercadoria", "calculo_icms"), |
| |
| |
| |
| ExemploTreinamento("calcular IPI produto industrializado", "calculo_ipi"), |
| ExemploTreinamento("NCM alíquota IPI TIPI tabela", "calculo_ipi"), |
| ExemploTreinamento("base cálculo IPI saída estabelecimento", "calculo_ipi"), |
| ExemploTreinamento("CST IPI 50 tributado 52 isento", "calculo_ipi"), |
| ExemploTreinamento("IPI frete seguro compõe base calculo", "calculo_ipi"), |
| ExemploTreinamento("alíquota IPI automóvel NCM capítulo 87", "calculo_ipi"), |
| ExemploTreinamento("isenção IPI produto farmacêutico NCM 30", "calculo_ipi"), |
| ExemploTreinamento("saída indústria IPI débito apuração", "calculo_ipi"), |
| ExemploTreinamento("quanto IPI pago produto industrializado saida", "calculo_ipi"), |
| ExemploTreinamento("IPI credito entrada materia prima insumo", "calculo_ipi"), |
| |
| |
| |
| ExemploTreinamento("calcular PIS COFINS receita bruta", "calculo_pis_cofins"), |
| ExemploTreinamento("regime cumulativo PIS 0,65% COFINS 3%", "calculo_pis_cofins"), |
| ExemploTreinamento("créditos PIS COFINS lucro real insumos", "calculo_pis_cofins"), |
| ExemploTreinamento("retenção PIS COFINS serviços prestados PJ", "calculo_pis_cofins"), |
| ExemploTreinamento("lucro real PIS 1,65 COFINS 7,6 não cumulativo", "calculo_pis_cofins"), |
| ExemploTreinamento("crédito energia eletrica aluguel PIS COFINS", "calculo_pis_cofins"), |
| ExemploTreinamento("quanto PIS COFINS pago receita 100000 lucro presumido", "calculo_pis_cofins"), |
| ExemploTreinamento("exportação PIS COFINS alíquota zero imunidade", "calculo_pis_cofins"), |
| ExemploTreinamento("CST PIS 01 70 73 receita tributada", "calculo_pis_cofins"), |
| ExemploTreinamento("PIS COFINS monofásico combustivel farmácia", "calculo_pis_cofins"), |
| ExemploTreinamento("retenção minima PIS COFINS CSLL 10 reais serviço", "calculo_pis_cofins"), |
| |
| |
| |
| ExemploTreinamento("calcular IRPJ lucro presumido trimestral", "calculo_irpj_csll"), |
| ExemploTreinamento("LALUR adicionar exclusão lucro real IRPJ", "calculo_irpj_csll"), |
| ExemploTreinamento("CSLL base cálculo percentual presunção", "calculo_irpj_csll"), |
| ExemploTreinamento("estimativa mensal IRPJ lucro real", "calculo_irpj_csll"), |
| ExemploTreinamento("adicional 10% IRPJ lucro excedente", "calculo_irpj_csll"), |
| ExemploTreinamento("IRPJ 15% alíquota base cálculo lucro", "calculo_irpj_csll"), |
| ExemploTreinamento("quanto imposto renda empresa lucro 500000", "calculo_irpj_csll"), |
| ExemploTreinamento("CSLL 9% lucro real contribuição social", "calculo_irpj_csll"), |
| ExemploTreinamento("presunção lucro 8% comércio 32% serviço IRPJ", "calculo_irpj_csll"), |
| ExemploTreinamento("prejuízo fiscal compensação 30% lucro real", "calculo_irpj_csll"), |
| ExemploTreinamento("IRPJ CSLL apuração trimestral lucro presumido", "calculo_irpj_csll"), |
| |
| |
| |
| ExemploTreinamento("calcular ISS imposto sobre serviços", "calculo_iss"), |
| ExemploTreinamento("alíquota ISS prestação serviço município", "calculo_iss"), |
| ExemploTreinamento("retenção ISS tomador serviço nota", "calculo_iss"), |
| ExemploTreinamento("ISS simples nacional anexo III IV V", "calculo_iss"), |
| ExemploTreinamento("alíquota minima ISS 2% maxima 5%", "calculo_iss"), |
| ExemploTreinamento("LC 116 2003 lista serviços ISS municipio", "calculo_iss"), |
| ExemploTreinamento("ISS retido fonte tomador serviço obrigatorio", "calculo_iss"), |
| ExemploTreinamento("quanto ISS servico consultoria 10000 reais", "calculo_iss"), |
| ExemploTreinamento("ISS desenvolvimento software TI tecnologia", "calculo_iss"), |
| ExemploTreinamento("base calculo ISS dedução material construção", "calculo_iss"), |
| ] |
|
|
|
|
| class DatasetClassificacaoFiscal(Dataset): |
| """Dataset PyTorch para treinamento do classificador de obrigações fiscais.""" |
|
|
| def __init__( |
| self, |
| exemplos: list[ExemploTreinamento], |
| tokenizador: TokenizadorFiscal, |
| max_len: int = 256, |
| ): |
| self.exemplos = exemplos |
| self.tokenizador = tokenizador |
| self.max_len = max_len |
| self.classe2idx = {c: i for i, c in enumerate(CLASSES_OBRIGACAO)} |
|
|
| def __len__(self) -> int: |
| return len(self.exemplos) |
|
|
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: |
| ex = self.exemplos[idx] |
| ids = self.tokenizador.encode(ex.texto, max_len=self.max_len) |
| ids_t = torch.zeros(self.max_len, dtype=torch.long) |
| mask = torch.zeros(self.max_len, dtype=torch.long) |
| n = min(len(ids), self.max_len) |
| ids_t[:n] = torch.tensor(ids[:n]) |
| mask[:n] = 1 |
| label = self.classe2idx.get(ex.classe, 0) |
| return { |
| "input_ids": ids_t, |
| "attention_mask": mask, |
| "labels": torch.tensor(label, dtype=torch.long), |
| "weight": torch.tensor(ex.peso, dtype=torch.float), |
| } |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ConfigTreinamento: |
| epochs: int = 50 |
| batch_size: int = 16 |
| lr: float = 3e-4 |
| weight_decay: float = 1e-4 |
| grad_clip: float = 1.0 |
| device: str = "auto" |
| checkpoint_dir: str = "./checkpoints" |
| log_interval: int = 10 |
| val_split: float = 0.15 |
| seed: int = 42 |
| early_stopping_patience: int = 8 |
| use_class_weights: bool = True |
|
|
|
|
| |
| |
| |
|
|
| class TrainerFiscal: |
| """Treinador do TransformerFiscal com classificação de intenção fiscal.""" |
|
|
| def __init__( |
| self, |
| modelo: TransformerFiscal, |
| tokenizador: TokenizadorFiscal, |
| cfg: Optional[ConfigTreinamento] = None, |
| ): |
| self.modelo = modelo |
| self.tokenizador = tokenizador |
| self.cfg = cfg or ConfigTreinamento() |
| self.device = self._resolver_device() |
| self.modelo.to(self.device) |
| self._historico: list[dict] = [] |
|
|
| def _resolver_device(self) -> str: |
| if self.cfg.device == "auto": |
| return "cuda" if torch.cuda.is_available() else "cpu" |
| return self.cfg.device |
|
|
| def _split_dataset( |
| self, exemplos: list[ExemploTreinamento] |
| ) -> tuple[list[ExemploTreinamento], list[ExemploTreinamento]]: |
| torch.manual_seed(self.cfg.seed) |
| n = len(exemplos) |
| n_val = max(1, int(n * self.cfg.val_split)) |
| indices = torch.randperm(n).tolist() |
| val_idx = set(indices[:n_val]) |
| treino = [e for i, e in enumerate(exemplos) if i not in val_idx] |
| val = [e for i, e in enumerate(exemplos) if i in val_idx] |
| return treino, val |
|
|
| def _calcular_class_weights(self, dataset: DatasetClassificacaoFiscal) -> torch.Tensor: |
| """Pesos inversamente proporcionais à frequência de cada classe.""" |
| contagens = [0] * len(CLASSES_OBRIGACAO) |
| for idx in range(len(dataset)): |
| item = dataset[idx] |
| contagens[item["labels"].item()] += 1 |
| total = sum(contagens) |
| pesos = [total / (len(contagens) * max(c, 1)) for c in contagens] |
| return torch.tensor(pesos, dtype=torch.float).to(self.device) |
|
|
| def treinar( |
| self, |
| exemplos: Optional[list[ExemploTreinamento]] = None, |
| salvar_em: Optional[str] = None, |
| ) -> list[dict]: |
| if exemplos is None: |
| exemplos = EXEMPLOS_CLASSIFICACAO |
|
|
| treino_ex, val_ex = self._split_dataset(exemplos) |
| ds_treino = DatasetClassificacaoFiscal(treino_ex, self.tokenizador) |
| ds_val = DatasetClassificacaoFiscal(val_ex, self.tokenizador) |
|
|
| loader_treino = DataLoader(ds_treino, batch_size=self.cfg.batch_size, shuffle=True) |
| loader_val = DataLoader(ds_val, batch_size=self.cfg.batch_size) |
|
|
| |
| class_w = ( |
| self._calcular_class_weights(ds_treino) |
| if self.cfg.use_class_weights |
| else None |
| ) |
| criterion = nn.CrossEntropyLoss(weight=class_w, reduction="none") |
|
|
| optimizer = optim.AdamW( |
| self.modelo.parameters(), |
| lr=self.cfg.lr, |
| weight_decay=self.cfg.weight_decay, |
| ) |
| total_steps = len(loader_treino) * self.cfg.epochs |
| scheduler = optim.lr_scheduler.OneCycleLR( |
| optimizer, max_lr=self.cfg.lr, total_steps=total_steps, pct_start=0.1, |
| ) |
|
|
| print(f"Treinamento: {len(treino_ex)} treino | {len(val_ex)} validação | " |
| f"device={self.device} | épocas={self.cfg.epochs}") |
|
|
| best_val_loss = float("inf") |
| best_state: dict = {} |
| paciencia = 0 |
|
|
| for epoch in range(1, self.cfg.epochs + 1): |
| self.modelo.train() |
| loss_total = corretos = total = 0 |
|
|
| for batch in loader_treino: |
| ids = batch["input_ids"].to(self.device) |
| mask = batch["attention_mask"].to(self.device) |
| labels = batch["labels"].to(self.device) |
| pesos = batch["weight"].to(self.device) |
|
|
| optimizer.zero_grad() |
| logits = self.modelo(ids, mask, task="classificar") |
| losses = criterion(logits, labels) |
| loss = (losses * pesos).mean() |
| loss.backward() |
| nn.utils.clip_grad_norm_(self.modelo.parameters(), self.cfg.grad_clip) |
| optimizer.step() |
| scheduler.step() |
|
|
| loss_total += loss.item() |
| corretos += (logits.argmax(-1) == labels).sum().item() |
| total += labels.size(0) |
|
|
| val_loss, val_acc = self._avaliar(loader_val, criterion) |
|
|
| entrada = { |
| "epoch": epoch, |
| "train_loss": loss_total / len(loader_treino), |
| "train_acc": corretos / total, |
| "val_loss": val_loss, |
| "val_acc": val_acc, |
| "lr": scheduler.get_last_lr()[0], |
| } |
| self._historico.append(entrada) |
|
|
| if epoch % self.cfg.log_interval == 0 or epoch == self.cfg.epochs: |
| print( |
| f"Época {epoch:3d}/{self.cfg.epochs} | " |
| f"Loss: {entrada['train_loss']:.4f} | Acc: {entrada['train_acc']:.3f} | " |
| f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f}" |
| ) |
|
|
| |
| if val_loss < best_val_loss - 1e-4: |
| best_val_loss = val_loss |
| best_state = {k: v.clone() for k, v in self.modelo.state_dict().items()} |
| paciencia = 0 |
| else: |
| paciencia += 1 |
| if paciencia >= self.cfg.early_stopping_patience: |
| print(f"Early stopping na época {epoch} (val_loss não melhorou há {paciencia} épocas)") |
| break |
|
|
| |
| if best_state: |
| self.modelo.load_state_dict(best_state) |
|
|
| if salvar_em: |
| caminho = Path(salvar_em) |
| caminho.parent.mkdir(parents=True, exist_ok=True) |
| self.modelo.save(caminho) |
| |
| historico_path = caminho.with_suffix(".historico.json") |
| historico_path.write_text(json.dumps({ |
| "historico": self._historico, |
| "config": { |
| "epochs": self.cfg.epochs, |
| "lr": self.cfg.lr, |
| "batch_size": self.cfg.batch_size, |
| "val_split": self.cfg.val_split, |
| "use_class_weights": self.cfg.use_class_weights, |
| "early_stopping_patience": self.cfg.early_stopping_patience, |
| "n_exemplos": len(exemplos), |
| "n_classes": len(CLASSES_OBRIGACAO), |
| }, |
| }, indent=2)) |
| print(f"Modelo salvo: {caminho} | Melhor val_loss: {best_val_loss:.4f}") |
|
|
| return self._historico |
|
|
| @torch.no_grad() |
| def _avaliar( |
| self, |
| loader: DataLoader, |
| criterion: nn.Module, |
| ) -> tuple[float, float]: |
| self.modelo.eval() |
| loss_total = 0.0 |
| corretos = 0 |
| total = 0 |
|
|
| for batch in loader: |
| ids = batch["input_ids"].to(self.device) |
| mask = batch["attention_mask"].to(self.device) |
| labels = batch["labels"].to(self.device) |
|
|
| logits = self.modelo(ids, mask, task="classificar") |
| loss = criterion(logits, labels).mean() |
|
|
| loss_total += loss.item() |
| corretos += (logits.argmax(-1) == labels).sum().item() |
| total += labels.size(0) |
|
|
| n = max(len(loader), 1) |
| return loss_total / n, corretos / max(total, 1) |
|
|
| def avaliar_exemplos(self, textos: list[str]) -> list[dict]: |
| """Avalia o modelo em textos e retorna top-3 predições com probabilidades.""" |
| import torch.nn.functional as F |
| self.modelo.eval() |
| resultados = [] |
| for texto in textos: |
| ids = self.tokenizador.encode(texto, max_len=256) |
| t = torch.tensor([ids], dtype=torch.long).to(self.device) |
| mask = torch.ones_like(t) |
| with torch.no_grad(): |
| logits = self.modelo(t, mask, task="classificar") |
| probs = F.softmax(logits[0], dim=-1) |
| top3 = probs.argsort(descending=True)[:3] |
| resultados.append({ |
| "texto": texto, |
| "predicoes": [ |
| {"classe": CLASSES_OBRIGACAO[i], "prob": float(probs[i])} |
| for i in top3 |
| ], |
| }) |
| return resultados |
|
|
| def avaliar_por_classe(self, exemplos: list[ExemploTreinamento]) -> dict: |
| """ |
| Avalia o modelo retornando acurácia por classe (matriz de confusão simplificada). |
| Útil para identificar quais classes o modelo confunde. |
| """ |
| import torch.nn.functional as F |
| self.modelo.eval() |
| classe2idx = {c: i for i, c in enumerate(CLASSES_OBRIGACAO)} |
| corretos_por_classe = {c: 0 for c in CLASSES_OBRIGACAO} |
| total_por_classe = {c: 0 for c in CLASSES_OBRIGACAO} |
| confusoes: list[dict] = [] |
|
|
| for ex in exemplos: |
| ids = self.tokenizador.encode(ex.texto, max_len=256) |
| t = torch.tensor([ids], dtype=torch.long).to(self.device) |
| mask = torch.ones_like(t) |
| with torch.no_grad(): |
| logits = self.modelo(t, mask, task="classificar") |
| pred_idx = logits[0].argmax().item() |
| pred = CLASSES_OBRIGACAO[pred_idx] |
| verdadeiro = ex.classe |
|
|
| total_por_classe[verdadeiro] += 1 |
| if pred == verdadeiro: |
| corretos_por_classe[verdadeiro] += 1 |
| else: |
| confusoes.append({"texto": ex.texto[:60], "esperado": verdadeiro, "previsto": pred}) |
|
|
| acuracia_por_classe = { |
| c: corretos_por_classe[c] / max(total_por_classe[c], 1) |
| for c in CLASSES_OBRIGACAO |
| } |
| acuracia_geral = sum(corretos_por_classe.values()) / max(sum(total_por_classe.values()), 1) |
|
|
| return { |
| "acuracia_geral": acuracia_geral, |
| "acuracia_por_classe": acuracia_por_classe, |
| "confusoes": confusoes[:20], |
| } |
|
|
|
|
| def treinar_modelo( |
| epochs: int = 30, |
| lr: float = 3e-4, |
| salvar_em: str = "./checkpoints/fiscal_llm.pt", |
| exemplos_extras: Optional[list[ExemploTreinamento]] = None, |
| ) -> TransformerFiscal: |
| """ |
| Treina o TransformerFiscal com os dados sintéticos built-in + |
| quaisquer exemplos extras fornecidos. |
| """ |
| tokenizador = TokenizadorFiscal() |
| modelo = TransformerFiscal( |
| vocab_size=len(tokenizador), |
| d_model=256, |
| nhead=8, |
| num_layers=4, |
| dim_feedforward=1024, |
| num_classes=len(CLASSES_OBRIGACAO), |
| ) |
|
|
| exemplos = list(EXEMPLOS_CLASSIFICACAO) |
| if exemplos_extras: |
| exemplos.extend(exemplos_extras) |
|
|
| cfg = ConfigTreinamento(epochs=epochs, lr=lr) |
| trainer = TrainerFiscal(modelo, tokenizador, cfg) |
| trainer.treinar(exemplos, salvar_em=salvar_em) |
|
|
| return modelo |
|
|