amos1088 commited on
Commit
9730244
Β·
1 Parent(s): d8bb5bb
fix_compatibility.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ echo Fixing Python 3.12 compatibility issues for DPO training...
3
+ echo.
4
+
5
+ REM Uninstall problematic packages
6
+ echo Removing conflicting packages...
7
+ pip uninstall -y tensorflow keras protobuf
8
+
9
+ REM Install tf-keras for compatibility
10
+ echo Installing tf-keras...
11
+ pip install tf-keras
12
+
13
+ REM Install specific protobuf version
14
+ echo Installing compatible protobuf...
15
+ pip install protobuf==3.20.3
16
+
17
+ REM Install PyTorch with CUDA 11.8 support
18
+ echo Installing PyTorch...
19
+ pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
20
+
21
+ REM Install other dependencies with specific versions
22
+ echo Installing other dependencies...
23
+ pip install transformers==4.36.2
24
+ pip install accelerate==0.25.0
25
+ pip install peft==0.7.1
26
+ pip install trl==0.7.10
27
+ pip install bitsandbytes==0.42.0
28
+ pip install datasets
29
+ pip install pandas
30
+ pip install scipy
31
+ pip install sentencepiece
32
+
33
+ echo.
34
+ echo Done! Now try running: python train_dpo_hf_fixed.py
35
+ pause
fix_compatibility.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Fix Python 3.12 compatibility issues for DPO training
3
+
4
+ echo "Fixing Python 3.12 compatibility issues for DPO training..."
5
+ echo
6
+
7
+ # Uninstall problematic packages
8
+ echo "Removing conflicting packages..."
9
+ pip uninstall -y tensorflow keras protobuf
10
+
11
+ # Install tf-keras for compatibility
12
+ echo "Installing tf-keras..."
13
+ pip install tf-keras
14
+
15
+ # Install specific protobuf version
16
+ echo "Installing compatible protobuf..."
17
+ pip install protobuf==3.20.3
18
+
19
+ # Install PyTorch with CUDA 11.8 support
20
+ echo "Installing PyTorch..."
21
+ pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
22
+
23
+ # Install other dependencies with specific versions
24
+ echo "Installing other dependencies..."
25
+ pip install transformers==4.36.2
26
+ pip install accelerate==0.25.0
27
+ pip install peft==0.7.1
28
+ pip install trl==0.7.10
29
+ pip install bitsandbytes==0.42.0
30
+ pip install datasets
31
+ pip install pandas
32
+ pip install scipy
33
+ pip install sentencepiece
34
+
35
+ echo
36
+ echo "Done! Now try running: python train_dpo_hf_fixed.py"
setup_environment.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Setup script to ensure all dependencies are correctly installed
3
+ """
4
+
5
+ import subprocess
6
+ import sys
7
+ import os
8
+
9
+ def run_command(cmd):
10
+ """Run a command and return success status"""
11
+ try:
12
+ subprocess.check_call(cmd, shell=True)
13
+ return True
14
+ except subprocess.CalledProcessError:
15
+ return False
16
+
17
+ def main():
18
+ print("πŸ”§ Setting up environment for DPO training...")
19
+ print("="*60)
20
+
21
+ # Python version check
22
+ python_version = sys.version_info
23
+ print(f"Python version: {python_version.major}.{python_version.minor}.{python_version.micro}")
24
+
25
+ if python_version.major < 3 or (python_version.major == 3 and python_version.minor < 8):
26
+ print("❌ Python 3.8+ is required!")
27
+ sys.exit(1)
28
+
29
+ # Fix protobuf issues
30
+ print("\nπŸ“¦ Fixing protobuf issues...")
31
+ run_command(f"{sys.executable} -m pip uninstall -y protobuf")
32
+ run_command(f"{sys.executable} -m pip install protobuf==3.20.3")
33
+
34
+ # Install tf-keras for compatibility
35
+ print("\nπŸ“¦ Installing tf-keras for compatibility...")
36
+ run_command(f"{sys.executable} -m pip install tf-keras")
37
+
38
+ # Core dependencies
39
+ print("\nπŸ“¦ Installing core dependencies...")
40
+ dependencies = [
41
+ "torch>=2.0.0",
42
+ "transformers>=4.36.0",
43
+ "datasets",
44
+ "accelerate>=0.25.0",
45
+ "peft>=0.7.0",
46
+ "trl>=0.7.0",
47
+ "bitsandbytes>=0.41.0",
48
+ "pandas",
49
+ "scipy",
50
+ "sentencepiece", # Required for some tokenizers
51
+ "protobuf==3.20.3", # Specific version to avoid issues
52
+ ]
53
+
54
+ for dep in dependencies:
55
+ print(f"Installing {dep}...")
56
+ if not run_command(f"{sys.executable} -m pip install {dep}"):
57
+ print(f"⚠️ Failed to install {dep}, continuing...")
58
+
59
+ # BEIR dependencies (optional)
60
+ print("\nπŸ“¦ Installing BEIR dependencies (optional)...")
61
+ beir_deps = ["beir", "scikit-learn", "tqdm"]
62
+ for dep in beir_deps:
63
+ print(f"Installing {dep}...")
64
+ run_command(f"{sys.executable} -m pip install {dep}")
65
+
66
+ # Check CUDA
67
+ print("\nπŸ” Checking CUDA availability...")
68
+ try:
69
+ import torch
70
+ if torch.cuda.is_available():
71
+ print(f"βœ… CUDA is available!")
72
+ print(f" Device: {torch.cuda.get_device_name(0)}")
73
+ print(f" CUDA version: {torch.version.cuda}")
74
+ else:
75
+ print("⚠️ CUDA not available. Training will be slow on CPU.")
76
+ except Exception as e:
77
+ print(f"⚠️ Could not check CUDA: {e}")
78
+
79
+ # Test imports
80
+ print("\nπŸ§ͺ Testing imports...")
81
+ test_imports = [
82
+ "torch",
83
+ "transformers",
84
+ "trl",
85
+ "peft",
86
+ "datasets",
87
+ "accelerate",
88
+ "bitsandbytes",
89
+ "pandas"
90
+ ]
91
+
92
+ failed = []
93
+ for module in test_imports:
94
+ try:
95
+ __import__(module)
96
+ print(f"βœ… {module}")
97
+ except ImportError as e:
98
+ print(f"❌ {module}: {e}")
99
+ failed.append(module)
100
+
101
+ if failed:
102
+ print(f"\n⚠️ Some imports failed: {', '.join(failed)}")
103
+ print("Try running: pip install --upgrade " + " ".join(failed))
104
+ else:
105
+ print("\nβœ… All imports successful!")
106
+
107
+ # Generate sample data if needed
108
+ if not os.path.exists("train.csv"):
109
+ print("\nπŸ“Š Generating sample data...")
110
+ try:
111
+ run_command(f"{sys.executable} generate_sample_data.py")
112
+ except:
113
+ print("⚠️ Could not generate sample data")
114
+
115
+ print("\nβœ… Setup complete!")
116
+ print("\nTo start training, run:")
117
+ print(f" {sys.executable} train_dpo_hf_fixed.py")
118
+
119
+ if __name__ == "__main__":
120
+ main()
train_dpo_hf_fixed.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DPO Training Script for Phi-3 Mini - Fixed version
3
+ Handles dependency issues and provides cleaner error handling
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import json
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+
12
+ # Set environment variables to avoid TensorFlow issues
13
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
14
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
15
+
16
+ try:
17
+ import torch
18
+ import pandas as pd
19
+ from transformers import (
20
+ AutoTokenizer,
21
+ AutoModelForCausalLM,
22
+ TrainingArguments,
23
+ TrainerCallback,
24
+ TrainerState,
25
+ TrainerControl
26
+ )
27
+ from trl import DPOTrainer
28
+ from trl.trainer.dpo_config import DPOConfig
29
+ from datasets import Dataset
30
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
31
+ from datetime import datetime
32
+ import logging
33
+ except ImportError as e:
34
+ print(f"Missing dependency: {e}")
35
+ print("\nPlease install required packages:")
36
+ print("pip install torch transformers trl peft datasets accelerate bitsandbytes pandas")
37
+ print("\nIf you get Keras errors, also run:")
38
+ print("pip install tf-keras")
39
+ sys.exit(1)
40
+
41
+ logging.basicConfig(level=logging.INFO)
42
+ logger = logging.getLogger(__name__)
43
+
44
+ # Configuration
45
+ MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
46
+ HF_USERNAME = os.environ.get("HF_USERNAME", "your-username")
47
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
48
+ PROJECT_NAME = "phi3-dpo-beir"
49
+ OUTPUT_DIR = f"./{PROJECT_NAME}-checkpoints"
50
+
51
+ class ValidationCallback(TrainerCallback):
52
+ """Custom callback to evaluate model similar to evaluate.py"""
53
+
54
+ def __init__(self, tokenizer, val_dataset, eval_freq=500):
55
+ self.tokenizer = tokenizer
56
+ self.val_dataset = val_dataset
57
+ self.eval_freq = eval_freq
58
+
59
+ def format_prompt_for_inference(self, query, document):
60
+ """Format for inference matching evaluate.py style"""
61
+ prompt = f"""You are an AI content analyst.
62
+
63
+ Task:
64
+ 1. Given the following content and a user query, decide if the content is relevant.
65
+ 2. If it is relevant:
66
+ - Extract the top 2-3 key sentences
67
+ - Suggest 3-5 relevant tags
68
+ - Provide a short explanation or content extension (~2-3 sentences)
69
+
70
+ Format your response in JSON with:
71
+ {{
72
+ "relevant": true or false,
73
+ "key_sentences": [...],
74
+ "tags": [...],
75
+ "expansion": "..."
76
+ }}
77
+
78
+ User Query:
79
+ {query}
80
+
81
+ Content:
82
+ {document}
83
+
84
+ Response:"""
85
+ return prompt
86
+
87
+ def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
88
+ """Run validation every N steps"""
89
+ if state.global_step % self.eval_freq == 0 and state.global_step > 0:
90
+ logger.info(f"\nπŸ” Running custom validation at step {state.global_step}")
91
+
92
+ model = kwargs["model"]
93
+ model.eval()
94
+
95
+ # Sample validation examples
96
+ sample_size = min(5, len(self.val_dataset))
97
+ samples = self.val_dataset.shuffle(seed=42).select(range(sample_size))
98
+
99
+ correct = 0
100
+ for sample in samples:
101
+ try:
102
+ # Extract query and document
103
+ prompt_text = sample["prompt"]
104
+ lines = prompt_text.split("\n")
105
+
106
+ # Find query and document sections
107
+ query_idx = -1
108
+ doc_idx = -1
109
+ for i, line in enumerate(lines):
110
+ if line.strip() == "Query:":
111
+ query_idx = i + 1
112
+ elif line.strip() == "Document:":
113
+ doc_idx = i + 1
114
+
115
+ if query_idx == -1 or doc_idx == -1:
116
+ continue
117
+
118
+ query = lines[query_idx].strip()
119
+ doc_parts = lines[doc_idx:]
120
+ document = "\n".join(doc_parts).strip()
121
+
122
+ # Generate response
123
+ inference_prompt = self.format_prompt_for_inference(query, document)
124
+ inputs = self.tokenizer(
125
+ inference_prompt,
126
+ return_tensors="pt",
127
+ truncation=True,
128
+ max_length=512
129
+ )
130
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
131
+
132
+ with torch.no_grad():
133
+ outputs = model.generate(
134
+ **inputs,
135
+ max_new_tokens=256,
136
+ temperature=0.1,
137
+ do_sample=True,
138
+ pad_token_id=self.tokenizer.eos_token_id
139
+ )
140
+
141
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
142
+ response = response[len(inference_prompt):].strip()
143
+
144
+ # Simple accuracy check
145
+ expected = sample["chosen"].lower()
146
+ if expected in response.lower():
147
+ correct += 1
148
+
149
+ logger.info(f"Expected: {expected}, Got: {response[:100]}...")
150
+ except Exception as e:
151
+ logger.error(f"Validation error: {e}")
152
+ continue
153
+
154
+ if sample_size > 0:
155
+ accuracy = correct / sample_size
156
+ logger.info(f"βœ… Validation accuracy: {accuracy:.2%}")
157
+
158
+ return control
159
+
160
+ def prepare_datasets():
161
+ """Load and prepare DPO datasets"""
162
+ logger.info("πŸ“Š Loading datasets...")
163
+
164
+ # Check if data files exist
165
+ if not os.path.exists("train.csv"):
166
+ logger.error("train.csv not found!")
167
+ logger.info("Please run download_beir_datasets.py first or use generate_sample_data.py")
168
+ return None, None, None
169
+
170
+ # Load CSVs
171
+ train_df = pd.read_csv("train.csv")
172
+ val_df = pd.read_csv("val.csv") if os.path.exists("val.csv") else None
173
+ test_df = pd.read_csv("test.csv") if os.path.exists("test.csv") else None
174
+
175
+ # Convert to HF datasets
176
+ train_dataset = Dataset.from_pandas(train_df)
177
+ val_dataset = Dataset.from_pandas(val_df) if val_df is not None else None
178
+ test_dataset = Dataset.from_pandas(test_df) if test_df is not None else None
179
+
180
+ logger.info(f"βœ… Loaded {len(train_dataset)} training examples")
181
+ if val_dataset:
182
+ logger.info(f"βœ… Loaded {len(val_dataset)} validation examples")
183
+
184
+ return train_dataset, val_dataset, test_dataset
185
+
186
+ def get_model_and_tokenizer():
187
+ """Load model and tokenizer with 4-bit quantization for A10G"""
188
+ logger.info(f"πŸ€– Loading model: {MODEL_ID}")
189
+
190
+ # Tokenizer
191
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
192
+ tokenizer.pad_token = tokenizer.eos_token
193
+ tokenizer.padding_side = "left" # Important for DPO
194
+
195
+ # Check if CUDA is available
196
+ if not torch.cuda.is_available():
197
+ logger.warning("⚠️ CUDA not available. Loading model in CPU mode (will be slow!)")
198
+ model = AutoModelForCausalLM.from_pretrained(
199
+ MODEL_ID,
200
+ torch_dtype=torch.float32,
201
+ device_map="cpu",
202
+ trust_remote_code=True
203
+ )
204
+ else:
205
+ # Model with 4-bit quantization
206
+ try:
207
+ model = AutoModelForCausalLM.from_pretrained(
208
+ MODEL_ID,
209
+ load_in_4bit=True,
210
+ torch_dtype=torch.float16,
211
+ device_map="auto",
212
+ trust_remote_code=True,
213
+ bnb_4bit_compute_dtype=torch.float16,
214
+ bnb_4bit_use_double_quant=True,
215
+ bnb_4bit_quant_type="nf4"
216
+ )
217
+ model = prepare_model_for_kbit_training(model)
218
+ except Exception as e:
219
+ logger.error(f"Failed to load model in 4-bit: {e}")
220
+ logger.info("Falling back to full precision...")
221
+ model = AutoModelForCausalLM.from_pretrained(
222
+ MODEL_ID,
223
+ torch_dtype=torch.float16,
224
+ device_map="auto",
225
+ trust_remote_code=True
226
+ )
227
+
228
+ return model, tokenizer
229
+
230
+ def get_peft_config():
231
+ """Get LoRA configuration optimized for A10G"""
232
+ return LoraConfig(
233
+ r=16,
234
+ lora_alpha=32,
235
+ target_modules=[
236
+ "q_proj", "v_proj", "k_proj", "o_proj",
237
+ "gate_proj", "up_proj", "down_proj"
238
+ ],
239
+ lora_dropout=0.1,
240
+ bias="none",
241
+ task_type="CAUSAL_LM",
242
+ )
243
+
244
+ def main():
245
+ logger.info("="*60)
246
+ logger.info("πŸš€ Starting DPO Training for Phi-3 Mini")
247
+ logger.info("="*60)
248
+
249
+ # Load datasets
250
+ train_dataset, val_dataset, test_dataset = prepare_datasets()
251
+ if train_dataset is None:
252
+ return
253
+
254
+ # Load model and tokenizer
255
+ try:
256
+ model, tokenizer = get_model_and_tokenizer()
257
+ except Exception as e:
258
+ logger.error(f"Failed to load model: {e}")
259
+ return
260
+
261
+ # LoRA config
262
+ peft_config = get_peft_config()
263
+
264
+ # Training arguments optimized for A10G
265
+ training_args = DPOConfig(
266
+ output_dir=OUTPUT_DIR,
267
+ num_train_epochs=3,
268
+ per_device_train_batch_size=2, # A10G can handle this
269
+ per_device_eval_batch_size=2,
270
+ gradient_accumulation_steps=4, # Effective batch size = 8
271
+ gradient_checkpointing=True,
272
+ learning_rate=5e-5,
273
+ lr_scheduler_type="cosine",
274
+ warmup_ratio=0.1,
275
+ logging_steps=10,
276
+ save_steps=100, # Save every 100 steps
277
+ eval_steps=500,
278
+ save_total_limit=5, # Keep last 5 checkpoints
279
+ load_best_model_at_end=True,
280
+ metric_for_best_model="loss",
281
+ greater_is_better=False,
282
+
283
+ # DPO specific
284
+ beta=0.1, # DPO regularization
285
+
286
+ # Optimization
287
+ optim="paged_adamw_8bit" if torch.cuda.is_available() else "adamw_torch",
288
+ fp16=torch.cuda.is_available(),
289
+
290
+ # Logging
291
+ report_to="none", # Disable wandb for simplicity
292
+ run_name=f"{PROJECT_NAME}-{datetime.now().strftime('%Y%m%d-%H%M')}",
293
+
294
+ # Hub integration
295
+ push_to_hub=True if HF_TOKEN else False,
296
+ hub_model_id=f"{HF_USERNAME}/{PROJECT_NAME}" if HF_TOKEN else None,
297
+ hub_strategy="checkpoint", # Push every checkpoint
298
+ hub_token=HF_TOKEN,
299
+ )
300
+
301
+ # Initialize trainer
302
+ try:
303
+ dpo_trainer = DPOTrainer(
304
+ model=model,
305
+ ref_model=None, # Will create a reference model copy
306
+ args=training_args,
307
+ train_dataset=train_dataset,
308
+ eval_dataset=val_dataset,
309
+ tokenizer=tokenizer,
310
+ peft_config=peft_config,
311
+ max_prompt_length=512,
312
+ max_length=768,
313
+ )
314
+ except Exception as e:
315
+ logger.error(f"Failed to initialize trainer: {e}")
316
+ return
317
+
318
+ # Add custom validation callback
319
+ if val_dataset:
320
+ val_callback = ValidationCallback(tokenizer, val_dataset)
321
+ dpo_trainer.add_callback(val_callback)
322
+
323
+ # Start training
324
+ logger.info("πŸš€ Starting DPO training...")
325
+ logger.info(f"πŸ’Ύ Checkpoints will be saved to: {OUTPUT_DIR}")
326
+ if HF_TOKEN:
327
+ logger.info(f"πŸ€— Model will be pushed to: https://huggingface.co/{HF_USERNAME}/{PROJECT_NAME}")
328
+
329
+ # Print some info about the data
330
+ logger.info("\nπŸ“Š Data Statistics:")
331
+ logger.info(f"Training samples: {len(train_dataset)}")
332
+ if val_dataset:
333
+ logger.info(f"Validation samples: {len(val_dataset)}")
334
+
335
+ # Show a sample
336
+ logger.info("\nπŸ“ Sample training data:")
337
+ sample = train_dataset[0]
338
+ logger.info(f"Prompt (first 200 chars): {sample['prompt'][:200]}...")
339
+ logger.info(f"Chosen: {sample['chosen']}")
340
+ logger.info(f"Rejected: {sample['rejected']}")
341
+
342
+ try:
343
+ dpo_trainer.train()
344
+ except KeyboardInterrupt:
345
+ logger.info("\n⚠️ Training interrupted by user")
346
+ except Exception as e:
347
+ logger.error(f"\n❌ Training failed: {e}")
348
+ return
349
+
350
+ # Save final model
351
+ logger.info("πŸ’Ύ Saving final model...")
352
+ dpo_trainer.save_model(f"{OUTPUT_DIR}/final")
353
+
354
+ # Push to hub
355
+ if HF_TOKEN:
356
+ logger.info("πŸ€— Pushing final model to Hub...")
357
+ try:
358
+ dpo_trainer.push_to_hub()
359
+ except Exception as e:
360
+ logger.error(f"Failed to push to hub: {e}")
361
+
362
+ logger.info("βœ… Training complete!")
363
+ logger.info(f"πŸ“ Model saved to: {OUTPUT_DIR}/final")
364
+
365
+ if __name__ == "__main__":
366
+ main()