File size: 5,691 Bytes
58d9159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from typing import List, Dict, Any, Optional
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.nn.utils.rnn import pad_sequence
from functools import partial


def tokenize_function(examples: Dict[str, Any], tokenizer: Any) -> Dict[str, List[int]]:
    """
    Aplica a template de chat do tokenizer e gera os token ids.

    Args:
        examples (Dict[str, Any]): Dicionário contendo a lista de mensagens sob a chave "messages".
        tokenizer (Any): Instância do tokenizer que deverá possuir a propriedade 'chat_template'.

    Returns:
        Dict[str, List[int]]: Dicionário com os token ids gerados.
    """
    full_text = tokenizer.apply_chat_template(
        examples["messages"],
        tokenize=True,
        add_generation_prompt=True
    )
    return {"input_ids": full_text}


def custom_collate_fn(
    batch: List[Dict[str, List[int]]],
    pad_token_id: int = 29797,
    ignore_index: int = -100,
    allowed_max_length: Optional[int] = None,
    device: str = "cpu",
) -> Dict[str, torch.Tensor]:
    """
    • Faz padding das sequências
    • Cria pares (input, label) deslocando 1 posição
    • Aplica `ignore_index` (-100) APENAS nos labels depois do 1.º PAD
    """

    # 1) Lista → Tensor  +  PAD final
    seqs = [torch.tensor(s["input_ids"] + [pad_token_id]) for s in batch]

    # 2) Padding até o comprimento máximo do batch
    padded = pad_sequence(seqs, batch_first=True, padding_value=pad_token_id)

    # 3) Desloca 1 posição e CLONA para quebrar o compartilhamento de memória
    input_ids = padded[:, :-1].clone()     # ← nunca terá -100
    labels    = padded[:, 1:].clone()      # ← vamos editar aqui

    # 4) Define -100 após o primeiro PAD de cada sequência
    pad_mask = (labels == pad_token_id)
    if pad_mask.any():
        # índice da primeira ocorrência de PAD em cada linha
        first_pad_pos = pad_mask.float().cumsum(1).eq(1) & pad_mask
        # tudo que vem depois do primeiro PAD recebe -100
        mask_after_first_pad = pad_mask & ~first_pad_pos
        labels[mask_after_first_pad] = ignore_index

    # 5) Trunca se for solicitado
    if allowed_max_length is not None:
        input_ids = input_ids[:, :allowed_max_length]
        labels    = labels[:, :allowed_max_length]

    return {
        "input_ids": input_ids.to(device),
        "labels":    labels.to(device),
    }



def create_data_loader_fine_tuning(
        tokenizer: Any,
        batch_size: int,
        path_folder: str,
        split: str = "train",
        pad_token_id: int = 0,
        ignore_index: int = -100,
        allowed_max_length: Optional[int] = None,
        device: str = "cpu"
) -> DataLoader:
    """
    Cria o DataLoader para fine-tuning, a partir de um dataset_files tokenizado.

    Esta função carrega o dataset_files, aplica a tokenização utilizando uma template de chat,
    e retorna um DataLoader que utiliza a função custom_collate_fn para o processamento
    adequado das batches.

    Args:
        tokenizer (Any): Tokenizer pré-treinado que suporte chat templates.
        batch_size (int): Número de amostras por batch.
        path_folder (str): Caminho ou identificador do dataset_files.
        split (str): Divisão do dataset_files a ser utilizada (por exemplo, "train" ou "test").
        pad_token_id (int): ID do token para padding.
        ignore_index (int): Valor a ser ignorado na função de perda.
        allowed_max_length (Optional[int]): Se definido, trunca as sequências para este tamanho máximo.
        device (str): Dispositivo para onde os tensores serão enviados ("cpu" ou "cuda").

    Returns:
        DataLoader: Instância do DataLoader pronta para o fine-tuning.
    """
    # Define a template de chat e atribui ao tokenizer.
    chat_template = """
    {% for message in messages %}
        {% if message['role'] == 'user' %}
            {{ '<|user_start|>' + message['content'] + '<|user_end|>' + '\n'}}
        {% elif message['role'] == 'assistant' %}
            {{ '<|assistant_start|>' + message['content'] + '<|assistant_end|>' + '\n' }}
        {% endif %}
    {% endfor %}
    """
    tokenizer.chat_template = chat_template

    # Carrega o dataset_files.
    raw_dataset = load_dataset(path=path_folder, split=split, download_mode="force_redownload")

    # Aplica a tokenização utilizando a função definida.
    tokenized_dataset = raw_dataset.map(
        lambda examples: tokenize_function(examples, tokenizer),
        batched=True,
        remove_columns=raw_dataset.column_names,
        desc="Tokenizando dataset_files"
    )

    # Configura o collate_fn com os parâmetros desejados.
    collate = partial(
        custom_collate_fn,
        pad_token_id=pad_token_id,
        ignore_index=ignore_index,
        allowed_max_length=allowed_max_length,
        device=device
    )

    print("Criando DataLoader...")
    return DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0,
        collate_fn=collate
    )


if __name__ == "__main__":
    # Carrega o tokenizer pré-treinado.
    tokenizer = AutoTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")

    # Cria o DataLoader para a divisão de treino do dataset_files "conversational".
    loader = create_data_loader_fine_tuning(
        tokenizer=tokenizer,
        batch_size=100,
        path_folder="conversational",
        split="test"
    )

    # Testa a extração de uma batch.
    batch = next(iter(loader))
    print(batch["input_ids"].shape, batch["labels"].shape)