sunkencity commited on
Commit
9828b8a
·
verified ·
1 Parent(s): 8257d75

Upload train_survival.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_survival.py +6 -8
train_survival.py CHANGED
@@ -39,21 +39,19 @@ model = AutoModelForCausalLM.from_pretrained(
39
  MODEL_ID,
40
  quantization_config=bnb_config,
41
  device_map="auto",
42
- use_cache=False
 
43
  )
44
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
45
  tokenizer.pad_token = tokenizer.eos_token
46
 
47
  # MANUAL FORMATTING
48
- # We do this manually to avoid SFTTrainer batching issues
49
  def format_row(example):
50
  instruction = example['instruction']
51
  response = example['response']
52
- # Qwen/Llama chat template format
53
  text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>{tokenizer.eos_token}"
54
  return {"text": text}
55
 
56
- # Apply formatting manually
57
  dataset = dataset.map(format_row)
58
 
59
  # LoRA
@@ -76,10 +74,11 @@ training_args = SFTConfig(
76
  logging_steps=10,
77
  push_to_hub=True,
78
  hub_model_id=OUTPUT_MODEL_ID,
79
- fp16=True,
 
80
  packing=False,
81
  max_length=1024,
82
- dataset_text_field="text" # Now this field exists and is correct
83
  )
84
 
85
  # Trainer
@@ -89,7 +88,6 @@ trainer = SFTTrainer(
89
  peft_config=peft_config,
90
  args=training_args,
91
  processing_class=tokenizer,
92
- # Removed formatting_func argument
93
  )
94
 
95
  print("Starting training...")
@@ -97,4 +95,4 @@ trainer.train()
97
 
98
  print("Pushing to hub...")
99
  trainer.push_to_hub()
100
- print("Done!")
 
39
  MODEL_ID,
40
  quantization_config=bnb_config,
41
  device_map="auto",
42
+ use_cache=False,
43
+ torch_dtype=torch.float16 # Explicitly set float16
44
  )
45
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
46
  tokenizer.pad_token = tokenizer.eos_token
47
 
48
  # MANUAL FORMATTING
 
49
  def format_row(example):
50
  instruction = example['instruction']
51
  response = example['response']
 
52
  text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>{tokenizer.eos_token}"
53
  return {"text": text}
54
 
 
55
  dataset = dataset.map(format_row)
56
 
57
  # LoRA
 
74
  logging_steps=10,
75
  push_to_hub=True,
76
  hub_model_id=OUTPUT_MODEL_ID,
77
+ fp16=True, # Force FP16
78
+ bf16=False, # Disable BF16 explicitly
79
  packing=False,
80
  max_length=1024,
81
+ dataset_text_field="text"
82
  )
83
 
84
  # Trainer
 
88
  peft_config=peft_config,
89
  args=training_args,
90
  processing_class=tokenizer,
 
91
  )
92
 
93
  print("Starting training...")
 
95
 
96
  print("Pushing to hub...")
97
  trainer.push_to_hub()
98
+ print("Done!")