ceperaltab commited on
Commit
5d72e96
·
verified ·
1 Parent(s): 396494d

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +24 -27
train.py CHANGED
@@ -8,11 +8,9 @@ from transformers import (
8
  TrainingArguments,
9
  )
10
  from peft import LoraConfig
11
- from trl import SFTTrainer, SFTConfig
12
 
13
  # --- CONFIGURATION ---
14
- # Base model: Using a quantized Llama 3 or Mistral is recommended for consumer GPUs.
15
- # Ensure you have access to the model on Hugging Face (might need login).
16
  MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
17
  DATASET_NAME = "ceperaltab/elixir-golden-dataset"
18
  OUTPUT_DIR = "elixir-model-qwen"
@@ -21,7 +19,6 @@ def main():
21
  print(f"Loading dataset from {DATASET_NAME}...")
22
  # 1. Load Dataset
23
  try:
24
- # Load directly from HF Hub
25
  dataset = load_dataset(DATASET_NAME, split="train")
26
  except Exception as e:
27
  print(f"Error loading dataset: {e}")
@@ -46,9 +43,9 @@ def main():
46
  # 4. Load Tokenizer
47
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
48
  tokenizer.pad_token = tokenizer.eos_token
49
- tokenizer.padding_side = "right" # Critical for fp16 training
50
 
51
- # 5. LoRA Config (Parameter Efficient Fine-Tuning)
52
  peft_config = LoraConfig(
53
  lora_alpha=16,
54
  lora_dropout=0.1,
@@ -59,41 +56,41 @@ def main():
59
  )
60
 
61
  # 6. Formatting Function for Chat Dataset
62
- # Converts {"messages": [...]} into the model's expected prompt format
63
  def formatting_prompts_func(examples):
64
  output_texts = []
65
  for messages in examples['messages']:
66
- # Apply chat template (e.g., <|begin_of_text|><|start_header_id|>user...)
67
- # We don't tokenize yet, SFTTrainer handles it
68
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
69
  output_texts.append(text)
70
  return output_texts
71
 
72
  print("Starting SFTTrainer setup...")
73
- # 7. Trainer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  trainer = SFTTrainer(
75
  model=model,
76
  train_dataset=dataset,
77
  peft_config=peft_config,
78
  formatting_func=formatting_prompts_func,
79
  tokenizer=tokenizer,
80
- args=SFTConfig(
81
- output_dir=OUTPUT_DIR,
82
- max_seq_length=2048, # Moved here
83
- per_device_train_batch_size=2,
84
- gradient_accumulation_steps=4,
85
- learning_rate=2e-4,
86
- logging_steps=10,
87
- num_train_epochs=1,
88
- optim="paged_adamw_32bit",
89
- fp16=True,
90
- group_by_length=True,
91
- save_strategy="epoch",
92
- report_to="none",
93
- push_to_hub=True,
94
- hub_model_id=f"ceperaltab/{OUTPUT_DIR}",
95
- dataset_text_field="text", # SFTConfig requires this or packing, though we use formatting_func
96
- ),
97
  )
98
 
99
  print("Starting training...")
 
8
  TrainingArguments,
9
  )
10
  from peft import LoraConfig
11
+ from trl import SFTTrainer
12
 
13
  # --- CONFIGURATION ---
 
 
14
  MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
15
  DATASET_NAME = "ceperaltab/elixir-golden-dataset"
16
  OUTPUT_DIR = "elixir-model-qwen"
 
19
  print(f"Loading dataset from {DATASET_NAME}...")
20
  # 1. Load Dataset
21
  try:
 
22
  dataset = load_dataset(DATASET_NAME, split="train")
23
  except Exception as e:
24
  print(f"Error loading dataset: {e}")
 
43
  # 4. Load Tokenizer
44
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
45
  tokenizer.pad_token = tokenizer.eos_token
46
+ tokenizer.padding_side = "right"
47
 
48
+ # 5. LoRA Config
49
  peft_config = LoraConfig(
50
  lora_alpha=16,
51
  lora_dropout=0.1,
 
56
  )
57
 
58
  # 6. Formatting Function for Chat Dataset
 
59
  def formatting_prompts_func(examples):
60
  output_texts = []
61
  for messages in examples['messages']:
 
 
62
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
63
  output_texts.append(text)
64
  return output_texts
65
 
66
  print("Starting SFTTrainer setup...")
67
+
68
+ # 7. Training Arguments (using transformers TrainingArguments for stability)
69
+ training_args = TrainingArguments(
70
+ output_dir=OUTPUT_DIR,
71
+ per_device_train_batch_size=2,
72
+ gradient_accumulation_steps=4,
73
+ learning_rate=2e-4,
74
+ logging_steps=10,
75
+ num_train_epochs=1,
76
+ optim="paged_adamw_32bit",
77
+ fp16=True,
78
+ group_by_length=True,
79
+ save_strategy="epoch",
80
+ report_to="none",
81
+ push_to_hub=True,
82
+ hub_model_id=f"ceperaltab/{OUTPUT_DIR}",
83
+ )
84
+
85
+ # 8. Trainer - use older stable API
86
  trainer = SFTTrainer(
87
  model=model,
88
  train_dataset=dataset,
89
  peft_config=peft_config,
90
  formatting_func=formatting_prompts_func,
91
  tokenizer=tokenizer,
92
+ args=training_args,
93
+ max_seq_length=2048, # Passed directly to SFTTrainer (old API)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
 
96
  print("Starting training...")