sunkencity commited on
Commit
8257d75
·
verified ·
1 Parent(s): 53e0ec1

Upload train_survival.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_survival.py +16 -19
train_survival.py CHANGED
@@ -44,6 +44,18 @@ model = AutoModelForCausalLM.from_pretrained(
44
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
45
  tokenizer.pad_token = tokenizer.eos_token
46
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # LoRA
48
  peft_config = LoraConfig(
49
  r=16,
@@ -67,32 +79,17 @@ training_args = SFTConfig(
67
  fp16=True,
68
  packing=False,
69
  max_length=1024,
70
- dataset_text_field="text"
71
  )
72
 
73
- def formatting_prompts_func(example):
74
- output_texts = []
75
- instructions = example['instruction']
76
- responses = example['response']
77
-
78
- for i in range(len(instructions)):
79
- if i >= len(responses): break
80
- instruction = instructions[i]
81
- response = responses[i]
82
- if not instruction or not response: continue
83
-
84
- text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"
85
- output_texts.append(text)
86
- return output_texts
87
-
88
  # Trainer
89
  trainer = SFTTrainer(
90
  model=model,
91
  train_dataset=dataset,
92
  peft_config=peft_config,
93
- formatting_func=formatting_prompts_func,
94
  args=training_args,
95
- processing_class=tokenizer, # CORRECTED: Using processing_class instead of tokenizer
 
96
  )
97
 
98
  print("Starting training...")
@@ -100,4 +97,4 @@ trainer.train()
100
 
101
  print("Pushing to hub...")
102
  trainer.push_to_hub()
103
- print("Done!")
 
44
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
45
  tokenizer.pad_token = tokenizer.eos_token
46
 
47
+ # MANUAL FORMATTING
48
+ # We do this manually to avoid SFTTrainer batching issues
49
+ def format_row(example):
50
+ instruction = example['instruction']
51
+ response = example['response']
52
+ # Qwen/Llama chat template format
53
+ text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>{tokenizer.eos_token}"
54
+ return {"text": text}
55
+
56
+ # Apply formatting manually
57
+ dataset = dataset.map(format_row)
58
+
59
  # LoRA
60
  peft_config = LoraConfig(
61
  r=16,
 
79
  fp16=True,
80
  packing=False,
81
  max_length=1024,
82
+ dataset_text_field="text" # Now this field exists and is correct
83
  )
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # Trainer
86
  trainer = SFTTrainer(
87
  model=model,
88
  train_dataset=dataset,
89
  peft_config=peft_config,
 
90
  args=training_args,
91
+ processing_class=tokenizer,
92
+ # Removed formatting_func argument
93
  )
94
 
95
  print("Starting training...")
 
97
 
98
  print("Pushing to hub...")
99
  trainer.push_to_hub()
100
+ print("Done!")