ceperaltab commited on
Commit
c30a02e
·
verified ·
1 Parent(s): 1ef766b

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +158 -0
train.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Neo4j Expert Model Training Script
4
+
5
+ Fine-tunes Qwen2.5-Coder-7B-Instruct using QLoRA for Neo4j/Cypher expertise.
6
+
7
+ Usage:
8
+ python train.py
9
+
10
+ Requires:
11
+ pip install -r requirements_train.txt
12
+ """
13
+
14
+ import os
15
+ import torch
16
+ from dotenv import load_dotenv
17
+ from datasets import load_dataset
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ BitsAndBytesConfig,
22
+ TrainingArguments,
23
+ )
24
+ from peft import LoraConfig
25
+ from trl import SFTTrainer
26
+
27
+ load_dotenv()
28
+
29
+ # === CONFIGURATION - NEO4J EXPERT MODEL ===
30
+
31
+ # Base model to fine-tune
32
+ MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
33
+
34
+ # Dataset - loaded from environment
35
+ DATASET_NAME = os.getenv("HF_DATASET_NAME", "your-username/neo4j-dataset")
36
+
37
+ # Output directory for the adapter
38
+ OUTPUT_DIR = "neo4j-cypher-expert"
39
+
40
+ # Hugging Face Hub settings - loaded from environment
41
+ HF_USERNAME = os.getenv("HF_USERNAME", "your-username")
42
+
43
+
44
+ def main():
45
+ print("=" * 50)
46
+ print("Neo4j Expert Model Training")
47
+ print("=" * 50)
48
+
49
+ # Load dataset
50
+ print(f"\nLoading dataset from {DATASET_NAME}...")
51
+ dataset = load_dataset(DATASET_NAME, split="train")
52
+ print(f"Dataset size: {len(dataset)} examples")
53
+
54
+ # 4-bit Quantization config for memory efficiency
55
+ bnb_config = BitsAndBytesConfig(
56
+ load_in_4bit=True,
57
+ bnb_4bit_quant_type="nf4",
58
+ bnb_4bit_compute_dtype=torch.float16,
59
+ bnb_4bit_use_double_quant=True,
60
+ )
61
+
62
+ print(f"\nLoading base model: {MODEL_NAME}...")
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ MODEL_NAME,
65
+ quantization_config=bnb_config,
66
+ device_map="auto",
67
+ trust_remote_code=True,
68
+ )
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
71
+ tokenizer.pad_token = tokenizer.eos_token
72
+ tokenizer.padding_side = "right"
73
+
74
+ # LoRA Configuration - Full coverage as specified
75
+ peft_config = LoraConfig(
76
+ lora_alpha=16,
77
+ lora_dropout=0.1,
78
+ r=64, # Rank
79
+ bias="none",
80
+ task_type="CAUSAL_LM",
81
+ # Full target modules for comprehensive fine-tuning
82
+ target_modules=[
83
+ "q_proj",
84
+ "k_proj",
85
+ "v_proj",
86
+ "o_proj",
87
+ "gate_proj",
88
+ "up_proj",
89
+ "down_proj",
90
+ ],
91
+ )
92
+
93
+ # Format chat messages using tokenizer's template
94
+ def formatting_prompts_func(examples):
95
+ output_texts = []
96
+ for messages in examples['messages']:
97
+ text = tokenizer.apply_chat_template(
98
+ messages,
99
+ tokenize=False,
100
+ add_generation_prompt=False
101
+ )
102
+ output_texts.append(text)
103
+ return output_texts
104
+
105
+ # Training Arguments
106
+ training_args = TrainingArguments(
107
+ output_dir=OUTPUT_DIR,
108
+ per_device_train_batch_size=1,
109
+ gradient_accumulation_steps=8,
110
+ learning_rate=2e-4,
111
+ logging_steps=10,
112
+ num_train_epochs=1,
113
+ optim="paged_adamw_32bit",
114
+ fp16=True,
115
+ group_by_length=True,
116
+ gradient_checkpointing=True,
117
+ save_strategy="epoch",
118
+ report_to="none",
119
+ warmup_ratio=0.03,
120
+ lr_scheduler_type="cosine",
121
+ # Push to Hugging Face Hub
122
+ push_to_hub=True,
123
+ hub_model_id=f"{HF_USERNAME}/{OUTPUT_DIR}",
124
+ )
125
+
126
+ # Initialize trainer
127
+ trainer = SFTTrainer(
128
+ model=model,
129
+ train_dataset=dataset,
130
+ peft_config=peft_config,
131
+ formatting_func=formatting_prompts_func,
132
+ max_seq_length=1024,
133
+ tokenizer=tokenizer,
134
+ args=training_args,
135
+ )
136
+
137
+ print("\nStarting training...")
138
+ print(f" Base model: {MODEL_NAME}")
139
+ print(f" Dataset: {DATASET_NAME}")
140
+ print(f" Output: {OUTPUT_DIR}")
141
+ print(f" LoRA rank: {peft_config.r}")
142
+ print(f" Target modules: {peft_config.target_modules}")
143
+
144
+ trainer.train()
145
+
146
+ # Save the adapter
147
+ trainer.save_model(OUTPUT_DIR)
148
+ print(f"\nTraining complete! Adapter saved to {OUTPUT_DIR}")
149
+
150
+ # Push to Hub
151
+ print(f"Pushing to Hugging Face Hub: {HF_USERNAME}/{OUTPUT_DIR}")
152
+ trainer.push_to_hub()
153
+
154
+ print("\nDone!")
155
+
156
+
157
+ if __name__ == "__main__":
158
+ main()