td-builder commited on
Commit
2212c4a
·
verified ·
1 Parent(s): e123506

Upload 137 files

Browse files
Files changed (1) hide show
  1. hugging/td_fuse/heal.py +14 -9
hugging/td_fuse/heal.py CHANGED
@@ -66,20 +66,25 @@ def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
66
  print("[heal] Loading healing fine-tune data...")
67
 
68
  # Merge-specific: use diverse data that exercises all merged capabilities
 
69
  datasets_to_load = [
70
- # General language (from Pile)
71
- ("EleutherAI/pile", "validation", 500, "text"),
72
  # Math reasoning (exercises DeepSeek/MiMo contributions)
73
- ("openai/gsm8k", "train", 300, "question"),
74
- # Code (exercises Llama contribution)
75
- ("codeparrot/github-code", "train", 200, "code"),
76
  ]
77
 
78
  all_texts = []
79
 
80
- for dataset_id, split, count, text_field in datasets_to_load:
 
81
  try:
82
- ds = load_dataset(dataset_id, split=split, streaming=True, trust_remote_code=True)
 
 
 
83
  loaded = 0
84
  for example in ds:
85
  if loaded >= count:
@@ -200,7 +205,7 @@ def apply_qlora_unsloth(
200
 
201
  trainer = SFTTrainer(
202
  model=model,
203
- tokenizer=tokenizer,
204
  train_dataset=dataset,
205
  args=training_args,
206
  max_seq_length=cfg.heal_seq_len,
@@ -328,7 +333,7 @@ def apply_qlora_standard(
328
 
329
  trainer = Trainer(
330
  model=model,
331
- tokenizer=tokenizer,
332
  train_dataset=dataset,
333
  args=training_args,
334
  )
 
66
  print("[heal] Loading healing fine-tune data...")
67
 
68
  # Merge-specific: use diverse data that exercises all merged capabilities
69
+ # Each entry: (dataset_id, config_name_or_None, split, count, text_field)
70
  datasets_to_load = [
71
+ # General language same calibration data source that works reliably
72
+ ("neuralmagic/LLM_compression_calibration", None, "train", 500, "text"),
73
  # Math reasoning (exercises DeepSeek/MiMo contributions)
74
+ ("openai/gsm8k", "main", "train", 300, "question"),
75
+ # Code bigcode/starcoderdata is a modern alternative
76
+ ("bigcode/starcoderdata", "python", "train", 200, "content"),
77
  ]
78
 
79
  all_texts = []
80
 
81
+ for entry in datasets_to_load:
82
+ dataset_id, config_name, split, count, text_field = entry
83
  try:
84
+ if config_name:
85
+ ds = load_dataset(dataset_id, config_name, split=split, streaming=True)
86
+ else:
87
+ ds = load_dataset(dataset_id, split=split, streaming=True)
88
  loaded = 0
89
  for example in ds:
90
  if loaded >= count:
 
205
 
206
  trainer = SFTTrainer(
207
  model=model,
208
+ processing_class=tokenizer,
209
  train_dataset=dataset,
210
  args=training_args,
211
  max_seq_length=cfg.heal_seq_len,
 
333
 
334
  trainer = Trainer(
335
  model=model,
336
+ processing_class=tokenizer,
337
  train_dataset=dataset,
338
  args=training_args,
339
  )