p11-p11 commited on
Commit
50a0747
·
verified ·
1 Parent(s): f22a74e

Upload train.py

Browse files
Files changed (1) hide show
  1. med/train.py +265 -0
med/train.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ from peft import LoraConfig, get_peft_model
5
+ import transformers
6
+ from datetime import datetime
7
+ import os
8
+
9
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 0 3090 1 2080
10
+
11
+ def apply_chat_template(example):
12
+ # Define the messages for the system, user, and assistant
13
+ messages = [
14
+ {
15
+ "role": "system",
16
+ "content": "You are a chess grandmaster specializing in finding checkmate moves in any chess position."
17
+ },
18
+ {
19
+ "role": "user",
20
+ "content": f"Given the following chessboard, identify the move that delivers checkmate:\n\n{example['board']}\n\n"
21
+ },
22
+ {
23
+ "role": "assistant",
24
+ "content": f"The move to achieve checkmate is: {example['mate']}"
25
+ }
26
+ ]
27
+
28
+ # Format the text manually following the template, ensuring proper spacing
29
+ formatted_text = ""
30
+ for msg in messages:
31
+ formatted_text += f"{msg['content']} "
32
+
33
+ example["text"] = formatted_text.strip() # Remove trailing spaces
34
+ return example
35
+
36
+
37
+ def main():
38
+ # Define the local paths to your CSV files
39
+ data_files = {
40
+ 'train': '/home/luciano/Documents/Tesis Ezequiel/Tesis/data_boards/high_train.csv',
41
+ 'test': '/home/luciano/Documents/Tesis Ezequiel/Tesis/data_boards/high_test.csv',
42
+ }
43
+
44
+ # Load the dataset from local CSV files
45
+ dataset = load_dataset(
46
+ 'csv',
47
+ data_files=data_files,
48
+ delimiter=',', # Specify the delimiter for CSV
49
+ usecols=['board', 'mate'], # Load only the required columns
50
+ on_bad_lines='skip', # Skip bad lines that cause parsing errors
51
+ )
52
+
53
+ # Select a subset of the data for train and test (increase this for actual training)
54
+ # For demonstration, using 5 training examples and 2 test examples
55
+ train_dataset = dataset['train']
56
+ eval_dataset = dataset['test']
57
+
58
+ print('Train Dataset:', train_dataset, '\nTest Dataset:', eval_dataset)
59
+
60
+ # Apply the chat template
61
+ train_dataset = train_dataset.map(
62
+ apply_chat_template,
63
+ num_proc=2,
64
+ #remove_columns=['board', 'mate']
65
+ )
66
+
67
+ eval_dataset = eval_dataset.map(
68
+ apply_chat_template,
69
+ num_proc=2,
70
+ #remove_columns=['board', 'mate'],
71
+ desc="Applying chat template"
72
+ )
73
+
74
+ # Inspect the first example after applying the chat template
75
+ print("\nFirst Training Example Text:\n", train_dataset[0]['text'])
76
+
77
+ # Configure quantization
78
+ quantization_config = BitsAndBytesConfig(
79
+ load_in_4bit=True,
80
+ bnb_4bit_quant_type="nf4",
81
+ bnb_4bit_compute_dtype=torch.bfloat16,
82
+ )
83
+
84
+ model_id = 'mistralai/Mistral-7B-Instruct-v0.3'
85
+
86
+ # Load the model
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ model_id,
89
+ attn_implementation='eager',
90
+ trust_remote_code=True,
91
+ quantization_config=quantization_config,
92
+ device_map="auto"
93
+ )
94
+
95
+ print("Model is loaded on device:", next(model.parameters()).device) # Should return cuda:0 if loaded onto GPU
96
+
97
+ # Load the tokenizer
98
+ tokenizer = AutoTokenizer.from_pretrained(
99
+ model_id,
100
+ padding_side="right", # Changed to 'right' to align with our padding strategy
101
+ use_fast=False, # needed for now, should be fixed soon
102
+ )
103
+ tokenizer.pad_token = tokenizer.eos_token
104
+
105
+ # Verify tokenizer special tokens
106
+ print("\nTokenizer Special Tokens:")
107
+ print("EOS Token:", tokenizer.eos_token)
108
+ print("BOS Token:", tokenizer.bos_token)
109
+ print("PAD Token:", tokenizer.pad_token)
110
+
111
+ def generate_and_tokenize_prompt(data_point):
112
+ # Define the prompt and the expected response
113
+ prompt = (
114
+ "You are a chess grandmaster specializing in finding checkmate moves in any chess position. "
115
+ "Given the following chessboard, identify the move that delivers checkmate:\n\n"
116
+ f"{data_point['board']}\n\n"
117
+ )
118
+ response = f"The move to achieve checkmate is: {data_point['mate']}"
119
+
120
+ # Tokenize prompt and response together
121
+ tokenized = tokenizer(
122
+ prompt + response,
123
+ padding='max_length',
124
+ truncation=True,
125
+ max_length=200,
126
+ return_tensors='pt',
127
+ )
128
+
129
+ input_ids = tokenized['input_ids'][0].tolist()
130
+ attention_mask = tokenized['attention_mask'][0].tolist()
131
+
132
+ # Find the start index of the response
133
+ response_start_str = response
134
+ response_start_idx = (prompt + response).find(response_start_str)
135
+
136
+ if response_start_idx == -1:
137
+ print("Warning: Response start string not found in the concatenated text.")
138
+ response_start_idx = len(prompt) # Fallback to end of prompt
139
+
140
+ # Tokenize the prompt to find the token index
141
+ prompt_tokenized = tokenizer(
142
+ prompt,
143
+ add_special_tokens=False,
144
+ return_tensors='pt'
145
+ )
146
+ prompt_length = prompt_tokenized['input_ids'].shape[1]
147
+
148
+ # Create labels: mask the prompt tokens with -100
149
+ labels = [-100] * prompt_length + input_ids[prompt_length:]
150
+
151
+ # If the total length is less than max_length, pad the remaining labels with -100
152
+ if len(labels) < 200:
153
+ labels += [-100] * (200 - len(labels))
154
+ else:
155
+ labels = labels[:200]
156
+
157
+ # Ensure input_ids and labels are exactly 200 tokens
158
+ input_ids = input_ids[:200]
159
+ attention_mask = attention_mask[:200]
160
+ labels = labels[:200]
161
+
162
+ """ # Debug prints to verify correctness
163
+ print("\n--- Tokenization Debug ---")
164
+ print("Prompt Text:\n", prompt)
165
+ print("Response Text:\n", response)
166
+ print("Prompt Token IDs:", prompt_tokenized['input_ids'][0].tolist())
167
+ print("Response Token IDs:", input_ids[prompt_length:])
168
+ print("Combined Input IDs:", input_ids)
169
+ print("Combined Attention Mask:", attention_mask)
170
+ print("Combined Labels:", labels)
171
+ print("Decoded Input IDs:\n", tokenizer.decode(input_ids, skip_special_tokens=False))
172
+ print("--- End of Debug ---\n")"""
173
+
174
+ return {
175
+ 'input_ids': input_ids,
176
+ 'attention_mask': attention_mask,
177
+ 'labels': labels
178
+ }
179
+
180
+
181
+ # Define the tokenization function with proper debugging
182
+ def generate_and_tokenize_prompt_wrapper(x):
183
+ return generate_and_tokenize_prompt(x)
184
+
185
+ # Tokenize the datasets
186
+ tokenized_train_dataset = train_dataset.map(
187
+ generate_and_tokenize_prompt_wrapper,
188
+ remove_columns=['text'],
189
+ batched=False,
190
+ )
191
+
192
+ tokenized_val_dataset = eval_dataset.map(
193
+ generate_and_tokenize_prompt_wrapper,
194
+ remove_columns=['text'],
195
+ batched=False,
196
+ )
197
+
198
+ # Inspect a sample from the tokenized training dataset
199
+ sample = tokenized_train_dataset[0]
200
+ print("\n--- Tokenized Sample ---")
201
+ print("Input IDs:", sample['input_ids'])
202
+ print("Attention Mask:", sample['attention_mask'])
203
+ print("Labels:", sample['labels'])
204
+ print("Decoded Input IDs:\n", tokenizer.decode(sample['input_ids'], skip_special_tokens=False))
205
+ print("--- End of Sample ---\n")
206
+
207
+ # Set up LoRA
208
+ lora_config = LoraConfig(
209
+ r=64,
210
+ lora_alpha=16,
211
+ lora_dropout=0.1,
212
+ bias="none",
213
+ task_type="CAUSAL_LM",
214
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
215
+ )
216
+
217
+ model = get_peft_model(model, lora_config)
218
+
219
+ project = "tesis"
220
+ base_model_name = "med"
221
+ run_name = f"{base_model_name}-{project}"
222
+ output_dir = f"./{run_name}"
223
+
224
+ # Define TrainingArguments
225
+ training_args = transformers.TrainingArguments(
226
+ output_dir=output_dir,
227
+ max_grad_norm=1.0, # Clip gradients to prevent exploding gradients
228
+ warmup_steps=100,
229
+ num_train_epochs=1, # Adjust as needed
230
+ per_device_train_batch_size=11, # 11 3090
231
+ per_device_eval_batch_size=10, # 10 3090
232
+ gradient_accumulation_steps=4, # To simulate a larger batch size
233
+ evaluation_strategy="epoch",
234
+ eval_steps=50, # Adjust based on dataset size
235
+ save_steps=1000, # Adjust based on dataset size
236
+ logging_steps=10, # More frequent logging for debugging
237
+ learning_rate=1e-5,
238
+ fp16=True,
239
+ logging_dir=r"/home/luciano/Documents/Tesis Ezequiel/Tesis/med/logs_med",
240
+ report_to="tensorboard", # Change to "tensorboard" or "wandb" if needed
241
+ run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
242
+ )
243
+
244
+ # Initialize the Trainer
245
+ trainer = transformers.Trainer(
246
+ model=model,
247
+ train_dataset=tokenized_train_dataset,
248
+ eval_dataset=tokenized_val_dataset,
249
+ args=training_args,
250
+ data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
251
+ )
252
+
253
+ # Disable cache to silence warnings
254
+ model.config.use_cache = False
255
+
256
+ # Start training
257
+ trainer.train(resume_from_checkpoint=r'/home/luciano/Documents/Tesis Ezequiel/Tesis/med/med_checkpoint')
258
+ # Save the model and tokenizer
259
+ #trainer.train()
260
+ trainer.save_model("./fine-tuned-model_high")
261
+ tokenizer.save_pretrained("./fine-tuned-model_high")
262
+
263
+
264
+ if __name__ == "__main__":
265
+ main()