jowilke77 commited on
Commit
37f6677
·
verified ·
1 Parent(s): 2062acc

Update train_lora.py

Browse files
Files changed (1) hide show
  1. train_lora.py +131 -21
train_lora.py CHANGED
@@ -1,47 +1,157 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
2
- from peft import LoraConfig, get_peft_model
 
 
 
 
 
 
3
  from datasets import load_dataset
 
 
4
 
5
  MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
 
6
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
9
 
 
10
  lora_config = LoraConfig(
11
- r=8,
12
- lora_alpha=16,
13
- target_modules=["q_proj", "v_proj"],
14
- lora_dropout=0.05,
15
  bias="none",
16
  task_type="CAUSAL_LM"
17
  )
18
 
19
  model = get_peft_model(model, lora_config)
 
20
 
 
 
21
  dataset = load_dataset("json", data_files="train.jsonl")
22
 
23
- def tokenize(example):
24
- text = tokenizer.apply_chat_template(
25
- example["messages"],
26
- tokenize=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
- return tokenizer(text, truncation=True)
 
 
 
 
29
 
30
- tokenized = dataset.map(tokenize)
 
 
 
 
 
 
31
 
32
- args = TrainingArguments(
 
 
 
 
 
 
 
33
  output_dir="./brad-ai-lora",
34
- per_device_train_batch_size=1,
35
- num_train_epochs=3,
36
- learning_rate=2e-4,
37
- logging_steps=5,
38
- save_strategy="epoch"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
 
 
41
  trainer = Trainer(
42
  model=model,
43
- args=args,
44
- train_dataset=tokenized["train"]
 
 
45
  )
46
 
 
 
47
  trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoTokenizer,
3
+ AutoModelForCausalLM,
4
+ TrainingArguments,
5
+ Trainer,
6
+ DataCollatorForLanguageModeling
7
+ )
8
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
9
  from datasets import load_dataset
10
+ import torch
11
+ import json
12
 
13
  MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
14
+ MAX_LENGTH = 512
15
 
16
+ # Load tokenizer and model
17
+ print("Loading model and tokenizer...")
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
+ tokenizer.pad_token = tokenizer.eos_token
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_NAME,
23
+ torch_dtype=torch.float16,
24
+ device_map="auto"
25
+ )
26
 
27
+ # Improved LoRA configuration
28
  lora_config = LoraConfig(
29
+ r=16, # Increased from 8 for better capacity
30
+ lora_alpha=32, # Increased from 16
31
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # More modules
32
+ lora_dropout=0.1, # Increased for better regularization
33
  bias="none",
34
  task_type="CAUSAL_LM"
35
  )
36
 
37
  model = get_peft_model(model, lora_config)
38
+ model.print_trainable_parameters()
39
 
40
+ # Load and split dataset
41
+ print("Loading dataset...")
42
  dataset = load_dataset("json", data_files="train.jsonl")
43
 
44
+ # Split into train/validation (80/20)
45
+ split_dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)
46
+ train_dataset = split_dataset["train"]
47
+ eval_dataset = split_dataset["test"]
48
+
49
+ print(f"Training samples: {len(train_dataset)}")
50
+ print(f"Validation samples: {len(eval_dataset)}")
51
+
52
+ def tokenize_function(examples):
53
+ """Tokenize the examples with proper formatting"""
54
+ texts = []
55
+ for messages in examples["messages"]:
56
+ # Apply chat template
57
+ text = tokenizer.apply_chat_template(
58
+ messages,
59
+ tokenize=False,
60
+ add_generation_prompt=False
61
+ )
62
+ texts.append(text)
63
+
64
+ # Tokenize with padding and truncation
65
+ tokenized = tokenizer(
66
+ texts,
67
+ truncation=True,
68
+ max_length=MAX_LENGTH,
69
+ padding="max_length",
70
+ return_tensors=None
71
  )
72
+
73
+ # Labels are the same as input_ids for causal LM
74
+ tokenized["labels"] = tokenized["input_ids"].copy()
75
+
76
+ return tokenized
77
 
78
+ # Tokenize datasets
79
+ print("Tokenizing datasets...")
80
+ tokenized_train = train_dataset.map(
81
+ tokenize_function,
82
+ batched=True,
83
+ remove_columns=train_dataset.column_names
84
+ )
85
 
86
+ tokenized_eval = eval_dataset.map(
87
+ tokenize_function,
88
+ batched=True,
89
+ remove_columns=eval_dataset.column_names
90
+ )
91
+
92
+ # Improved training arguments
93
+ training_args = TrainingArguments(
94
  output_dir="./brad-ai-lora",
95
+
96
+ # Training hyperparameters
97
+ num_train_epochs=5, # Increased from 3
98
+ per_device_train_batch_size=2, # Increased from 1
99
+ per_device_eval_batch_size=2,
100
+ gradient_accumulation_steps=4, # Effective batch size = 8
101
+
102
+ # Learning rate and scheduling
103
+ learning_rate=3e-4, # Slightly increased
104
+ lr_scheduler_type="cosine", # Better than default
105
+ warmup_ratio=0.1, # Warmup for 10% of training
106
+
107
+ # Optimization
108
+ optim="adamw_torch",
109
+ weight_decay=0.01,
110
+ max_grad_norm=1.0,
111
+
112
+ # Logging and evaluation
113
+ logging_steps=10,
114
+ eval_strategy="steps",
115
+ eval_steps=50,
116
+ save_strategy="steps",
117
+ save_steps=50,
118
+ save_total_limit=3, # Keep only best 3 checkpoints
119
+
120
+ # Performance
121
+ fp16=True, # Mixed precision training
122
+ dataloader_num_workers=2,
123
+
124
+ # Monitoring
125
+ load_best_model_at_end=True,
126
+ metric_for_best_model="eval_loss",
127
+ greater_is_better=False,
128
+
129
+ # Misc
130
+ report_to="none", # Change to "tensorboard" if you want logging
131
+ seed=42
132
  )
133
 
134
+ # Create trainer
135
  trainer = Trainer(
136
  model=model,
137
+ args=training_args,
138
+ train_dataset=tokenized_train,
139
+ eval_dataset=tokenized_eval,
140
+ tokenizer=tokenizer
141
  )
142
 
143
+ # Train the model
144
+ print("Starting training...")
145
  trainer.train()
146
+
147
+ # Save the final model
148
+ print("Saving model...")
149
+ trainer.save_model("./brad-ai-lora-final")
150
+ tokenizer.save_pretrained("./brad-ai-lora-final")
151
+
152
+ # Evaluate final model
153
+ print("Final evaluation:")
154
+ eval_results = trainer.evaluate()
155
+ print(eval_results)
156
+
157
+ print("Training complete!")