import os from typing import Any, Dict, List, Optional, Union import torch from torch.utils.data import DataLoader from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from datasets import load_dataset, Dataset, DatasetDict def tokenize_function( examples: Dict[str, List[str]], tokenizer: PreTrainedTokenizer ) -> Dict[str, List[List[int]]]: """ Tokeniza os exemplos sem aplicar truncamento ou padding. Retorna apenas os input_ids. """ tokenized_output = tokenizer(examples["text"], truncation=False, padding=False) return {"input_ids": tokenized_output["input_ids"]} def pack_documents( examples: Dict[str, List[List[int]]], max_length: int, eos_token_id: Optional[int] = None ) -> Dict[str, List[List[int]]]: """ Aplica Document Packing e retorna apenas os inputs de tamanho fixo (max_length), descartando o último token extra usado para labels. """ # Concatena tokens de todo o batch concatenated: List[int] = [] separator = [eos_token_id] if eos_token_id is not None else [] first = True for doc in examples["input_ids"]: if not first and separator: concatenated.extend(separator) concatenated.extend(doc) first = False block_size = max_length + 1 total_len = (len(concatenated) // block_size) * block_size if total_len == 0: return {"input_ids": []} concatenated = concatenated[:total_len] # Divide em blocos de block_size e remove o último token de cada bloco blocks = [ concatenated[i : i + block_size] for i in range(0, total_len, block_size) ] inputs = [blk[:-1] for blk in blocks] # Filtra qualquer bloco vazio inputs = [inp for inp in inputs if len(inp) > 0] return {"input_ids": inputs} def create_train_dataloader( folder_path: str, tokenizer: PreTrainedTokenizerFast, batch_size: int = 4, max_length: int = 512, drop_last: bool = True, num_workers: int = 5 ) -> Optional[DataLoader]: """ Carrega .txt de folder_path, tokeniza, aplica packing só de inputs e retorna um DataLoader que fornece batches de input_ids. """ raw_dataset = load_dataset(folder_path, split="train", streaming=False) print(f"Dataset bruto carregado: {raw_dataset}") # 1) Tokenização tokenized = raw_dataset.map( lambda ex: tokenize_function(ex, tokenizer), batched=True, batch_size=1000, num_proc=20, remove_columns=raw_dataset.column_names, ) print(f"Dataset tokenizado: {tokenized}") # 2) Document Packing sem labels packed = tokenized.map( lambda ex: pack_documents( ex, max_length=max_length, eos_token_id=tokenizer.eos_token_id ), batched=True, batch_size=10000, num_proc=20, ) # 3) Configura para PyTorch packed.set_format(type="torch", columns=["input_ids"]) print("Criando DataLoader...") return DataLoader( packed, batch_size=batch_size, drop_last=drop_last, num_workers=num_workers, )