File size: 4,163 Bytes
b50a848 4599e09 b50a848 4599e09 b50a848 4599e09 b50a848 4599e09 b50a848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments
)
from trl import SFTTrainer
from peft import LoraConfig
# 1. Configuration
BASE_MODEL = "DeepSeek-Coder-V2-Lite-Instruct"
OUTPUT_DIR = "outputs/zenith-lora-simple"
DATA_FILES = [
"data/zenith.jsonl",
"data/training_data_v2.jsonl",
"data/genesis_dataset_identity.jsonl",
"data/genesis_dataset_code.jsonl",
"data/genesis_dataset_orchestration.jsonl",
"data/genesis_dataset_tools.jsonl",
"data/genesis_dataset_teaching.jsonl",
"data/genesis_dataset_generation.jsonl",
]
# 2. Quantization Configuration
compute_dtype = torch.float16
if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8:
compute_dtype = torch.bfloat16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
llm_int8_enable_fp32_cpu_offload=True,
)
# 3. Load Model and Tokenizer
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto", # Keep auto for now, it's the most flexible
trust_remote_code=True,
)
model.config.use_cache = False
# 4. Load and Prepare Dataset
print(f"Loading datasets: {DATA_FILES}")
dataset = load_dataset("json", data_files=DATA_FILES, split="train")
def _valid(example):
msgs = example.get("messages")
if not isinstance(msgs, list) or not msgs:
return False
for m in msgs:
if not isinstance(m, dict) or "role" not in m or "content" not in m:
return False
return True
def _to_text(example):
try:
text = tokenizer.apply_chat_template(
example["messages"], tokenize=False, add_generation_prompt=False
)
return {"text": text}
except Exception:
return {"text": ""}
dataset = dataset.filter(_valid)
dataset = dataset.map(_to_text, remove_columns=dataset.column_names)
# Drop empty or pathological items
dataset = dataset.filter(lambda x: isinstance(x.get("text"), str) and len(x["text"]) > 0)
# 5. Create fixed train/validation split
print("Creating train/validation split...")
split_dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
# 6. LoRA Configuration
peft_config = LoraConfig(
lora_alpha=32,
lora_dropout=0.1,
r=16,
bias="none",
task_type="CAUSAL_LM",
)
# 7. Training Arguments
print("Defining training arguments...")
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=5e-5, # Lower learning rate for stability
lr_scheduler_type="cosine", # Cosine decay scheduler
warmup_steps=50, # Warmup steps
logging_steps=10,
max_steps=200,
save_steps=50,
save_total_limit=2, # Save only the best and the last checkpoints
evaluation_strategy="steps",
eval_steps=50,
load_best_model_at_end=True, # Load the best model at the end of training
metric_for_best_model="eval_loss",
greater_is_better=False,
max_grad_norm=1.0, # Gradient clipping
fp16=True if compute_dtype == torch.float16 else False,
bf16=True if compute_dtype == torch.bfloat16 else False,
gradient_checkpointing=True,
)
# 8. Initialize Trainer
print("Initializing trainer...")
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=2048,
tokenizer=tokenizer,
args=training_args,
packing=False,
)
# 8. Train
print("Starting training...")
trainer.train()
# 9. Save Model
print("Saving final model...")
trainer.save_model(OUTPUT_DIR)
print(f"✅ Training complete! Model saved to {OUTPUT_DIR}")
|