File size: 3,162 Bytes
58d9159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
from typing import Any, Dict, List, Optional, Union

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from datasets import load_dataset, Dataset, DatasetDict


def tokenize_function(
    examples: Dict[str, List[str]],
    tokenizer: PreTrainedTokenizer
) -> Dict[str, List[List[int]]]:
    """
    Tokeniza os exemplos sem aplicar truncamento ou padding.

    Retorna apenas os input_ids.
    """
    tokenized_output = tokenizer(examples["text"], truncation=False, padding=False)
    return {"input_ids": tokenized_output["input_ids"]}


def pack_documents(
    examples: Dict[str, List[List[int]]],
    max_length: int,
    eos_token_id: Optional[int] = None
) -> Dict[str, List[List[int]]]:
    """
    Aplica Document Packing e retorna apenas os inputs de tamanho fixo (max_length),
    descartando o último token extra usado para labels.
    """
    # Concatena tokens de todo o batch
    concatenated: List[int] = []
    separator = [eos_token_id] if eos_token_id is not None else []
    first = True
    for doc in examples["input_ids"]:
        if not first and separator:
            concatenated.extend(separator)
        concatenated.extend(doc)
        first = False

    block_size = max_length + 1
    total_len = (len(concatenated) // block_size) * block_size
    if total_len == 0:
        return {"input_ids": []}

    concatenated = concatenated[:total_len]
    # Divide em blocos de block_size e remove o último token de cada bloco
    blocks = [
        concatenated[i : i + block_size]
        for i in range(0, total_len, block_size)
    ]
    inputs = [blk[:-1] for blk in blocks]

    # Filtra qualquer bloco vazio
    inputs = [inp for inp in inputs if len(inp) > 0]
    return {"input_ids": inputs}


def create_train_dataloader(
    folder_path: str,
    tokenizer: PreTrainedTokenizerFast,
    batch_size: int = 4,
    max_length: int = 512,
    drop_last: bool = True,
    num_workers: int = 5
) -> Optional[DataLoader]:
    """
    Carrega .txt de folder_path, tokeniza, aplica packing só de inputs
    e retorna um DataLoader que fornece batches de input_ids.
    """
    raw_dataset = load_dataset(folder_path, split="train", streaming=False)
    print(f"Dataset bruto carregado: {raw_dataset}")

    # 1) Tokenização
    tokenized = raw_dataset.map(
        lambda ex: tokenize_function(ex, tokenizer),
        batched=True,
        batch_size=1000,
        num_proc=20,
        remove_columns=raw_dataset.column_names,
    )
    print(f"Dataset tokenizado: {tokenized}")

    # 2) Document Packing sem labels
    packed = tokenized.map(
        lambda ex: pack_documents(
            ex,
            max_length=max_length,
            eos_token_id=tokenizer.eos_token_id
        ),
        batched=True,
        batch_size=10000,
        num_proc=20,
    )

    # 3) Configura para PyTorch
    packed.set_format(type="torch", columns=["input_ids"])

    print("Criando DataLoader...")
    return DataLoader(
        packed,
        batch_size=batch_size,
        drop_last=drop_last,
        num_workers=num_workers,
    )