File size: 29,820 Bytes
151b875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
import argparse
import os
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
from torch.optim import AdamW
from tqdm import tqdm
import gc
import traceback
import matplotlib.pyplot as plt
from anticipation.vocab import ANTICIPATE, AUTOREGRESS  # Import the flag token constants

# Helper function to monitor GPU memory usage
def print_gpu_memory_stats():
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i} memory allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
            print(f"GPU {i} memory reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")
            print(f"GPU {i} max memory allocated: {torch.cuda.max_memory_allocated(i) / 1024**2:.2f} MB")

# Check for NaN values in model parameters
def check_model_for_nans(model):
    for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f"NaN detected in parameter {name}")
            return True
    return False

# Force CUDA if available
if torch.cuda.is_available():
    device = torch.device("cuda")
    device_count = torch.cuda.device_count()
    print(f"✓ CUDA is available with {device_count} device(s)")
    for i in range(device_count):
        device_name = torch.cuda.get_device_name(i)
        print(f"  Device {i}: {device_name}")
        props = torch.cuda.get_device_properties(i)
        print(f"    - Total memory: {props.total_memory / 1024**3:.2f} GB")
        print(f"    - CUDA capability: {props.major}.{props.minor}")
else:
    device = torch.device("cpu")
    print("✗ CUDA is not available! Training will be much slower on CPU.")

# Explicitly print which device we're using
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")

class SequencePackedDataset(Dataset):
    def __init__(self, file_path, context_length=1024, max_packed_sequences=4):
        """Load data from tokenized file and implement sequence packing

        

        Args:

            file_path: Path to the tokenized data file

            context_length: Maximum context length (default 1024)

            max_packed_sequences: Maximum number of sequences to pack together (default 4)

        """
        from anticipation.vocab import SEPARATOR, AUTOREGRESS, ANTICIPATE
        
        # Read all individual sequences
        individual_sequences = []
        with open(file_path, 'r') as f:
            for line in f:
                tokens = list(map(int, line.strip().split()))
                individual_sequences.append(tokens)
        
        print(f"Loaded {len(individual_sequences)} individual sequences")
        
        # Create packed sequences
        self.packed_sequences = []
        self.attention_masks = []
        
        # Keep track of statistics
        self.total_packed = 0
        self.avg_sequences_per_pack = 0
        sequences_per_pack = []
        
        # Process sequences in random order for better mixing
        import random
        random.shuffle(individual_sequences)
        
        # Pack sequences
        current_packed = []
        current_positions = []  # Track positions for creating attention masks
        
        for sequence in individual_sequences:
            # Extract control flag (first token)
            control_flag = sequence[0]
            assert control_flag in [AUTOREGRESS, ANTICIPATE], f"Invalid control flag: {control_flag}"
            
            # Rest of sequence (without control flag)
            sequence_content = sequence[1:]
            
            # If adding this sequence would exceed context length, start a new packed sequence
            # We need to add 3 separator tokens between sequences
            if len(current_packed) > 0 and (len(current_packed) + 3 + len(sequence_content) > context_length or 
                                           len(sequences_per_pack) >= max_packed_sequences):
                # Finalize current packed sequence
                if len(current_packed) > 0:
                    # Create attention mask (1 for tokens to attend to, 0 for tokens to ignore)
                    attention_mask = torch.zeros(context_length, dtype=torch.long)
                    for start, end in current_positions:
                        attention_mask[start:end] = 1
                    
                    # Pad to context length if needed
                    if len(current_packed) < context_length:
                        padding_length = context_length - len(current_packed)
                        current_packed.extend([SEPARATOR] * padding_length)
                    
                    # Convert to tensor and store
                    self.packed_sequences.append(torch.tensor(current_packed[:context_length], dtype=torch.long))
                    self.attention_masks.append(attention_mask)
                    sequences_per_pack.append(len(current_positions))
                    self.total_packed += 1
                
                # Start a new packed sequence
                current_packed = []
                current_positions = []
            
            # Add separator tokens between sequences (except for the first sequence in the pack)
            start_pos = len(current_packed)
            if len(current_packed) > 0:
                # Add separator tokens between sequences
                current_packed.extend([SEPARATOR, SEPARATOR, SEPARATOR])
                start_pos += 3
            
            # Add control flag and sequence content
            current_packed.append(control_flag)
            current_packed.extend(sequence_content)
            end_pos = len(current_packed)
            
            # Record the position of this sequence for attention masking
            current_positions.append((start_pos, end_pos))
        
        # Add the final packed sequence if not empty
        if len(current_packed) > 0:
            attention_mask = torch.zeros(context_length, dtype=torch.long)
            for start, end in current_positions:
                attention_mask[start:end] = 1
                
            # Pad to context length if needed
            if len(current_packed) < context_length:
                padding_length = context_length - len(current_packed)
                current_packed.extend([SEPARATOR] * padding_length)
            
            # Convert to tensor and store
            self.packed_sequences.append(torch.tensor(current_packed[:context_length], dtype=torch.long))
            self.attention_masks.append(attention_mask)
            sequences_per_pack.append(len(current_positions))
            self.total_packed += 1
        
        # Calculate statistics
        if sequences_per_pack:
            self.avg_sequences_per_pack = sum(sequences_per_pack) / len(sequences_per_pack)
        
        print(f"Created {len(self.packed_sequences)} packed sequences")
        print(f"Average sequences per pack: {self.avg_sequences_per_pack:.2f}")
        
    def __len__(self):
        return len(self.packed_sequences)
    
    def __getitem__(self, idx):
        return {
            "input_ids": self.packed_sequences[idx],
            "attention_mask": self.attention_masks[idx],
            "labels": self.packed_sequences[idx],
        }

def collate_packed_sequences(batch):
    """Collate function for packed sequences that includes attention masks"""
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_masks = torch.stack([item["attention_mask"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {
        "input_ids": input_ids,
        "attention_mask": attention_masks,
        "labels": labels
    }

def evaluate_model(model, dataloader, accelerator):
    """Calculate validation loss on a dataset"""
    model.eval()
    total_loss = 0
    total_samples = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            outputs = model(**batch)
            loss = outputs.loss
            
            # Get batch size from the input shape
            batch_size = batch["input_ids"].size(0)
            
            # Accumulate loss (weighted by batch size)
            total_loss += loss.item() * batch_size
            total_samples += batch_size
    
    # Return average loss
    return total_loss / total_samples

def plot_losses(train_losses, val_losses, validation_steps, output_dir):
    """

    Plot training and validation losses and save the figure

    

    Args:

        train_losses (list): Training loss history

        val_losses (list): Validation loss history

        validation_steps (list): Steps at which validation was performed

        output_dir (Path): Directory to save the plot

    """
    plt.figure(figsize=(10, 6))
    
    # Plot all training losses
    steps = list(range(1, len(train_losses) + 1))
    plt.plot(steps, train_losses, label='Training Loss', alpha=0.7, color='blue')
    
    # Plot validation losses at specific steps
    plt.plot(validation_steps, val_losses, label='Validation Loss', 
             linestyle='--', marker='o', markersize=5, color='red')
    
    plt.xlabel('Steps (x10)')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Save the figure
    plot_path = output_dir / "loss_plot.png"
    plt.savefig(plot_path)
    plt.close()
    
    print(f"Loss plot saved to {plot_path}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_file', type=Path, default=Path('./data/train.txt'))
    parser.add_argument('--val_file', type=Path, default=Path('./data/test.txt'))
    parser.add_argument('--model_name', type=str, default='stanford-crfm/music-small-800k')
    parser.add_argument('--output_dir', type=Path, default=Path('./fine_tuned'))
    parser.add_argument('--batch_size', type=int, default=8) 
    parser.add_argument('--val_batch_size', type=int, default=16)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=32)  # For effective batch size 256
    parser.add_argument('--learning_rate', type=float, default=3e-5)
    parser.add_argument('--max_steps', type=int, default=3500)
    parser.add_argument('--save_steps', type=int, default=500)
    parser.add_argument('--eval_steps', type=int, default=100)
    parser.add_argument('--warmup_steps', type=int, default=500)
    parser.add_argument('--force_cpu', action='store_true', help='Force CPU usage even if GPU is available')
    parser.add_argument('--reduce_memory', action='store_true', help='Use memory-saving techniques')
    parser.add_argument('--context_length', type=int, default=1024, help='Maximum context length')
    parser.add_argument('--max_packed_sequences', type=int, default=4, 
                       help='Maximum number of sequences to pack together (set to 1 to disable packing)')
    args = parser.parse_args()
    
    # Override device if requested
    global device
    if args.force_cpu:
        device = torch.device("cpu")
        print("Forcing CPU usage as requested")
    
    print(f"Effective batch size: {args.batch_size * args.gradient_accumulation_steps}")
    print(f"Final device confirmation: {device}")
    
    try:
        # Initialize accelerator with memory optimization if requested
        # Use bf16 instead of fp16 for better numerical stability
        mixed_precision = 'bf16' if torch.cuda.is_available() and not args.force_cpu else 'no'
        print(f"Mixed precision mode: {mixed_precision}")
        
        accelerator = Accelerator(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            cpu=args.force_cpu,
            mixed_precision=mixed_precision,
        )
        
        # Create output directory
        os.makedirs(args.output_dir, exist_ok=True)
        
        # Monitor initial GPU memory
        print("Initial GPU memory stats:")
        print_gpu_memory_stats()
        
        # Load training dataset
        print(f"Loading training dataset from {args.data_file}...")
        if args.max_packed_sequences > 1:
            print(f"Using sequence packing with max {args.max_packed_sequences} sequences per pack")
            train_dataset = SequencePackedDataset(
                args.data_file, 
                context_length=args.context_length,
                max_packed_sequences=args.max_packed_sequences
            )
            collate_fn_train = collate_packed_sequences
        else:
            print("Sequence packing disabled - using single sequences")
            # Original dataset class for backward compatibility
            from anticipation.vocab import SEPARATOR
            individual_sequences = []
            with open(args.data_file, 'r') as f:
                for line in f:
                    tokens = list(map(int, line.strip().split()))
                    individual_sequences.append(torch.tensor(tokens, dtype=torch.long))
            
            class TokenizedDataset(Dataset):
                def __init__(self, sequences):
                    self.sequences = sequences
                    self.sequence_length = len(self.sequences[0]) if self.sequences else 0
                    print(f"Loaded {len(self.sequences)} sequences with length {self.sequence_length}")
                
                def __len__(self):
                    return len(self.sequences)
                
                def __getitem__(self, idx):
                    tokens = self.sequences[idx]
                    return {"input_ids": tokens, "labels": tokens}
            
            train_dataset = TokenizedDataset(individual_sequences)
            
            def collate_fn_train(batch):
                input_ids = torch.stack([item["input_ids"] for item in batch])
                labels = torch.stack([item["labels"] for item in batch])
                return {"input_ids": input_ids, "labels": labels}
            
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=args.batch_size, 
            shuffle=True,
            collate_fn=collate_fn_train,
            pin_memory=torch.cuda.is_available() and not args.force_cpu,
            num_workers=0,  # Avoid multiprocessing issues
        )
        
        # Load validation dataset
        print(f"Loading validation dataset from {args.val_file}...")
        if args.max_packed_sequences > 1:
            val_dataset = SequencePackedDataset(
                args.val_file, 
                context_length=args.context_length,
                max_packed_sequences=args.max_packed_sequences
            )
            collate_fn_val = collate_packed_sequences
        else:
            # Load validation sequences
            val_sequences = []
            with open(args.val_file, 'r') as f:
                for line in f:
                    tokens = list(map(int, line.strip().split()))
                    val_sequences.append(torch.tensor(tokens, dtype=torch.long))
            
            val_dataset = TokenizedDataset(val_sequences)
            collate_fn_val = collate_fn_train
        
        val_dataloader = DataLoader(
            val_dataset, 
            batch_size=args.val_batch_size,
            shuffle=False,  # No need to shuffle validation data
            collate_fn=collate_fn_val,
            pin_memory=torch.cuda.is_available() and not args.force_cpu,
            num_workers=0,
        )
        
        # Load model with memory optimizations
        print(f"Loading model {args.model_name}...")
        model_kwargs = {
            "trust_remote_code": True,
            "use_cache": False,  # Important for training
        }
        
        if args.reduce_memory and torch.cuda.is_available():
            print("Using memory reduction techniques...")
            # BF16 is more stable than FP16
            model_kwargs.update({
                "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
                "low_cpu_mem_usage": True,
            })
        
        try:
            model = AutoModelForCausalLM.from_pretrained(
                args.model_name,
                **model_kwargs
            )
        except Exception as e:
            print(f"Error loading model with advanced options: {e}")
            print("Trying with basic options...")
            model = AutoModelForCausalLM.from_pretrained(
                args.model_name,
                trust_remote_code=True,
                use_cache=False
            )
        
        # Check memory after loading model
        print("GPU memory after loading model:")
        print_gpu_memory_stats()
        
        # Explicitly move model to our device before creating optimizer
        model = model.to(device)
        print(f"Model moved to: {next(model.parameters()).device}")
        
        # Setup optimizer with gradient clipping to prevent exploding gradients
        # Using a lower learning rate and better epsilon value for numerical stability
        optimizer = AdamW(
            model.parameters(), 
            lr=args.learning_rate,
            eps=1e-6,  # More stable epsilon
            weight_decay=0.01,
            betas=(0.9, 0.999),  # Stable default betas
        )
        
        # Prepare for training with accelerate
        model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
        val_dataloader = accelerator.prepare_data_loader(val_dataloader)
        print(f"After accelerator preparation, model device: {next(model.parameters()).device}")
        
        # Learning rate scheduler
        scheduler = get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=args.max_steps,
        )
        
        # Check memory before training
        print("GPU memory before training:")
        print_gpu_memory_stats()
        
        # Disable anomaly detection which can cause overhead
        torch.autograd.set_detect_anomaly(False)
        
        # Set deterministic algorithms for reproducibility
        torch.backends.cudnn.deterministic = False  # Better performance
        torch.backends.cudnn.benchmark = True  # Better performance
        
        if torch.cuda.is_available():
            print("Clearing CUDA cache before training")
            torch.cuda.empty_cache()
            torch.cuda.set_device(0)
        
        # Training loop
        print("Starting training...")
        model.train()
        completed_steps = 0
        step = 0
        
        # Lists to track losses
        train_losses = []
        val_losses = []
        validation_steps = []
        
        # Use standard tqdm with disable=False to ensure it always displays
        progress_bar = tqdm(total=args.max_steps, desc="Training", disable=False)
        
        try:
            while completed_steps < args.max_steps:
                for batch in train_dataloader:
                    try:
                        with accelerator.accumulate(model):
                            # Forward pass with gradient scaling
                            outputs = model(**batch)
                            loss = outputs.loss
                            
                            # Check for NaN loss
                            if torch.isnan(loss).any() or torch.isinf(loss).any():
                                print(f"WARNING: NaN or Inf loss detected: {loss.item()}")
                                # Skip this backward pass
                                optimizer.zero_grad()
                                continue
                                
                            # Backward pass
                            accelerator.backward(loss)
                            
                            # Only update optimizer and scheduler when gradients are synchronized
                            if accelerator.sync_gradients:
                                # Gradient clipping
                                accelerator.clip_grad_norm_(model.parameters(), max_norm=0.5)
                                
                                # Check for NaN in gradients
                                has_nan_grads = False
                                for name, param in model.named_parameters():
                                    if param.grad is not None and torch.isnan(param.grad).any():
                                        print(f"NaN gradient detected in {name}")
                                        has_nan_grads = True
                                        break
                                        
                                if has_nan_grads:
                                    print("Skipping update due to NaN gradients")
                                    optimizer.zero_grad()
                                    continue
                                
                                # Only update optimizer and scheduler here
                                optimizer.step()
                                scheduler.step()
                                optimizer.zero_grad()
                                
                                # Only update step counters when we actually update weights
                                completed_steps += 1
                                progress_bar.update(1)
                                
                                # Log progress
                                if completed_steps % 10 == 0:
                                    # Store the training loss every 10 steps
                                    train_losses.append(loss.item())
                                    
                                    # Print more precise learning rate
                                    print(f"Step: {completed_steps}/{args.max_steps}, Loss: {loss.item():.4f}, "
                                          f"LR: {scheduler.get_last_lr()[0]:.8e}")
                                    
                                    # Check for NaN parameters periodically
                                    if check_model_for_nans(model):
                                        print("NaN parameters detected in model! Training may be unstable.")
                                    
                                    # Check memory periodically
                                    if completed_steps % 100 == 0:
                                        print_gpu_memory_stats()
                                
                                # Run validation periodically
                                if completed_steps % args.eval_steps == 0:
                                    print(f"\nRunning validation at step {completed_steps}...")
                                    val_loss = evaluate_model(model, val_dataloader, accelerator)
                                    validation_steps.append(completed_steps // 10)  # Store step number (divided by 10 for plotting)
                                    val_losses.append(val_loss)
                                    print(f"Validation Loss: {val_loss:.4f}")
                                    
                                    # Return to training mode
                                    model.train()
                                    
                                    # Free up memory after validation
                                    if torch.cuda.is_available():
                                        torch.cuda.empty_cache()
                                        gc.collect()
                                
                                # Save checkpoint
                                if completed_steps % args.save_steps == 0:
                                    checkpoint_dir = args.output_dir / f"checkpoint-{completed_steps}"
                                    os.makedirs(checkpoint_dir, exist_ok=True)
                                    
                                    # Unwrap model before saving
                                    unwrapped_model = accelerator.unwrap_model(model)
                                    unwrapped_model.save_pretrained(
                                        checkpoint_dir,
                                        is_main_process=accelerator.is_main_process,
                                        save_function=accelerator.save,
                                    )
                                    print(f"Saved checkpoint to {checkpoint_dir}")
                                    
                                    # Save the losses so far
                                    np.savez(
                                        checkpoint_dir / "losses.npz",
                                        train_losses=np.array(train_losses),
                                        val_losses=np.array(val_losses),
                                        validation_steps=np.array(validation_steps)
                                    )
                                    
                                    # Create and save loss plot
                                    plot_losses(train_losses, val_losses, validation_steps, checkpoint_dir)
                                    
                                    # Free up memory
                                    if torch.cuda.is_available():
                                        torch.cuda.empty_cache()
                                        gc.collect()
                            
                            # Zero gradients even if we don't sync (needed for some accelerator configurations)
                            if not accelerator.sync_gradients:
                                optimizer.zero_grad()
                                
                            # Check if we've reached max steps
                            if completed_steps >= args.max_steps:
                                break
                            
                    except RuntimeError as e:
                        if "CUDA out of memory" in str(e):
                            print(f"CUDA OOM error! Current batch size: {args.batch_size}")
                            print(f"Current memory usage:")
                            print_gpu_memory_stats()
                            print("Consider reducing batch size or model size.")
                            print(f"Error details: {str(e)}")
                            raise
                        elif "nan" in str(e).lower() or "inf" in str(e).lower():
                            print(f"NaN/Inf error: {str(e)}")
                            print("Trying to recover by skipping this batch...")
                            optimizer.zero_grad()
                            continue
                        else:
                            print(f"Runtime error: {str(e)}")
                            print(traceback.format_exc())
                            raise
            
        except Exception as e:
            print(f"Error during training: {e}")
            print(traceback.format_exc())
            raise
        finally:
            # Make sure we always close the progress bar
            progress_bar.close()
            
            # Always try to save whatever we have and generate the final plot
            try:
                # Final validation run
                print("\nRunning final validation...")
                final_val_loss = evaluate_model(model, val_dataloader, accelerator)
                validation_steps.append(completed_steps // 10)
                val_losses.append(final_val_loss)
                print(f"Final validation Loss: {final_val_loss:.4f}")
                
                # Final save
                final_dir = args.output_dir / "final"
                os.makedirs(final_dir, exist_ok=True)
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(
                    final_dir,
                    is_main_process=accelerator.is_main_process,
                    save_function=accelerator.save,
                )
                print(f"Saved final model to {final_dir}")
                
                # Save the final losses
                np.savez(
                    final_dir / "losses.npz",
                    train_losses=np.array(train_losses),
                    val_losses=np.array(val_losses),
                    validation_steps=np.array(validation_steps)
                )
                
                # Create and save final loss plot
                plot_losses(train_losses, val_losses, validation_steps, final_dir)
                
            except Exception as save_error:
                print(f"Error saving final model or generating plot: {save_error}")
            
    except Exception as setup_error:
        print(f"Error in setup: {setup_error}")
        print(traceback.format_exc())

if __name__ == "__main__":
    main()