ceperaltab commited on
Commit
f368100
·
verified ·
1 Parent(s): 0cea67c

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +151 -0
train.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Elasticsearch Expert Model Training Script
4
+
5
+ Fine-tunes Qwen2.5-Coder-7B-Instruct using QLoRA for Elasticsearch Query & Mapping expertise.
6
+
7
+ Usage:
8
+ python train.py
9
+
10
+ Requires:
11
+ pip install -r requirements_train.txt
12
+ """
13
+
14
+ import os
15
+ import torch
16
+ from datasets import load_dataset
17
+ from transformers import (
18
+ AutoModelForCausalLM,
19
+ AutoTokenizer,
20
+ BitsAndBytesConfig,
21
+ TrainingArguments,
22
+ )
23
+ from peft import LoraConfig
24
+ from trl import SFTTrainer
25
+
26
+ # === CONFIGURATION - ELASTICSEARCH EXPERT MODEL ===
27
+
28
+ # Base model to fine-tune
29
+ MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
30
+
31
+ # Dataset (Update with your actual HF dataset)
32
+ DATASET_NAME = "ceperaltab/elasticsearch-dataset"
33
+
34
+ # Output directory for the adapter
35
+ OUTPUT_DIR = "elasticsearch-expert"
36
+
37
+ # Hugging Face Hub settings
38
+ HF_USERNAME = "ceperaltab"
39
+
40
+
41
+ def main():
42
+ print("=" * 50)
43
+ print("Elasticsearch Expert Model Training")
44
+ print("=" * 50)
45
+
46
+ # Load dataset
47
+ print(f"\nLoading dataset from {DATASET_NAME}...")
48
+ dataset = load_dataset(DATASET_NAME, split="train")
49
+ print(f"Dataset size: {len(dataset)} examples")
50
+
51
+ # 4-bit Quantization config for memory efficiency
52
+ bnb_config = BitsAndBytesConfig(
53
+ load_in_4bit=True,
54
+ bnb_4bit_quant_type="nf4",
55
+ bnb_4bit_compute_dtype=torch.float16,
56
+ )
57
+
58
+ print(f"\nLoading base model: {MODEL_NAME}...")
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ MODEL_NAME,
61
+ quantization_config=bnb_config,
62
+ device_map="auto",
63
+ trust_remote_code=True,
64
+ )
65
+
66
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+ tokenizer.padding_side = "right"
69
+
70
+ # LoRA Configuration - Full coverage as specified
71
+ peft_config = LoraConfig(
72
+ lora_alpha=16,
73
+ lora_dropout=0.1,
74
+ r=64, # Rank
75
+ bias="none",
76
+ task_type="CAUSAL_LM",
77
+ # Full target modules for comprehensive fine-tuning
78
+ target_modules=[
79
+ "q_proj",
80
+ "k_proj",
81
+ "v_proj",
82
+ "o_proj",
83
+ "gate_proj",
84
+ "up_proj",
85
+ "down_proj",
86
+ ],
87
+ )
88
+
89
+ # Format chat messages using tokenizer's template (TRL v0.8.x API)
90
+ def formatting_prompts_func(examples):
91
+ output_texts = []
92
+ for messages in examples['messages']:
93
+ text = tokenizer.apply_chat_template(
94
+ messages,
95
+ tokenize=False,
96
+ add_generation_prompt=False
97
+ )
98
+ output_texts.append(text)
99
+ return output_texts
100
+
101
+ # Training Arguments (TRL v0.8.x uses TrainingArguments from transformers)
102
+ training_args = TrainingArguments(
103
+ output_dir=OUTPUT_DIR,
104
+ per_device_train_batch_size=1,
105
+ gradient_accumulation_steps=8,
106
+ learning_rate=2e-4,
107
+ logging_steps=10,
108
+ num_train_epochs=1,
109
+ optim="paged_adamw_32bit",
110
+ fp16=True,
111
+ group_by_length=True,
112
+ gradient_checkpointing=True,
113
+ save_strategy="epoch",
114
+ report_to="none",
115
+ push_to_hub=True,
116
+ hub_model_id=f"{HF_USERNAME}/{OUTPUT_DIR}",
117
+ )
118
+
119
+ # SFTTrainer (TRL v0.8.x API)
120
+ trainer = SFTTrainer(
121
+ model=model,
122
+ train_dataset=dataset,
123
+ peft_config=peft_config,
124
+ formatting_func=formatting_prompts_func,
125
+ max_seq_length=1024,
126
+ tokenizer=tokenizer,
127
+ args=training_args,
128
+ )
129
+
130
+ print("\nStarting training...")
131
+ print(f" Base model: {MODEL_NAME}")
132
+ print(f" Dataset: {DATASET_NAME}")
133
+ print(f" Output: {OUTPUT_DIR}")
134
+ print(f" LoRA rank: {peft_config.r}")
135
+ print(f" Target modules: {peft_config.target_modules}")
136
+
137
+ trainer.train()
138
+
139
+ # Save the adapter
140
+ trainer.save_model(OUTPUT_DIR)
141
+ print(f"\nTraining complete! Adapter saved to {OUTPUT_DIR}")
142
+
143
+ # Push to Hub
144
+ print(f"Pushing to Hugging Face Hub: {HF_USERNAME}/{OUTPUT_DIR}")
145
+ trainer.push_to_hub()
146
+
147
+ print("\nDone!")
148
+
149
+
150
+ if __name__ == "__main__":
151
+ main()