File size: 15,536 Bytes
5106722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
"""
Model training script for financial LLM fine-tuning
"""

import torch
import json
from datetime import datetime
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer,
    DataCollatorForLanguageModeling,
    default_data_collator,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType


def setup_model_and_tokenizer(config):
    """Setup model and tokenizer with quantization"""
    
    # Speed: enable TF32 on Ampere (A100)
    try:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print("βœ… TF32 enabled for faster matmul")
    except Exception:
        pass
    
    # Clear GPU cache and check memory
    torch.cuda.empty_cache()
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    allocated_memory = torch.cuda.memory_allocated() / 1e9
    free_memory = total_memory - allocated_memory
    
    print(f"πŸ”‹ A100 Memory Status:")
    print(f"   Total: {total_memory:.1f} GB")
    print(f"   Free: {free_memory:.1f} GB") 
    
    if free_memory < 15:
        print("⚠️ Warning: Low GPU memory, consider clearing cache")
    
    # Determine quantization mode (default to 8bit to avoid OOM)
    quantization = config.get("quantization")
    if quantization is None:
        quantization = "8bit"
    print(f"βš™οΈ Quantization mode: {quantization}")

    # Quantization config
    bnb_config = None
    if quantization == "4bit":
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    elif quantization == "8bit":
        bnb_config = BitsAndBytesConfig(load_in_8bit=True)
    
    # Load tokenizer
    print(f"Loading tokenizer: {config['model_name']}")
    tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"  # Ensure consistent padding
    
    # Load model
    print(f"Loading model: {config['model_name']}")
    model_kwargs = dict(
        device_map={"": 0},  # Force all layers to GPU 0
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )

    # Attention implementation selection: try flash-attn v2, else SDPA, else eager
    attn_pref = config.get("attn_impl")  # "flash" | "sdpa" | "eager" | None
    chosen_attn = None
    if attn_pref == "flash":
        try:
            import flash_attn  # noqa: F401
            chosen_attn = "flash_attention_2"
        except Exception:
            print("⚠️ flash-attn not available; falling back to SDPA")
            chosen_attn = "sdpa"
    elif attn_pref == "sdpa":
        chosen_attn = "sdpa"
    elif attn_pref == "eager":
        chosen_attn = "eager"
    else:
        # Auto: prefer flash if importable, otherwise SDPA
        try:
            import flash_attn  # noqa: F401
            chosen_attn = "flash_attention_2"
        except Exception:
            chosen_attn = "sdpa"

    # Pass down to Transformers if supported (>=4.39 for Llama)
    model_kwargs["attn_implementation"] = chosen_attn
    print(f"βœ… Attention implementation: {chosen_attn}")
    if bnb_config is not None:
        model_kwargs["quantization_config"] = bnb_config

    model = AutoModelForCausalLM.from_pretrained(
        config['model_name'],
        **model_kwargs,
    )
    
    model.config.use_cache = False
    model.config.pretraining_tp = 1
    # Ensure pad token id is set for training/eval
    if getattr(model.config, "pad_token_id", None) is None and tokenizer.pad_token_id is not None:
        model.config.pad_token_id = tokenizer.pad_token_id
    
    # Enable gradient checkpointing on the model to reduce memory
    try:
        if config.get('gradient_checkpointing', True):
            model.gradient_checkpointing_enable()
            print("βœ… Model gradient checkpointing enabled")
    except Exception:
        pass
    
    # Check memory usage after model loading
    allocated_after = torch.cuda.memory_allocated() / 1e9
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    usage_percent = (allocated_after / total_memory) * 100
    
    print(f"Model loaded successfully!")
    print(f"Model parameters: {model.num_parameters():,}")
    print(f"πŸ”‹ GPU Memory after loading: {allocated_after:.1f}/{total_memory:.1f} GB ({usage_percent:.1f}%)")
    
    if usage_percent > 85:
        print("⚠️ Warning: High GPU memory usage! Consider reducing batch size.")
    else:
        print("βœ… GPU memory usage looks good for training!")
    
    return model, tokenizer


def setup_lora(model, config):
    """Setup LoRA for efficient fine-tuning"""
    
    # LoRA configuration
    # Determine target modules based on model architecture
    if "DialoGPT" in config['model_name']:
        target_modules = ["c_attn", "c_proj"]
    elif "Llama" in config['model_name'] or "llama" in config['model_name']:
        # Llama 3.1 architecture - target all attention and MLP layers
        target_modules = [
            "q_proj", "k_proj", "v_proj", "o_proj",  # Attention layers
            "gate_proj", "up_proj", "down_proj"      # MLP layers
        ]
    else:
        # Default for other transformer models
        target_modules = ["q_proj", "v_proj"]
    
    # Read LoRA hyperparameters with safe defaults
    lora_r = int(config.get('lora_r', 16))
    lora_alpha = int(config.get('lora_alpha', 32))
    lora_dropout = float(config.get('lora_dropout', 0.1))

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=target_modules,
        bias="none",
    )
    
    # Apply LoRA to model
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    print(f"LoRA configuration applied successfully!")
    print(f"Target modules: {target_modules}")
    print(f"LoRA params β†’ r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
    return model


def tokenize_dataset(dataset, tokenizer, config):
    """Tokenize the dataset"""
    
    def tokenize_function(examples):
        """Tokenize the texts"""
        # Tokenize with consistent padding and truncation
        tokenized = tokenizer(
            examples["text"],
            truncation=True,
            padding="max_length",
            max_length=config['max_length'],
            return_tensors=None,
            add_special_tokens=True,
        )
        
        # Set labels (for causal LM, labels = input_ids)
        # Make sure labels are exactly the same as input_ids
        tokenized["labels"] = tokenized["input_ids"].copy()
        
        return tokenized
    
    # Apply tokenization
    print("Tokenizing dataset...")
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset["train"].column_names,
        desc="Tokenizing",
    )
    
    print("Tokenization complete!")
    
    # Debug: Check tokenized sample shapes
    sample = tokenized_dataset["train"][0]
    print(f"βœ… Sample tokenized input_ids shape: {len(sample['input_ids'])}")
    print(f"βœ… Sample tokenized labels shape: {len(sample['labels'])}")
    print(f"βœ… Max length setting: {config['max_length']}")
    
    return tokenized_dataset


def setup_training(model, tokenizer, tokenized_dataset, config):
    """Setup training arguments and trainer"""
    
    # Use default data collator since we're pre-padding during tokenization
    data_collator = default_data_collator
    
    import transformers
    transformers_version = transformers.__version__
    print(f"πŸ”§ Transformers version: {transformers_version}")
    
    use_eval_strategy = hasattr(TrainingArguments, '__dataclass_fields__') and \
                       'eval_strategy' in str(TrainingArguments.__dataclass_fields__)
    eval_param_name = "eval_strategy" if use_eval_strategy else "evaluation_strategy"
    
    training_args_dict = {
        "output_dir": config['output_dir'],
        "per_device_train_batch_size": config['train_batch_size'],
        "per_device_eval_batch_size": config['eval_batch_size'],
        "gradient_accumulation_steps": config['gradient_accumulation_steps'],
        "num_train_epochs": config['num_epochs'],
        "learning_rate": config['learning_rate'],
        "logging_steps": config.get('logging_steps', 25),
        eval_param_name: "steps",
        "eval_steps": config.get('eval_steps', 50),
        # Save checkpoints frequently enough; default aligns with eval steps
        "save_steps": config.get('save_steps', config.get('eval_steps', 100)),
        "save_total_limit": 2,
        "remove_unused_columns": False,
        "push_to_hub": False,
        "report_to": None,
        "load_best_model_at_end": True,
        "group_by_length": True,
        "warmup_ratio": config.get('warmup_ratio', 0.03),
        "weight_decay": config.get('weight_decay', 0.01),
        "max_grad_norm": config.get('max_grad_norm', 1.0),
        "lr_scheduler_type": "cosine",
        "dataloader_num_workers": config.get('dataloader_num_workers', 2),
        "dataloader_pin_memory": True,
        "skip_memory_metrics": True,
        "log_level": "warning",
        "include_inputs_for_metrics": False,
        "prediction_loss_only": True,
        "gradient_checkpointing": config.get('gradient_checkpointing', True),
    }

    # Optionally force alignment so a checkpoint is always written at each eval step
    # This helps ensure the current best (by eval loss) has a corresponding checkpoint
    if config.get('align_save_with_eval', True):
        training_args_dict["save_steps"] = training_args_dict.get("eval_steps", training_args_dict.get("save_steps", 100))

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    if use_bf16:
        training_args_dict["bf16"] = True
        training_args_dict["fp16"] = False
        print("βœ… Using bf16 precision")
    else:
        training_args_dict["fp16"] = True
        print("βœ… Using fp16 precision")
    
    print(f"βœ… Using {eval_param_name} parameter for evaluation")
    training_args = TrainingArguments(**training_args_dict)
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        data_collator=data_collator,
    )
    
    print("Trainer initialized!")
    print(f"Training samples: {len(tokenized_dataset['train'])}")
    print(f"Validation samples: {len(tokenized_dataset['validation'])}")
    
    # Validate data shapes to prevent tensor errors
    print("πŸ” Validating data shapes...")
    train_sample = tokenized_dataset["train"][0]
    val_sample = tokenized_dataset["validation"][0]
    
    print(f"βœ… Train sample - input_ids: {len(train_sample['input_ids'])}, labels: {len(train_sample['labels'])}")
    print(f"βœ… Val sample - input_ids: {len(val_sample['input_ids'])}, labels: {len(val_sample['labels'])}")
    
    # Check a few more samples to ensure consistency
    for i in range(min(3, len(tokenized_dataset['train']))):
        sample = tokenized_dataset['train'][i]
        if len(sample['input_ids']) != config['max_length']:
            print(f"⚠️  Warning: Sample {i} has inconsistent length: {len(sample['input_ids'])} != {config['max_length']}")
        if len(sample['input_ids']) != len(sample['labels']):
            print(f"⚠️  Warning: Sample {i} input_ids and labels length mismatch: {len(sample['input_ids'])} != {len(sample['labels'])}")
    
    print("βœ… Data validation complete!")
    
    return trainer


def save_model_and_config(model, tokenizer, trainer, config):
    """Save the trained model and configuration"""
    
    print("Saving model...")
    
    # Save LoRA adapter
    trainer.save_model(config['save_dir'])
    tokenizer.save_pretrained(config['save_dir'])
    
    # Save configuration
    config_data = {
        "base_model": config['model_name'],
        "dataset": config['dataset_name'],
        "dataset_config": config['dataset_config'],
        "training_config": config,
        "lora_config": {
            "r": config['lora_r'],
            "alpha": config['lora_alpha'],
            "dropout": config['lora_dropout']
        },
        "training_date": datetime.now().isoformat()
    }
    
    with open(f"{config['save_dir']}/training_config.json", "w") as f:
        json.dump(config_data, f, indent=2, default=str)
    
    print(f"Model saved to {config['save_dir']}")
    
    # Evaluate on validation set
    print("Evaluating model on validation set...")
    test_results = trainer.evaluate()
    
    # Save evaluation results
    with open(f"{config['save_dir']}/test_results.json", "w") as f:
        json.dump(test_results, f, indent=2)
    
    print(f"Evaluation complete! Results saved to {config['save_dir']}/test_results.json")
    
    return test_results


def run_training(config, processed_dataset):
    """Run the complete training pipeline"""
    
    print("πŸš€ Starting financial LLM fine-tuning...")
    print(f"Base model: {config['model_name']}")
    print(f"Dataset: {config['dataset_name']}")
    print(f"Training samples: {len(processed_dataset['train'])}")
    
    # Setup model and tokenizer
    model, tokenizer = setup_model_and_tokenizer(config)
    
    # Apply LoRA
    model = setup_lora(model, config)
    
    # Tokenize dataset
    tokenized_dataset = tokenize_dataset(processed_dataset, tokenizer, config)
    
    # Setup training
    trainer = setup_training(model, tokenizer, tokenized_dataset, config)
    
    # Start training
    print("Starting training...")
    print(f"Training will run for {config['num_epochs']} epochs")
    print(f"Effective batch size: {config['train_batch_size'] * config['gradient_accumulation_steps']}")
    
    trainer.train()
    
    print("Training completed!")
    
    # Save model and evaluate
    test_results = save_model_and_config(model, tokenizer, trainer, config)
    
    print("πŸŽ‰ Fine-tuning complete! πŸŽ‰")
    print(f"βœ… Model saved to: {config['save_dir']}")
    print(f"βœ… Test perplexity: {test_results.get('eval_loss', 'N/A'):.4f}")
    
    return model, tokenizer, trainer


if __name__ == "__main__":
    # Test configuration
    test_config = {
        "model_name": "microsoft/DialoGPT-medium",
        "dataset_name": "Josephgflowers/Finance-Instruct-500k",
        "dataset_config": "default",
        "max_length": 512,
        "train_batch_size": 2,
        "eval_batch_size": 2,
        "gradient_accumulation_steps": 8,
        "learning_rate": 2e-4,
        "num_epochs": 1,
        "lora_r": 16,
        "lora_alpha": 32,
        "lora_dropout": 0.1,
        "output_dir": "./test-financial-lora",
        "save_dir": "./test-financial-final",
        "quantization": "8bit",  # options: none | 8bit | 4bit
        "save_steps": 100,
        "eval_steps": 50,
        "logging_steps": 25,
        "gradient_checkpointing": True,
        "dataloader_num_workers": 2, # Added for testing
      }
    
    print("Testing training pipeline...")
    
    # This would require the processed dataset
    # model, tokenizer, trainer = run_training(test_config, processed_dataset)
    
    print("Training pipeline setup complete!")