rtferraz commited on
Commit
1dfd4e2
·
verified ·
1 Parent(s): 28118c7

Add data_pipeline.py — tokenize_user_sequences, pack_sequences, prepare_clm_dataset

Browse files
src/domain_tokenizer/training/data_pipeline.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data pipeline for domain sequence CLM pre-training.
3
+
4
+ Pipeline:
5
+ 1. Raw events -> DomainTokenizerBuilder.tokenize_sequence() -> token strings
6
+ 2. Token strings -> HF tokenizer -> token IDs (variable length)
7
+ 3. Token ID sequences -> pack into fixed-length blocks (group_texts pattern)
8
+ 4. Packed blocks -> DataCollatorForLanguageModeling -> {input_ids, labels, attention_mask}
9
+
10
+ Packing follows the official HF run_clm.py pattern: concatenate all tokenized
11
+ sequences, split into fixed-length blocks. Zero padding waste, 100% token utilization.
12
+ """
13
+
14
+ import logging
15
+ from typing import Any, Dict, List, Optional, Sequence
16
+
17
+ from datasets import Dataset as HFDataset
18
+
19
+ from ..tokenizers.domain_tokenizer import DomainTokenizerBuilder
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def tokenize_user_sequences(
25
+ user_sequences: Sequence[Sequence[Dict[str, Any]]],
26
+ builder: DomainTokenizerBuilder,
27
+ hf_tokenizer,
28
+ add_bos: bool = True,
29
+ add_eos: bool = True,
30
+ num_proc: int = 1,
31
+ ) -> List[List[int]]:
32
+ """Tokenize user event sequences into token ID lists."""
33
+ all_token_ids = []
34
+ for events in user_sequences:
35
+ token_strings = builder.tokenize_sequence(events, add_bos=add_bos, add_eos=add_eos)
36
+ token_text = " ".join(token_strings)
37
+ encoding = hf_tokenizer(token_text, add_special_tokens=False)
38
+ all_token_ids.append(encoding["input_ids"])
39
+ return all_token_ids
40
+
41
+
42
+ def pack_sequences(token_id_sequences: List[List[int]], block_size: int = 512) -> HFDataset:
43
+ """Pack variable-length token sequences into fixed-length blocks.
44
+
45
+ Follows the official HF run_clm.py pattern: concatenate all sequences
46
+ into one long stream, split into fixed-length blocks, drop remainder.
47
+ Achieves 100% token utilization with zero padding waste.
48
+ """
49
+ concatenated = []
50
+ for seq in token_id_sequences:
51
+ concatenated.extend(seq)
52
+
53
+ total_tokens = len(concatenated)
54
+ n_blocks = total_tokens // block_size
55
+ dropped = total_tokens - n_blocks * block_size
56
+
57
+ if n_blocks == 0:
58
+ raise ValueError(
59
+ f"Not enough tokens ({total_tokens}) to form even one block of size {block_size}. "
60
+ f"Reduce block_size or add more data."
61
+ )
62
+
63
+ logger.info(f"Packing: {total_tokens:,} tokens -> {n_blocks:,} blocks of {block_size} "
64
+ f"({dropped} tokens dropped, {dropped/total_tokens*100:.1f}% waste)")
65
+
66
+ packed = [concatenated[i * block_size : (i + 1) * block_size] for i in range(n_blocks)]
67
+ return HFDataset.from_dict({"input_ids": packed})
68
+
69
+
70
+ def prepare_clm_dataset(
71
+ user_sequences: Sequence[Sequence[Dict[str, Any]]],
72
+ builder: DomainTokenizerBuilder,
73
+ hf_tokenizer,
74
+ block_size: int = 512,
75
+ add_bos: bool = True,
76
+ add_eos: bool = True,
77
+ ) -> HFDataset:
78
+ """Full pipeline: user event sequences -> packed CLM training dataset.
79
+
80
+ Example:
81
+ >>> dataset = prepare_clm_dataset(user_sequences, builder, hf_tokenizer, block_size=512)
82
+ >>> collator = DataCollatorForLanguageModeling(tokenizer=hf_tokenizer, mlm=False)
83
+ >>> trainer = Trainer(model=model, train_dataset=dataset, data_collator=collator, ...)
84
+ """
85
+ token_id_sequences = tokenize_user_sequences(
86
+ user_sequences, builder, hf_tokenizer, add_bos=add_bos, add_eos=add_eos,
87
+ )
88
+ total_tokens = sum(len(seq) for seq in token_id_sequences)
89
+ avg_tokens = total_tokens / max(len(token_id_sequences), 1)
90
+ logger.info(f"Tokenized {len(token_id_sequences)} user sequences -> "
91
+ f"{total_tokens:,} tokens (avg {avg_tokens:.1f} tokens/sequence)")
92
+ return pack_sequences(token_id_sequences, block_size=block_size)