tynerox / src /dataset /pre_train.py
Ubuntu
Re-adiciona model.safetensors via LFS
58d9159
import os
from typing import Any, Dict, List, Optional, Union
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from datasets import load_dataset, Dataset, DatasetDict
def tokenize_function(
examples: Dict[str, List[str]],
tokenizer: PreTrainedTokenizer
) -> Dict[str, List[List[int]]]:
"""
Tokeniza os exemplos sem aplicar truncamento ou padding.
Retorna apenas os input_ids.
"""
tokenized_output = tokenizer(examples["text"], truncation=False, padding=False)
return {"input_ids": tokenized_output["input_ids"]}
def pack_documents(
examples: Dict[str, List[List[int]]],
max_length: int,
eos_token_id: Optional[int] = None
) -> Dict[str, List[List[int]]]:
"""
Aplica Document Packing e retorna apenas os inputs de tamanho fixo (max_length),
descartando o último token extra usado para labels.
"""
# Concatena tokens de todo o batch
concatenated: List[int] = []
separator = [eos_token_id] if eos_token_id is not None else []
first = True
for doc in examples["input_ids"]:
if not first and separator:
concatenated.extend(separator)
concatenated.extend(doc)
first = False
block_size = max_length + 1
total_len = (len(concatenated) // block_size) * block_size
if total_len == 0:
return {"input_ids": []}
concatenated = concatenated[:total_len]
# Divide em blocos de block_size e remove o último token de cada bloco
blocks = [
concatenated[i : i + block_size]
for i in range(0, total_len, block_size)
]
inputs = [blk[:-1] for blk in blocks]
# Filtra qualquer bloco vazio
inputs = [inp for inp in inputs if len(inp) > 0]
return {"input_ids": inputs}
def create_train_dataloader(
folder_path: str,
tokenizer: PreTrainedTokenizerFast,
batch_size: int = 4,
max_length: int = 512,
drop_last: bool = True,
num_workers: int = 5
) -> Optional[DataLoader]:
"""
Carrega .txt de folder_path, tokeniza, aplica packing só de inputs
e retorna um DataLoader que fornece batches de input_ids.
"""
raw_dataset = load_dataset(folder_path, split="train", streaming=False)
print(f"Dataset bruto carregado: {raw_dataset}")
# 1) Tokenização
tokenized = raw_dataset.map(
lambda ex: tokenize_function(ex, tokenizer),
batched=True,
batch_size=1000,
num_proc=20,
remove_columns=raw_dataset.column_names,
)
print(f"Dataset tokenizado: {tokenized}")
# 2) Document Packing sem labels
packed = tokenized.map(
lambda ex: pack_documents(
ex,
max_length=max_length,
eos_token_id=tokenizer.eos_token_id
),
batched=True,
batch_size=10000,
num_proc=20,
)
# 3) Configura para PyTorch
packed.set_format(type="torch", columns=["input_ids"])
print("Criando DataLoader...")
return DataLoader(
packed,
batch_size=batch_size,
drop_last=drop_last,
num_workers=num_workers,
)