ceperaltab commited on
Commit
29ff030
·
verified ·
1 Parent(s): 7fc8323

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +12 -17
train.py CHANGED
@@ -13,31 +13,29 @@ Requires:
13
 
14
  import os
15
  import torch
16
- from dotenv import load_dotenv
17
  from datasets import load_dataset
18
  from transformers import (
19
  AutoModelForCausalLM,
20
  AutoTokenizer,
21
  BitsAndBytesConfig,
 
22
  )
23
  from peft import LoraConfig
24
- from trl import SFTTrainer, SFTConfig
25
-
26
- load_dotenv()
27
 
28
  # === CONFIGURATION - NEO4J EXPERT MODEL ===
29
 
30
  # Base model to fine-tune
31
  MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
32
 
33
- # Dataset - loaded from environment or use default
34
- DATASET_NAME = os.getenv("HF_DATASET_NAME", "ceperaltab/neo4j-cypher-dataset")
35
 
36
  # Output directory for the adapter
37
  OUTPUT_DIR = "neo4j-cypher-expert"
38
 
39
- # Hugging Face Hub settings - loaded from environment or use default
40
- HF_USERNAME = os.getenv("HF_USERNAME", "ceperaltab")
41
 
42
 
43
  def main():
@@ -55,7 +53,6 @@ def main():
55
  load_in_4bit=True,
56
  bnb_4bit_quant_type="nf4",
57
  bnb_4bit_compute_dtype=torch.float16,
58
- bnb_4bit_use_double_quant=True,
59
  )
60
 
61
  print(f"\nLoading base model: {MODEL_NAME}...")
@@ -89,7 +86,7 @@ def main():
89
  ],
90
  )
91
 
92
- # Format chat messages using tokenizer's template
93
  def formatting_prompts_func(examples):
94
  output_texts = []
95
  for messages in examples['messages']:
@@ -101,8 +98,8 @@ def main():
101
  output_texts.append(text)
102
  return output_texts
103
 
104
- # Training Arguments (SFTConfig for TRL 0.27+)
105
- training_args = SFTConfig(
106
  output_dir=OUTPUT_DIR,
107
  per_device_train_batch_size=1,
108
  gradient_accumulation_steps=8,
@@ -115,20 +112,18 @@ def main():
115
  gradient_checkpointing=True,
116
  save_strategy="epoch",
117
  report_to="none",
118
- warmup_steps=100,
119
- lr_scheduler_type="cosine",
120
- # Push to Hugging Face Hub
121
  push_to_hub=True,
122
  hub_model_id=f"{HF_USERNAME}/{OUTPUT_DIR}",
123
  )
124
 
125
- # Initialize trainer (TRL 0.27+ API)
126
  trainer = SFTTrainer(
127
  model=model,
128
  train_dataset=dataset,
129
  peft_config=peft_config,
130
  formatting_func=formatting_prompts_func,
131
- processing_class=tokenizer, # renamed from 'tokenizer' in TRL 0.27+
 
132
  args=training_args,
133
  )
134
 
 
13
 
14
  import os
15
  import torch
 
16
  from datasets import load_dataset
17
  from transformers import (
18
  AutoModelForCausalLM,
19
  AutoTokenizer,
20
  BitsAndBytesConfig,
21
+ TrainingArguments,
22
  )
23
  from peft import LoraConfig
24
+ from trl import SFTTrainer
 
 
25
 
26
  # === CONFIGURATION - NEO4J EXPERT MODEL ===
27
 
28
  # Base model to fine-tune
29
  MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
30
 
31
+ # Dataset
32
+ DATASET_NAME = "ceperaltab/neo4j-cypher-dataset"
33
 
34
  # Output directory for the adapter
35
  OUTPUT_DIR = "neo4j-cypher-expert"
36
 
37
+ # Hugging Face Hub settings
38
+ HF_USERNAME = "ceperaltab"
39
 
40
 
41
  def main():
 
53
  load_in_4bit=True,
54
  bnb_4bit_quant_type="nf4",
55
  bnb_4bit_compute_dtype=torch.float16,
 
56
  )
57
 
58
  print(f"\nLoading base model: {MODEL_NAME}...")
 
86
  ],
87
  )
88
 
89
+ # Format chat messages using tokenizer's template (TRL v0.8.x API)
90
  def formatting_prompts_func(examples):
91
  output_texts = []
92
  for messages in examples['messages']:
 
98
  output_texts.append(text)
99
  return output_texts
100
 
101
+ # Training Arguments (TRL v0.8.x uses TrainingArguments from transformers)
102
+ training_args = TrainingArguments(
103
  output_dir=OUTPUT_DIR,
104
  per_device_train_batch_size=1,
105
  gradient_accumulation_steps=8,
 
112
  gradient_checkpointing=True,
113
  save_strategy="epoch",
114
  report_to="none",
 
 
 
115
  push_to_hub=True,
116
  hub_model_id=f"{HF_USERNAME}/{OUTPUT_DIR}",
117
  )
118
 
119
+ # SFTTrainer (TRL v0.8.x API)
120
  trainer = SFTTrainer(
121
  model=model,
122
  train_dataset=dataset,
123
  peft_config=peft_config,
124
  formatting_func=formatting_prompts_func,
125
+ max_seq_length=1024,
126
+ tokenizer=tokenizer,
127
  args=training_args,
128
  )
129