sunkencity commited on
Commit
d71ac87
·
verified ·
1 Parent(s): 29ee62e

Upload train_survival.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_survival.py +6 -11
train_survival.py CHANGED
@@ -17,7 +17,7 @@ OUTPUT_MODEL_ID = "sunkencity/survival-expert-3b"
17
  # Load Dataset
18
  dataset = load_dataset(DATASET_ID, split="train")
19
 
20
- # Load Model with Quantization (for efficiency)
21
  bnb_config = BitsAndBytesConfig(
22
  load_in_4bit=True,
23
  bnb_4bit_quant_type="nf4",
@@ -33,7 +33,7 @@ model = AutoModelForCausalLM.from_pretrained(
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
34
  tokenizer.pad_token = tokenizer.eos_token
35
 
36
- # LoRA Configuration
37
  peft_config = LoraConfig(
38
  r=16,
39
  lora_alpha=32,
@@ -43,8 +43,7 @@ peft_config = LoraConfig(
43
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
44
  )
45
 
46
- # Training Arguments
47
- # Removed max_seq_length from SFTConfig as it caused a TypeError
48
  training_args = SFTConfig(
49
  output_dir="./results",
50
  num_train_epochs=3,
@@ -55,32 +54,28 @@ training_args = SFTConfig(
55
  push_to_hub=True,
56
  hub_model_id=OUTPUT_MODEL_ID,
57
  fp16=True,
58
- dataset_text_field="text",
59
  packing=False
60
  )
61
 
62
- # Formatting function for SFT (Chat format)
63
  def formatting_prompts_func(example):
64
  output_texts = []
65
  for i in range(len(example['instruction'])):
66
  instruction = example['instruction'][i]
67
  response = example['response'][i]
68
-
69
- # Qwen/Llama chat template format
70
  text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"
71
  output_texts.append(text)
72
  return output_texts
73
 
74
  # Trainer
75
- # max_seq_length is passed to SFTTrainer directly
76
  trainer = SFTTrainer(
77
  model=model,
78
  train_dataset=dataset,
79
  peft_config=peft_config,
80
  formatting_func=formatting_prompts_func,
81
  args=training_args,
82
- tokenizer=tokenizer,
83
- max_seq_length=1024
84
  )
85
 
86
  print("Starting training...")
 
17
  # Load Dataset
18
  dataset = load_dataset(DATASET_ID, split="train")
19
 
20
+ # Load Model
21
  bnb_config = BitsAndBytesConfig(
22
  load_in_4bit=True,
23
  bnb_4bit_quant_type="nf4",
 
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
34
  tokenizer.pad_token = tokenizer.eos_token
35
 
36
+ # LoRA
37
  peft_config = LoraConfig(
38
  r=16,
39
  lora_alpha=32,
 
43
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
44
  )
45
 
46
+ # Args
 
47
  training_args = SFTConfig(
48
  output_dir="./results",
49
  num_train_epochs=3,
 
54
  push_to_hub=True,
55
  hub_model_id=OUTPUT_MODEL_ID,
56
  fp16=True,
57
+ dataset_text_field="text",
58
  packing=False
59
  )
60
 
 
61
  def formatting_prompts_func(example):
62
  output_texts = []
63
  for i in range(len(example['instruction'])):
64
  instruction = example['instruction'][i]
65
  response = example['response'][i]
 
 
66
  text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"
67
  output_texts.append(text)
68
  return output_texts
69
 
70
  # Trainer
 
71
  trainer = SFTTrainer(
72
  model=model,
73
  train_dataset=dataset,
74
  peft_config=peft_config,
75
  formatting_func=formatting_prompts_func,
76
  args=training_args,
77
+ processing_class=tokenizer, # New name for tokenizer
78
+ max_seq_length=1024 # Passed here
79
  )
80
 
81
  print("Starting training...")