Prithvik-1 commited on
Commit
8514fc9
·
verified ·
1 Parent(s): eada9ff

Upload scripts/training/finetune_mistral7b.py with huggingface_hub

Browse files
scripts/training/finetune_mistral7b.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Fine-tuning script for Mistral models (7B, 3B, etc.) using LoRA (Low-Rank Adaptation)
4
+ This script uses Hugging Face Transformers, PEFT, and BitsAndBytes for efficient training.
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ from datasets import load_dataset
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ TrainingArguments,
14
+ BitsAndBytesConfig,
15
+ Trainer,
16
+ DataCollatorForLanguageModeling
17
+ )
18
+ from peft import (
19
+ LoraConfig,
20
+ PeftModel,
21
+ get_peft_model,
22
+ prepare_model_for_kbit_training,
23
+ TaskType,
24
+ )
25
+ import json
26
+
27
+ def get_device_info():
28
+ """Detect and return available compute device"""
29
+ device_info = {
30
+ "device": "cpu",
31
+ "device_type": "cpu",
32
+ "use_quantization": False,
33
+ "dtype": torch.float32
34
+ }
35
+
36
+ if torch.cuda.is_available():
37
+ device_info["device"] = "cuda"
38
+ device_info["device_type"] = "cuda"
39
+ device_info["use_quantization"] = True
40
+ device_info["dtype"] = torch.float16
41
+ device_info["device_count"] = torch.cuda.device_count()
42
+ device_info["device_name"] = torch.cuda.get_device_name(0)
43
+ print(f"✓ CUDA GPU detected: {device_info['device_name']} (Count: {device_info['device_count']})")
44
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
45
+ device_info["device"] = "mps"
46
+ device_info["device_type"] = "mps"
47
+ device_info["use_quantization"] = False # BitsAndBytes doesn't support MPS
48
+ device_info["dtype"] = torch.float16
49
+ print("✓ Apple Silicon GPU (MPS) detected")
50
+ else:
51
+ print("⚠ No GPU detected, using CPU (training will be very slow)")
52
+ device_info["dtype"] = torch.float32
53
+
54
+ return device_info
55
+
56
+ # Defaults
57
+ DEFAULT_BASE_MODEL = "mistralai/Mistral-7B-v0.1"
58
+ DEFAULT_OUTPUT_DIR = "./mistral-finetuned"
59
+ DEFAULT_DATASET_PATH = "./training_data.jsonl" # Path to your training data
60
+
61
+ # LoRA Configuration - Updated with increased dropout for regularization
62
+ LORA_CONFIG = LoraConfig(
63
+ r=16, # Rank
64
+ lora_alpha=32, # LoRA alpha scaling parameter
65
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
66
+ lora_dropout=0.1, # Increased from 0.05 to 0.1 for better regularization
67
+ bias="none",
68
+ task_type=TaskType.CAUSAL_LM,
69
+ )
70
+
71
+ # BitsAndBytes Configuration for 4-bit quantization (CUDA only)
72
+ def get_bitsandbytes_config():
73
+ """Get BitsAndBytes config if CUDA is available, otherwise None"""
74
+ if torch.cuda.is_available():
75
+ return BitsAndBytesConfig(
76
+ load_in_4bit=True,
77
+ bnb_4bit_quant_type="nf4",
78
+ bnb_4bit_compute_dtype=torch.float16,
79
+ bnb_4bit_use_double_quant=True,
80
+ )
81
+ return None
82
+
83
+ def load_and_prepare_model(model_name: str, adapter_path: str | None = None):
84
+ """Load the specified Mistral model, optionally warm-starting from an existing LoRA adapter."""
85
+ device_info = get_device_info()
86
+ print(f"\nLoading model: {model_name}")
87
+
88
+ tokenizer_source = adapter_path if adapter_path and os.path.isdir(adapter_path) else model_name
89
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_source)
90
+ if tokenizer.pad_token is None:
91
+ tokenizer.pad_token = tokenizer.eos_token
92
+ tokenizer.pad_token_id = tokenizer.eos_token_id
93
+
94
+ # Get quantization config (CUDA only)
95
+ bnb_config = get_bitsandbytes_config()
96
+
97
+ # Prepare model loading kwargs
98
+ model_kwargs = {
99
+ "trust_remote_code": True,
100
+ }
101
+
102
+ if bnb_config is not None:
103
+ # Use 4-bit quantization on CUDA
104
+ print("Using 4-bit quantization (CUDA)")
105
+ model_kwargs["quantization_config"] = bnb_config
106
+ model_kwargs["device_map"] = "auto"
107
+ elif device_info["device_type"] == "mps":
108
+ # Use MPS with float16
109
+ print(f"Using MPS device with {device_info['dtype']}")
110
+ model_kwargs["torch_dtype"] = device_info["dtype"]
111
+ model_kwargs["device_map"] = "auto"
112
+ else:
113
+ # CPU fallback
114
+ print("Using CPU (no quantization)")
115
+ model_kwargs["torch_dtype"] = torch.float32
116
+ model_kwargs["device_map"] = "cpu"
117
+
118
+ # Load base model
119
+ base_model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
120
+
121
+ # Prepare model for k-bit training (only if using quantization)
122
+ if bnb_config is not None:
123
+ base_model = prepare_model_for_kbit_training(base_model)
124
+
125
+ if adapter_path:
126
+ print(f"Loading existing LoRA adapter from: {adapter_path}")
127
+ model = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=True)
128
+ else:
129
+ model = get_peft_model(base_model, LORA_CONFIG)
130
+
131
+ # Enable gradient checkpointing to save memory
132
+ model.gradient_checkpointing_enable()
133
+
134
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
135
+ total_params = sum(p.numel() for p in model.parameters())
136
+ print(f"Model loaded successfully!")
137
+ print(f" - Device: {device_info['device']}")
138
+ print(f" - Trainable parameters: {trainable_params:,}")
139
+ print(f" - Total parameters: {total_params:,}")
140
+ print(f" - Trainable ratio: {100 * trainable_params / total_params:.2f}%\n")
141
+
142
+ return model, tokenizer, device_info
143
+
144
+ def load_training_data(file_path):
145
+ """Load training data from JSONL file"""
146
+ print(f"Loading training data from {file_path}")
147
+
148
+ if not os.path.exists(file_path):
149
+ print(f"Warning: {file_path} not found. Creating a sample dataset...")
150
+ # Create a sample dataset for demonstration
151
+ sample_data = [
152
+ {"instruction": "What is AI?", "response": "AI (Artificial Intelligence) is the simulation of human intelligence by machines."},
153
+ {"instruction": "Explain machine learning", "response": "Machine learning is a subset of AI that enables systems to learn from data."},
154
+ ]
155
+ with open(file_path, 'w') as f:
156
+ for item in sample_data:
157
+ f.write(json.dumps(item) + '\n')
158
+ print(f"Sample dataset created at {file_path}")
159
+
160
+ data = []
161
+ with open(file_path, 'r') as f:
162
+ for line in f:
163
+ data.append(json.loads(line))
164
+
165
+ return data
166
+
167
+ def clean_completion(completion):
168
+ """Remove format markers from completion"""
169
+ if not completion:
170
+ return completion
171
+ # Remove format markers if present
172
+ if "### Strict JSON ###" in completion:
173
+ completion = completion.split("### Strict JSON ###")[1]
174
+ if "### End ###" in completion:
175
+ completion = completion.split("### End ###")[0]
176
+ return completion.strip()
177
+
178
+ def format_prompt(instruction, response=None):
179
+ """Format training examples as prompts"""
180
+ # Clean response to remove format markers
181
+ if response:
182
+ response = clean_completion(response)
183
+ prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
184
+ if response:
185
+ prompt += f"{response}"
186
+ return prompt
187
+
188
+ def tokenize_function(examples, tokenizer, max_length=512):
189
+ """Tokenize the training examples"""
190
+ texts = [format_prompt(inst, resp) for inst, resp in zip(examples["instruction"], examples["response"])]
191
+
192
+ tokenized = tokenizer(
193
+ texts,
194
+ truncation=True,
195
+ padding="max_length",
196
+ max_length=max_length,
197
+ return_tensors="pt"
198
+ )
199
+
200
+ tokenized["labels"] = tokenized["input_ids"].clone()
201
+ return tokenized
202
+
203
+ def main():
204
+ import argparse
205
+
206
+ parser = argparse.ArgumentParser(description="Fine-tune Mistral models with LoRA")
207
+ parser.add_argument("--base-model", default=DEFAULT_BASE_MODEL, help="HF model id (e.g. mistralai/Mistral-7B-v0.1 or mistralai/Mistral-3B-v0.1)")
208
+ parser.add_argument("--adapter-path", default=None, help="Optional path to existing LoRA adapters to continue training")
209
+ parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR, help="Where to write the fine-tuned adapters")
210
+ parser.add_argument("--dataset", default=DEFAULT_DATASET_PATH, help="Path to training data JSONL")
211
+ parser.add_argument("--max-length", type=int, default=512, help="Max sequence length for tokenization")
212
+ args = parser.parse_args()
213
+
214
+ print("Starting Mistral Fine-tuning with LoRA")
215
+ print("=" * 50)
216
+ print(f"Base model: {args.base_model}")
217
+ print(f"Training data: {args.dataset}")
218
+ print(f"Output dir: {args.output_dir}\n")
219
+
220
+ # Load model and tokenizer
221
+ model, tokenizer, device_info = load_and_prepare_model(args.base_model, args.adapter_path)
222
+
223
+ # Load training data
224
+ training_data = load_training_data(args.dataset)
225
+
226
+ # Convert to dataset format
227
+ instructions = []
228
+ responses = []
229
+
230
+ for item in training_data:
231
+ if "instruction" in item:
232
+ instructions.append(item["instruction"])
233
+ responses.append(item.get("response", ""))
234
+ elif "prompt" in item and "completion" in item:
235
+ instructions.append(item["prompt"])
236
+ completion_value = item["completion"]
237
+ if isinstance(completion_value, (dict, list)):
238
+ responses.append(json.dumps(completion_value))
239
+ else:
240
+ responses.append(str(completion_value))
241
+ elif "messages" in item:
242
+ messages = item["messages"]
243
+ if not isinstance(messages, list) or len(messages) == 0:
244
+ raise KeyError("'messages' entries must be non-empty lists.")
245
+
246
+ prompt_parts = []
247
+ assistant_reply = None
248
+
249
+ for idx, message in enumerate(messages):
250
+ role = message.get("role", "user")
251
+ content = str(message.get("content", "")).strip()
252
+
253
+ if idx == len(messages) - 1 and role == "assistant":
254
+ assistant_reply = content
255
+ else:
256
+ role_label = role.upper()
257
+ prompt_parts.append(f"{role_label}: {content}")
258
+
259
+ if assistant_reply is None:
260
+ assistant_reply = str(messages[-1].get("content", "")).strip()
261
+
262
+ prompt_text = "\n\n".join(part for part in prompt_parts if part)
263
+ instructions.append(prompt_text)
264
+ responses.append(assistant_reply)
265
+ else:
266
+ raise KeyError("Each training example must include either 'instruction'/'response', 'prompt'/'completion', or 'messages'.")
267
+
268
+ # Create a simple dataset dict
269
+ from datasets import Dataset
270
+ dataset = Dataset.from_dict({
271
+ "instruction": instructions,
272
+ "response": responses
273
+ })
274
+
275
+ # Tokenize dataset
276
+ print("Tokenizing dataset...")
277
+ tokenized_dataset = dataset.map(
278
+ lambda x: tokenize_function(x, tokenizer, max_length=args.max_length),
279
+ batched=True,
280
+ remove_columns=dataset.column_names
281
+ )
282
+
283
+ # Split dataset into train/validation (80/20)
284
+ print("Splitting dataset into train/validation (80/20)...")
285
+ train_val_split = tokenized_dataset.train_test_split(test_size=0.2, seed=42)
286
+ train_dataset = train_val_split["train"]
287
+ val_dataset = train_val_split["test"]
288
+
289
+ print(f" - Training samples: {len(train_dataset)}")
290
+ print(f" - Validation samples: {len(val_dataset)}")
291
+
292
+ # Training arguments - adjust based on device
293
+ use_fp16 = device_info["device_type"] in ["cuda", "mps"]
294
+
295
+ # Calculate total steps and appropriate warmup
296
+ effective_batch_size = (2 if device_info["device_type"] != "cpu" else 1) * 4 # batch_size * gradient_accumulation
297
+ total_steps = (len(train_dataset) // effective_batch_size) * 3 # 3 epochs
298
+ warmup_steps = max(10, int(0.1 * total_steps)) # 10% warmup, minimum 10 steps
299
+
300
+ print(f"\nTraining Configuration:")
301
+ print(f" - Total training steps: {total_steps}")
302
+ print(f" - Warmup steps: {warmup_steps} ({100*warmup_steps/total_steps:.1f}% of training)")
303
+
304
+ training_args = TrainingArguments(
305
+ output_dir=args.output_dir,
306
+ num_train_epochs=3,
307
+ per_device_train_batch_size=2 if device_info["device_type"] != "cpu" else 1,
308
+ gradient_accumulation_steps=4,
309
+ warmup_steps=warmup_steps, # Dynamic warmup (10% of total steps)
310
+ learning_rate=5e-5, # Reduced from 2e-4 to prevent overfitting
311
+ weight_decay=0.01, # Added L2 regularization
312
+ fp16=use_fp16, # Only enable on GPU (CUDA/MPS)
313
+ bf16=False, # Can enable for newer CUDA GPUs if needed
314
+ logging_steps=10,
315
+ save_steps=50, # Save more frequently
316
+ eval_strategy="steps", # Enable evaluation
317
+ eval_steps=50, # Evaluate every 50 steps
318
+ save_total_limit=3,
319
+ load_best_model_at_end=True, # Load best checkpoint based on validation loss
320
+ metric_for_best_model="eval_loss",
321
+ greater_is_better=False,
322
+ lr_scheduler_type="cosine", # Cosine learning rate decay
323
+ max_grad_norm=1.0, # Gradient clipping
324
+ report_to="none",
325
+ push_to_hub=False,
326
+ dataloader_pin_memory=device_info["device_type"] == "cuda", # Only pin memory for CUDA
327
+ remove_unused_columns=False,
328
+ )
329
+
330
+ print(f"Training Configuration:")
331
+ print(f" - Device: {device_info['device']}")
332
+ print(f" - Mixed precision (fp16): {use_fp16}")
333
+ print(f" - Batch size: {training_args.per_device_train_batch_size}")
334
+ print(f" - Gradient accumulation: {training_args.gradient_accumulation_steps}")
335
+ print(f" - Learning rate: {training_args.learning_rate}")
336
+ print(f" - Weight decay: {training_args.weight_decay}")
337
+ print(f" - LR scheduler: {training_args.lr_scheduler_type}")
338
+ print(f" - Max grad norm: {training_args.max_grad_norm}")
339
+ print("=" * 50)
340
+
341
+ # Data collator
342
+ data_collator = DataCollatorForLanguageModeling(
343
+ tokenizer=tokenizer,
344
+ mlm=False,
345
+ )
346
+
347
+ # Add early stopping callback
348
+ from transformers import EarlyStoppingCallback
349
+
350
+ # Create trainer with validation set and early stopping
351
+ trainer = Trainer(
352
+ model=model,
353
+ args=training_args,
354
+ train_dataset=train_dataset,
355
+ eval_dataset=val_dataset, # Add validation set
356
+ data_collator=data_collator,
357
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], # Stop if no improvement for 3 evals
358
+ )
359
+
360
+ # Train
361
+ print("\nStarting training...")
362
+ trainer.train()
363
+
364
+ # Save model
365
+ print(f"\nSaving fine-tuned model to {args.output_dir}")
366
+ trainer.save_model(args.output_dir)
367
+ tokenizer.save_pretrained(args.output_dir)
368
+
369
+ # Save LoRA adapters separately
370
+ model.save_pretrained(args.output_dir)
371
+
372
+ print("\nFine-tuning complete!")
373
+ print(f"Model saved to: {args.output_dir}")
374
+ print(f"To load for inference, use the inference script with: {args.output_dir}")
375
+
376
+ if __name__ == "__main__":
377
+ main()
378
+