from typing import List, Dict, Any, Optional import torch from torch.utils.data import DataLoader from datasets import load_dataset from transformers import AutoTokenizer from torch.nn.utils.rnn import pad_sequence from functools import partial def tokenize_function(examples: Dict[str, Any], tokenizer: Any) -> Dict[str, List[int]]: """ Aplica a template de chat do tokenizer e gera os token ids. Args: examples (Dict[str, Any]): Dicionário contendo a lista de mensagens sob a chave "messages". tokenizer (Any): Instância do tokenizer que deverá possuir a propriedade 'chat_template'. Returns: Dict[str, List[int]]: Dicionário com os token ids gerados. """ full_text = tokenizer.apply_chat_template( examples["messages"], tokenize=True, add_generation_prompt=True ) return {"input_ids": full_text} def custom_collate_fn( batch: List[Dict[str, List[int]]], pad_token_id: int = 29797, ignore_index: int = -100, allowed_max_length: Optional[int] = None, device: str = "cpu", ) -> Dict[str, torch.Tensor]: """ • Faz padding das sequências • Cria pares (input, label) deslocando 1 posição • Aplica `ignore_index` (-100) APENAS nos labels depois do 1.º PAD """ # 1) Lista → Tensor + PAD final seqs = [torch.tensor(s["input_ids"] + [pad_token_id]) for s in batch] # 2) Padding até o comprimento máximo do batch padded = pad_sequence(seqs, batch_first=True, padding_value=pad_token_id) # 3) Desloca 1 posição e CLONA para quebrar o compartilhamento de memória input_ids = padded[:, :-1].clone() # ← nunca terá -100 labels = padded[:, 1:].clone() # ← vamos editar aqui # 4) Define -100 após o primeiro PAD de cada sequência pad_mask = (labels == pad_token_id) if pad_mask.any(): # índice da primeira ocorrência de PAD em cada linha first_pad_pos = pad_mask.float().cumsum(1).eq(1) & pad_mask # tudo que vem depois do primeiro PAD recebe -100 mask_after_first_pad = pad_mask & ~first_pad_pos labels[mask_after_first_pad] = ignore_index # 5) Trunca se for solicitado if allowed_max_length is not None: input_ids = input_ids[:, :allowed_max_length] labels = labels[:, :allowed_max_length] return { "input_ids": input_ids.to(device), "labels": labels.to(device), } def create_data_loader_fine_tuning( tokenizer: Any, batch_size: int, path_folder: str, split: str = "train", pad_token_id: int = 0, ignore_index: int = -100, allowed_max_length: Optional[int] = None, device: str = "cpu" ) -> DataLoader: """ Cria o DataLoader para fine-tuning, a partir de um dataset_files tokenizado. Esta função carrega o dataset_files, aplica a tokenização utilizando uma template de chat, e retorna um DataLoader que utiliza a função custom_collate_fn para o processamento adequado das batches. Args: tokenizer (Any): Tokenizer pré-treinado que suporte chat templates. batch_size (int): Número de amostras por batch. path_folder (str): Caminho ou identificador do dataset_files. split (str): Divisão do dataset_files a ser utilizada (por exemplo, "train" ou "test"). pad_token_id (int): ID do token para padding. ignore_index (int): Valor a ser ignorado na função de perda. allowed_max_length (Optional[int]): Se definido, trunca as sequências para este tamanho máximo. device (str): Dispositivo para onde os tensores serão enviados ("cpu" ou "cuda"). Returns: DataLoader: Instância do DataLoader pronta para o fine-tuning. """ # Define a template de chat e atribui ao tokenizer. chat_template = """ {% for message in messages %} {% if message['role'] == 'user' %} {{ '<|user_start|>' + message['content'] + '<|user_end|>' + '\n'}} {% elif message['role'] == 'assistant' %} {{ '<|assistant_start|>' + message['content'] + '<|assistant_end|>' + '\n' }} {% endif %} {% endfor %} """ tokenizer.chat_template = chat_template # Carrega o dataset_files. raw_dataset = load_dataset(path=path_folder, split=split, download_mode="force_redownload") # Aplica a tokenização utilizando a função definida. tokenized_dataset = raw_dataset.map( lambda examples: tokenize_function(examples, tokenizer), batched=True, remove_columns=raw_dataset.column_names, desc="Tokenizando dataset_files" ) # Configura o collate_fn com os parâmetros desejados. collate = partial( custom_collate_fn, pad_token_id=pad_token_id, ignore_index=ignore_index, allowed_max_length=allowed_max_length, device=device ) print("Criando DataLoader...") return DataLoader( tokenized_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=0, collate_fn=collate ) if __name__ == "__main__": # Carrega o tokenizer pré-treinado. tokenizer = AutoTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased") # Cria o DataLoader para a divisão de treino do dataset_files "conversational". loader = create_data_loader_fine_tuning( tokenizer=tokenizer, batch_size=100, path_folder="conversational", split="test" ) # Testa a extração de uma batch. batch = next(iter(loader)) print(batch["input_ids"].shape, batch["labels"].shape)