stmasson commited on
Commit
81400b6
·
verified ·
1 Parent(s): 159c050

Upload scripts/train_sft_n8n_multitask.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_sft_n8n_multitask.py +22 -5
scripts/train_sft_n8n_multitask.py CHANGED
@@ -36,19 +36,36 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
36
  from trl import SFTTrainer, SFTConfig
37
 
38
 
39
- # Load multitask dataset
40
  print("Loading n8n-agentic-multitask dataset...")
41
- train_dataset = load_dataset(
 
 
42
  "stmasson/n8n-agentic-multitask",
43
  data_files="data/multitask_large/train.jsonl",
44
- split="train"
 
45
  )
46
- eval_dataset = load_dataset(
47
  "stmasson/n8n-agentic-multitask",
48
  data_files="data/multitask_large/val.jsonl",
49
- split="train"
 
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  print(f"Train: {len(train_dataset)} examples")
53
  print(f"Eval: {len(eval_dataset)} examples")
54
 
 
36
  from trl import SFTTrainer, SFTConfig
37
 
38
 
39
+ # Load multitask dataset - use streaming to avoid schema issues, then convert
40
  print("Loading n8n-agentic-multitask dataset...")
41
+
42
+ # Load with streaming to handle variable schema in metadata
43
+ train_stream = load_dataset(
44
  "stmasson/n8n-agentic-multitask",
45
  data_files="data/multitask_large/train.jsonl",
46
+ split="train",
47
+ streaming=True
48
  )
49
+ eval_stream = load_dataset(
50
  "stmasson/n8n-agentic-multitask",
51
  data_files="data/multitask_large/val.jsonl",
52
+ split="train",
53
+ streaming=True
54
  )
55
 
56
+ # Only keep the 'messages' column (required for SFT)
57
+ def extract_messages(example):
58
+ return {"messages": example["messages"]}
59
+
60
+ train_dataset = train_stream.map(extract_messages, remove_columns=["task_type", "metadata"])
61
+ eval_dataset = eval_stream.map(extract_messages, remove_columns=["task_type", "metadata"])
62
+
63
+ # Convert streaming to regular dataset (materializes in memory)
64
+ from datasets import Dataset
65
+ print("Converting streaming dataset to memory...")
66
+ train_dataset = Dataset.from_generator(lambda: (x for x in train_dataset))
67
+ eval_dataset = Dataset.from_generator(lambda: (x for x in eval_dataset))
68
+
69
  print(f"Train: {len(train_dataset)} examples")
70
  print(f"Eval: {len(eval_dataset)} examples")
71