Zenith_Copilot / train_simple.py
algorythmtechnologies's picture
Upload folder using huggingface_hub
4599e09 verified
import os
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments
)
from trl import SFTTrainer
from peft import LoraConfig
# 1. Configuration
BASE_MODEL = "DeepSeek-Coder-V2-Lite-Instruct"
OUTPUT_DIR = "outputs/zenith-lora-simple"
DATA_FILES = [
"data/zenith.jsonl",
"data/training_data_v2.jsonl",
"data/genesis_dataset_identity.jsonl",
"data/genesis_dataset_code.jsonl",
"data/genesis_dataset_orchestration.jsonl",
"data/genesis_dataset_tools.jsonl",
"data/genesis_dataset_teaching.jsonl",
"data/genesis_dataset_generation.jsonl",
]
# 2. Quantization Configuration
compute_dtype = torch.float16
if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8:
compute_dtype = torch.bfloat16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
llm_int8_enable_fp32_cpu_offload=True,
)
# 3. Load Model and Tokenizer
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto", # Keep auto for now, it's the most flexible
trust_remote_code=True,
)
model.config.use_cache = False
# 4. Load and Prepare Dataset
print(f"Loading datasets: {DATA_FILES}")
dataset = load_dataset("json", data_files=DATA_FILES, split="train")
def _valid(example):
msgs = example.get("messages")
if not isinstance(msgs, list) or not msgs:
return False
for m in msgs:
if not isinstance(m, dict) or "role" not in m or "content" not in m:
return False
return True
def _to_text(example):
try:
text = tokenizer.apply_chat_template(
example["messages"], tokenize=False, add_generation_prompt=False
)
return {"text": text}
except Exception:
return {"text": ""}
dataset = dataset.filter(_valid)
dataset = dataset.map(_to_text, remove_columns=dataset.column_names)
# Drop empty or pathological items
dataset = dataset.filter(lambda x: isinstance(x.get("text"), str) and len(x["text"]) > 0)
# 5. Create fixed train/validation split
print("Creating train/validation split...")
split_dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
# 6. LoRA Configuration
peft_config = LoraConfig(
lora_alpha=32,
lora_dropout=0.1,
r=16,
bias="none",
task_type="CAUSAL_LM",
)
# 7. Training Arguments
print("Defining training arguments...")
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=5e-5, # Lower learning rate for stability
lr_scheduler_type="cosine", # Cosine decay scheduler
warmup_steps=50, # Warmup steps
logging_steps=10,
max_steps=200,
save_steps=50,
save_total_limit=2, # Save only the best and the last checkpoints
evaluation_strategy="steps",
eval_steps=50,
load_best_model_at_end=True, # Load the best model at the end of training
metric_for_best_model="eval_loss",
greater_is_better=False,
max_grad_norm=1.0, # Gradient clipping
fp16=True if compute_dtype == torch.float16 else False,
bf16=True if compute_dtype == torch.bfloat16 else False,
gradient_checkpointing=True,
)
# 8. Initialize Trainer
print("Initializing trainer...")
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=2048,
tokenizer=tokenizer,
args=training_args,
packing=False,
)
# 8. Train
print("Starting training...")
trainer.train()
# 9. Save Model
print("Saving final model...")
trainer.save_model(OUTPUT_DIR)
print(f"✅ Training complete! Model saved to {OUTPUT_DIR}")