ceperaltab's picture
Upload train.py with huggingface_hub
29ff030 verified
#!/usr/bin/env python3
"""
Neo4j Expert Model Training Script
Fine-tunes Qwen2.5-Coder-7B-Instruct using QLoRA for Neo4j/Cypher expertise.
Usage:
python train.py
Requires:
pip install -r requirements_train.txt
"""
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig
from trl import SFTTrainer
# === CONFIGURATION - NEO4J EXPERT MODEL ===
# Base model to fine-tune
MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
# Dataset
DATASET_NAME = "ceperaltab/neo4j-cypher-dataset"
# Output directory for the adapter
OUTPUT_DIR = "neo4j-cypher-expert"
# Hugging Face Hub settings
HF_USERNAME = "ceperaltab"
def main():
print("=" * 50)
print("Neo4j Expert Model Training")
print("=" * 50)
# Load dataset
print(f"\nLoading dataset from {DATASET_NAME}...")
dataset = load_dataset(DATASET_NAME, split="train")
print(f"Dataset size: {len(dataset)} examples")
# 4-bit Quantization config for memory efficiency
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
print(f"\nLoading base model: {MODEL_NAME}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# LoRA Configuration - Full coverage as specified
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64, # Rank
bias="none",
task_type="CAUSAL_LM",
# Full target modules for comprehensive fine-tuning
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
# Format chat messages using tokenizer's template (TRL v0.8.x API)
def formatting_prompts_func(examples):
output_texts = []
for messages in examples['messages']:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
output_texts.append(text)
return output_texts
# Training Arguments (TRL v0.8.x uses TrainingArguments from transformers)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=2e-4,
logging_steps=10,
num_train_epochs=1,
optim="paged_adamw_32bit",
fp16=True,
group_by_length=True,
gradient_checkpointing=True,
save_strategy="epoch",
report_to="none",
push_to_hub=True,
hub_model_id=f"{HF_USERNAME}/{OUTPUT_DIR}",
)
# SFTTrainer (TRL v0.8.x API)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
formatting_func=formatting_prompts_func,
max_seq_length=1024,
tokenizer=tokenizer,
args=training_args,
)
print("\nStarting training...")
print(f" Base model: {MODEL_NAME}")
print(f" Dataset: {DATASET_NAME}")
print(f" Output: {OUTPUT_DIR}")
print(f" LoRA rank: {peft_config.r}")
print(f" Target modules: {peft_config.target_modules}")
trainer.train()
# Save the adapter
trainer.save_model(OUTPUT_DIR)
print(f"\nTraining complete! Adapter saved to {OUTPUT_DIR}")
# Push to Hub
print(f"Pushing to Hugging Face Hub: {HF_USERNAME}/{OUTPUT_DIR}")
trainer.push_to_hub()
print("\nDone!")
if __name__ == "__main__":
main()