Upload train_survival.py with huggingface_hub
Browse files- train_survival.py +6 -11
train_survival.py
CHANGED
|
@@ -17,7 +17,7 @@ OUTPUT_MODEL_ID = "sunkencity/survival-expert-3b"
|
|
| 17 |
# Load Dataset
|
| 18 |
dataset = load_dataset(DATASET_ID, split="train")
|
| 19 |
|
| 20 |
-
# Load Model
|
| 21 |
bnb_config = BitsAndBytesConfig(
|
| 22 |
load_in_4bit=True,
|
| 23 |
bnb_4bit_quant_type="nf4",
|
|
@@ -33,7 +33,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 33 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 34 |
tokenizer.pad_token = tokenizer.eos_token
|
| 35 |
|
| 36 |
-
# LoRA
|
| 37 |
peft_config = LoraConfig(
|
| 38 |
r=16,
|
| 39 |
lora_alpha=32,
|
|
@@ -43,8 +43,7 @@ peft_config = LoraConfig(
|
|
| 43 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
| 44 |
)
|
| 45 |
|
| 46 |
-
#
|
| 47 |
-
# Removed max_seq_length from SFTConfig as it caused a TypeError
|
| 48 |
training_args = SFTConfig(
|
| 49 |
output_dir="./results",
|
| 50 |
num_train_epochs=3,
|
|
@@ -55,32 +54,28 @@ training_args = SFTConfig(
|
|
| 55 |
push_to_hub=True,
|
| 56 |
hub_model_id=OUTPUT_MODEL_ID,
|
| 57 |
fp16=True,
|
| 58 |
-
dataset_text_field="text",
|
| 59 |
packing=False
|
| 60 |
)
|
| 61 |
|
| 62 |
-
# Formatting function for SFT (Chat format)
|
| 63 |
def formatting_prompts_func(example):
|
| 64 |
output_texts = []
|
| 65 |
for i in range(len(example['instruction'])):
|
| 66 |
instruction = example['instruction'][i]
|
| 67 |
response = example['response'][i]
|
| 68 |
-
|
| 69 |
-
# Qwen/Llama chat template format
|
| 70 |
text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"
|
| 71 |
output_texts.append(text)
|
| 72 |
return output_texts
|
| 73 |
|
| 74 |
# Trainer
|
| 75 |
-
# max_seq_length is passed to SFTTrainer directly
|
| 76 |
trainer = SFTTrainer(
|
| 77 |
model=model,
|
| 78 |
train_dataset=dataset,
|
| 79 |
peft_config=peft_config,
|
| 80 |
formatting_func=formatting_prompts_func,
|
| 81 |
args=training_args,
|
| 82 |
-
|
| 83 |
-
max_seq_length=1024
|
| 84 |
)
|
| 85 |
|
| 86 |
print("Starting training...")
|
|
|
|
| 17 |
# Load Dataset
|
| 18 |
dataset = load_dataset(DATASET_ID, split="train")
|
| 19 |
|
| 20 |
+
# Load Model
|
| 21 |
bnb_config = BitsAndBytesConfig(
|
| 22 |
load_in_4bit=True,
|
| 23 |
bnb_4bit_quant_type="nf4",
|
|
|
|
| 33 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 34 |
tokenizer.pad_token = tokenizer.eos_token
|
| 35 |
|
| 36 |
+
# LoRA
|
| 37 |
peft_config = LoraConfig(
|
| 38 |
r=16,
|
| 39 |
lora_alpha=32,
|
|
|
|
| 43 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
| 44 |
)
|
| 45 |
|
| 46 |
+
# Args
|
|
|
|
| 47 |
training_args = SFTConfig(
|
| 48 |
output_dir="./results",
|
| 49 |
num_train_epochs=3,
|
|
|
|
| 54 |
push_to_hub=True,
|
| 55 |
hub_model_id=OUTPUT_MODEL_ID,
|
| 56 |
fp16=True,
|
| 57 |
+
dataset_text_field="text",
|
| 58 |
packing=False
|
| 59 |
)
|
| 60 |
|
|
|
|
| 61 |
def formatting_prompts_func(example):
|
| 62 |
output_texts = []
|
| 63 |
for i in range(len(example['instruction'])):
|
| 64 |
instruction = example['instruction'][i]
|
| 65 |
response = example['response'][i]
|
|
|
|
|
|
|
| 66 |
text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"
|
| 67 |
output_texts.append(text)
|
| 68 |
return output_texts
|
| 69 |
|
| 70 |
# Trainer
|
|
|
|
| 71 |
trainer = SFTTrainer(
|
| 72 |
model=model,
|
| 73 |
train_dataset=dataset,
|
| 74 |
peft_config=peft_config,
|
| 75 |
formatting_func=formatting_prompts_func,
|
| 76 |
args=training_args,
|
| 77 |
+
processing_class=tokenizer, # New name for tokenizer
|
| 78 |
+
max_seq_length=1024 # Passed here
|
| 79 |
)
|
| 80 |
|
| 81 |
print("Starting training...")
|