File size: 4,026 Bytes
c30a02e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29ff030
c30a02e
 
29ff030
c30a02e
 
 
 
 
 
29ff030
 
c30a02e
 
 
 
29ff030
 
c30a02e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29ff030
c30a02e
 
 
 
 
 
 
 
 
 
 
29ff030
 
c30a02e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29ff030
c30a02e
 
 
 
 
29ff030
 
c30a02e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
"""
Neo4j Expert Model Training Script

Fine-tunes Qwen2.5-Coder-7B-Instruct using QLoRA for Neo4j/Cypher expertise.

Usage:
    python train.py

Requires:
    pip install -r requirements_train.txt
"""

import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig
from trl import SFTTrainer

# === CONFIGURATION - NEO4J EXPERT MODEL ===

# Base model to fine-tune
MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"

# Dataset
DATASET_NAME = "ceperaltab/neo4j-cypher-dataset"

# Output directory for the adapter
OUTPUT_DIR = "neo4j-cypher-expert"

# Hugging Face Hub settings
HF_USERNAME = "ceperaltab"


def main():
    print("=" * 50)
    print("Neo4j Expert Model Training")
    print("=" * 50)
    
    # Load dataset
    print(f"\nLoading dataset from {DATASET_NAME}...")
    dataset = load_dataset(DATASET_NAME, split="train")
    print(f"Dataset size: {len(dataset)} examples")
    
    # 4-bit Quantization config for memory efficiency
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
    
    print(f"\nLoading base model: {MODEL_NAME}...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    # LoRA Configuration - Full coverage as specified
    peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,  # Rank
        bias="none",
        task_type="CAUSAL_LM",
        # Full target modules for comprehensive fine-tuning
        target_modules=[
            "q_proj",
            "k_proj", 
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
    )
    
    # Format chat messages using tokenizer's template (TRL v0.8.x API)
    def formatting_prompts_func(examples):
        output_texts = []
        for messages in examples['messages']:
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            output_texts.append(text)
        return output_texts
    
    # Training Arguments (TRL v0.8.x uses TrainingArguments from transformers)
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        learning_rate=2e-4,
        logging_steps=10,
        num_train_epochs=1,
        optim="paged_adamw_32bit",
        fp16=True,
        group_by_length=True,
        gradient_checkpointing=True,
        save_strategy="epoch",
        report_to="none",
        push_to_hub=True,
        hub_model_id=f"{HF_USERNAME}/{OUTPUT_DIR}",
    )
    
    # SFTTrainer (TRL v0.8.x API)
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        peft_config=peft_config,
        formatting_func=formatting_prompts_func,
        max_seq_length=1024,
        tokenizer=tokenizer,
        args=training_args,
    )
    
    print("\nStarting training...")
    print(f"  Base model: {MODEL_NAME}")
    print(f"  Dataset: {DATASET_NAME}")
    print(f"  Output: {OUTPUT_DIR}")
    print(f"  LoRA rank: {peft_config.r}")
    print(f"  Target modules: {peft_config.target_modules}")
    
    trainer.train()
    
    # Save the adapter
    trainer.save_model(OUTPUT_DIR)
    print(f"\nTraining complete! Adapter saved to {OUTPUT_DIR}")
    
    # Push to Hub
    print(f"Pushing to Hugging Face Hub: {HF_USERNAME}/{OUTPUT_DIR}")
    trainer.push_to_hub()
    
    print("\nDone!")


if __name__ == "__main__":
    main()