Upload 137 files
Browse files- hugging/td_fuse/heal.py +14 -9
hugging/td_fuse/heal.py
CHANGED
|
@@ -66,20 +66,25 @@ def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
|
| 66 |
print("[heal] Loading healing fine-tune data...")
|
| 67 |
|
| 68 |
# Merge-specific: use diverse data that exercises all merged capabilities
|
|
|
|
| 69 |
datasets_to_load = [
|
| 70 |
-
# General language
|
| 71 |
-
("
|
| 72 |
# Math reasoning (exercises DeepSeek/MiMo contributions)
|
| 73 |
-
("openai/gsm8k", "train", 300, "question"),
|
| 74 |
-
# Code
|
| 75 |
-
("
|
| 76 |
]
|
| 77 |
|
| 78 |
all_texts = []
|
| 79 |
|
| 80 |
-
for
|
|
|
|
| 81 |
try:
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
| 83 |
loaded = 0
|
| 84 |
for example in ds:
|
| 85 |
if loaded >= count:
|
|
@@ -200,7 +205,7 @@ def apply_qlora_unsloth(
|
|
| 200 |
|
| 201 |
trainer = SFTTrainer(
|
| 202 |
model=model,
|
| 203 |
-
|
| 204 |
train_dataset=dataset,
|
| 205 |
args=training_args,
|
| 206 |
max_seq_length=cfg.heal_seq_len,
|
|
@@ -328,7 +333,7 @@ def apply_qlora_standard(
|
|
| 328 |
|
| 329 |
trainer = Trainer(
|
| 330 |
model=model,
|
| 331 |
-
|
| 332 |
train_dataset=dataset,
|
| 333 |
args=training_args,
|
| 334 |
)
|
|
|
|
| 66 |
print("[heal] Loading healing fine-tune data...")
|
| 67 |
|
| 68 |
# Merge-specific: use diverse data that exercises all merged capabilities
|
| 69 |
+
# Each entry: (dataset_id, config_name_or_None, split, count, text_field)
|
| 70 |
datasets_to_load = [
|
| 71 |
+
# General language — same calibration data source that works reliably
|
| 72 |
+
("neuralmagic/LLM_compression_calibration", None, "train", 500, "text"),
|
| 73 |
# Math reasoning (exercises DeepSeek/MiMo contributions)
|
| 74 |
+
("openai/gsm8k", "main", "train", 300, "question"),
|
| 75 |
+
# Code — bigcode/starcoderdata is a modern alternative
|
| 76 |
+
("bigcode/starcoderdata", "python", "train", 200, "content"),
|
| 77 |
]
|
| 78 |
|
| 79 |
all_texts = []
|
| 80 |
|
| 81 |
+
for entry in datasets_to_load:
|
| 82 |
+
dataset_id, config_name, split, count, text_field = entry
|
| 83 |
try:
|
| 84 |
+
if config_name:
|
| 85 |
+
ds = load_dataset(dataset_id, config_name, split=split, streaming=True)
|
| 86 |
+
else:
|
| 87 |
+
ds = load_dataset(dataset_id, split=split, streaming=True)
|
| 88 |
loaded = 0
|
| 89 |
for example in ds:
|
| 90 |
if loaded >= count:
|
|
|
|
| 205 |
|
| 206 |
trainer = SFTTrainer(
|
| 207 |
model=model,
|
| 208 |
+
processing_class=tokenizer,
|
| 209 |
train_dataset=dataset,
|
| 210 |
args=training_args,
|
| 211 |
max_seq_length=cfg.heal_seq_len,
|
|
|
|
| 333 |
|
| 334 |
trainer = Trainer(
|
| 335 |
model=model,
|
| 336 |
+
processing_class=tokenizer,
|
| 337 |
train_dataset=dataset,
|
| 338 |
args=training_args,
|
| 339 |
)
|