Upload train_survival.py with huggingface_hub
Browse files- train_survival.py +6 -8
train_survival.py
CHANGED
|
@@ -39,21 +39,19 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 39 |
MODEL_ID,
|
| 40 |
quantization_config=bnb_config,
|
| 41 |
device_map="auto",
|
| 42 |
-
use_cache=False
|
|
|
|
| 43 |
)
|
| 44 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 45 |
tokenizer.pad_token = tokenizer.eos_token
|
| 46 |
|
| 47 |
# MANUAL FORMATTING
|
| 48 |
-
# We do this manually to avoid SFTTrainer batching issues
|
| 49 |
def format_row(example):
|
| 50 |
instruction = example['instruction']
|
| 51 |
response = example['response']
|
| 52 |
-
# Qwen/Llama chat template format
|
| 53 |
text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>{tokenizer.eos_token}"
|
| 54 |
return {"text": text}
|
| 55 |
|
| 56 |
-
# Apply formatting manually
|
| 57 |
dataset = dataset.map(format_row)
|
| 58 |
|
| 59 |
# LoRA
|
|
@@ -76,10 +74,11 @@ training_args = SFTConfig(
|
|
| 76 |
logging_steps=10,
|
| 77 |
push_to_hub=True,
|
| 78 |
hub_model_id=OUTPUT_MODEL_ID,
|
| 79 |
-
fp16=True,
|
|
|
|
| 80 |
packing=False,
|
| 81 |
max_length=1024,
|
| 82 |
-
dataset_text_field="text"
|
| 83 |
)
|
| 84 |
|
| 85 |
# Trainer
|
|
@@ -89,7 +88,6 @@ trainer = SFTTrainer(
|
|
| 89 |
peft_config=peft_config,
|
| 90 |
args=training_args,
|
| 91 |
processing_class=tokenizer,
|
| 92 |
-
# Removed formatting_func argument
|
| 93 |
)
|
| 94 |
|
| 95 |
print("Starting training...")
|
|
@@ -97,4 +95,4 @@ trainer.train()
|
|
| 97 |
|
| 98 |
print("Pushing to hub...")
|
| 99 |
trainer.push_to_hub()
|
| 100 |
-
print("Done!")
|
|
|
|
| 39 |
MODEL_ID,
|
| 40 |
quantization_config=bnb_config,
|
| 41 |
device_map="auto",
|
| 42 |
+
use_cache=False,
|
| 43 |
+
torch_dtype=torch.float16 # Explicitly set float16
|
| 44 |
)
|
| 45 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 46 |
tokenizer.pad_token = tokenizer.eos_token
|
| 47 |
|
| 48 |
# MANUAL FORMATTING
|
|
|
|
| 49 |
def format_row(example):
|
| 50 |
instruction = example['instruction']
|
| 51 |
response = example['response']
|
|
|
|
| 52 |
text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>{tokenizer.eos_token}"
|
| 53 |
return {"text": text}
|
| 54 |
|
|
|
|
| 55 |
dataset = dataset.map(format_row)
|
| 56 |
|
| 57 |
# LoRA
|
|
|
|
| 74 |
logging_steps=10,
|
| 75 |
push_to_hub=True,
|
| 76 |
hub_model_id=OUTPUT_MODEL_ID,
|
| 77 |
+
fp16=True, # Force FP16
|
| 78 |
+
bf16=False, # Disable BF16 explicitly
|
| 79 |
packing=False,
|
| 80 |
max_length=1024,
|
| 81 |
+
dataset_text_field="text"
|
| 82 |
)
|
| 83 |
|
| 84 |
# Trainer
|
|
|
|
| 88 |
peft_config=peft_config,
|
| 89 |
args=training_args,
|
| 90 |
processing_class=tokenizer,
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
print("Starting training...")
|
|
|
|
| 95 |
|
| 96 |
print("Pushing to hub...")
|
| 97 |
trainer.push_to_hub()
|
| 98 |
+
print("Done!")
|