LH-Tech-AI commited on
Commit
a47f627
·
verified ·
1 Parent(s): e1845ba

Create prepare_finetune.py

Browse files

The data preparation script for finetuning the model.

Files changed (1) hide show
  1. prepare_finetune.py +120 -0
prepare_finetune.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tiktoken
4
+ from datasets import load_dataset
5
+ from tqdm import tqdm
6
+
7
+ OUTPUT_DIR = "data/alpaca_cleaned_mixed"
8
+ TOKENIZER_NAME = "gpt2"
9
+ SEED = 1337
10
+
11
+ FINEWEB_SAMPLES = 2500
12
+
13
+ enc = tiktoken.get_encoding(TOKENIZER_NAME)
14
+ EOS_TOKEN = "<|endoftext|>"
15
+
16
+ def format_prompt_with_mask(instruction, input_text, output):
17
+ """
18
+ Formatiert den Prompt und erstellt die Loss-Maske.
19
+ Format:
20
+ Instruction: ...
21
+ Input: ... (optional)
22
+ Response: ... <|endoftext|>
23
+ """
24
+ if input_text and input_text.strip():
25
+ prompt_text = f"Instruction:\n{instruction}\n\nInput:\n{input_text}\n\nResponse:\n"
26
+ else:
27
+ prompt_text = f"Instruction:\n{instruction}\n\nResponse:\n"
28
+
29
+ completion_text = f"{output}{EOS_TOKEN}"
30
+
31
+ prompt_ids = enc.encode(prompt_text, allowed_special={'<|endoftext|>'})
32
+ completion_ids = enc.encode(completion_text, allowed_special={'<|endoftext|>'})
33
+
34
+ full_ids = prompt_ids + completion_ids
35
+
36
+ mask = [0] * len(prompt_ids) + [1] * len(completion_ids)
37
+
38
+ return full_ids, mask
39
+
40
+ def main():
41
+ np.random.seed(SEED)
42
+ print(f"🚀 Starting Prepare-Script for SmaLLMPro (350M Instruct)...")
43
+ print(f"📚 Tokenizer: {TOKENIZER_NAME}")
44
+
45
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
46
+
47
+ print("📥 Loading 'yahma/alpaca-cleaned' (Chat-Instructions)...")
48
+ alpaca = load_dataset("yahma/alpaca-cleaned", split='train')
49
+
50
+ print(f"📥 Loading 'HuggingFaceFW/fineweb-edu' (Sample-10BT) for {FINEWEB_SAMPLES} Samples...")
51
+ fineweb = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split='train', streaming=True)
52
+
53
+ all_tokens = []
54
+ all_masks = []
55
+
56
+ print("⚙️ Processing Alpaca...")
57
+ for ex in tqdm(alpaca, desc="Alpaca"):
58
+ ids, mask = format_prompt_with_mask(ex['instruction'], ex['input'], ex['output'])
59
+ all_tokens.extend(ids)
60
+ all_masks.extend(mask)
61
+
62
+ alpaca_len = len(all_tokens)
63
+ print(f" -> Alpaca Tokens: {alpaca_len:,}")
64
+
65
+ print("⚙️ Processing FineWeb (Anti-Forgetting)...")
66
+ fw_iter = iter(fineweb)
67
+ fw_count = 0
68
+ fw_tokens_count = 0
69
+
70
+ for _ in tqdm(range(FINEWEB_SAMPLES), desc="FineWeb"):
71
+ try:
72
+ ex = next(fw_iter)
73
+ text = ex['text'] + EOS_TOKEN
74
+ ids = enc.encode(text, allowed_special={EOS_TOKEN})
75
+
76
+ all_tokens.extend(ids)
77
+ all_masks.extend([1] * len(ids))
78
+
79
+ fw_tokens_count += len(ids)
80
+ fw_count += 1
81
+ except StopIteration:
82
+ break
83
+
84
+ print(f" -> FineWeb Tokens: {fw_tokens_count:,} (from {fw_count} documents)")
85
+
86
+ total_tokens = len(all_tokens)
87
+ print(f"\n💾 Saving {total_tokens:,} Tokens in '{OUTPUT_DIR}'...")
88
+
89
+ token_arr = np.array(all_tokens, dtype=np.uint16)
90
+ token_arr.tofile(os.path.join(OUTPUT_DIR, "train.bin"))
91
+
92
+ mask_arr = np.array(all_masks, dtype=np.uint8)
93
+ mask_arr.tofile(os.path.join(OUTPUT_DIR, "train_mask.bin"))
94
+
95
+ print("\n🔍 --- SANITY CHECK ---")
96
+ print("I decode the first 50 tokens of the first sample, to check, if everything is okay.")
97
+ print("Green (TRAIN) = The things the model learns. Grey (IGNORE) = The things the model only reads.")
98
+
99
+ check_len = 100
100
+ sample_ids = all_tokens[:check_len]
101
+ sample_mask = all_masks[:check_len]
102
+
103
+ decoded_parts = []
104
+ for t_id, m_val in zip(sample_ids, sample_mask):
105
+ token_str = enc.decode([t_id])
106
+ if m_val == 1:
107
+ decoded_parts.append(f"\033[92m{token_str}\033[0m")
108
+ else:
109
+ decoded_parts.append(f"\033[90m{token_str}\033[0m")
110
+
111
+ print("".join(decoded_parts))
112
+ print("\n(Legend: \033[90mGrey=Prompt/Ignored\033[0m, \033[Green=Response/Learned\033[0m)")
113
+
114
+ if len(token_arr) != len(mask_arr):
115
+ print("\n❌ Warning: Token and Mask Array have different lengths! Something has gone wrong!")
116
+ else:
117
+ print("\n✅ Everything seems to be fine. The arrays are synchronized. You can now start the training.")
118
+
119
+ if __name__ == "__main__":
120
+ main()