File size: 5,694 Bytes
1df0e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import time
import os
from aetheris.utils import save_checkpoint, load_latest_checkpoint, calculate_model_stats

class Trainer:
    def __init__(self, model, optimizer, scaler, config, device, checkpoint_dir, logger=None):
        self.model = model
        self.optimizer = optimizer
        self.scaler = scaler
        self.config = config
        self.device = device
        self.checkpoint_dir = checkpoint_dir
        self.logger = logger
        
        self.model.to(self.device)

    def validate(self, val_loader, global_step):
        self.model.eval()
        total_loss = 0
        total_items = 0
        num_batches = 100 # Validate on 100 batches to save time
        
        print(f"\n[Validation] Starting validation at step {global_step}...")
        
        with torch.no_grad():
             for i, batch in enumerate(val_loader):
                if i >= num_batches:
                    break
                    
                input_ids, labels = batch
                input_ids = input_ids.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                
                # Auto-cast context
                if self.device.type == 'cuda':
                    autocast_dtype = torch.float16
                else:
                    autocast_dtype = torch.bfloat16
                    
                use_autocast = True if self.config.torch_dtype != torch.float32 else False
                
                if use_autocast:
                    with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
                        output = self.model(input_ids, labels)
                else:
                    output = self.model(input_ids, labels)
                
                total_loss += output["loss"].item()
                total_items += 1
        
        avg_loss = total_loss / total_items if total_items > 0 else 0
        perplexity = torch.exp(torch.tensor(avg_loss)).item()
        
        print(f"[Validation] Step {global_step} | Loss: {avg_loss:.4f} | PPL: {perplexity:.4f}")
        self.model.train()
        return avg_loss

    def train_epoch(self, train_loader, total_steps, start_step=0, stage_name="Training", val_loader=None, eval_every=500):
        print(f"\n{'='*70}\nStarting {stage_name}: Target Steps={total_steps}\n{'='*70}")
        self.model.train()
        global_step = start_step
        running_loss = 0

        print("Initializing data iterator...")
        train_iter = iter(train_loader)

        print("Fetching first batch...")

        while global_step < total_steps:
            step_start = time.time()

            # Removed periodic cache clearing for performance

            self.optimizer.zero_grad(set_to_none=True)

            try:
                batch = next(train_iter)
                if global_step == start_step:
                    print(f"✓ First batch loaded! Starting training loop...")
            except StopIteration:
                train_iter = iter(train_loader)
                batch = next(train_iter)

            input_ids, labels = batch
            input_ids = input_ids.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)

            # Determine autocast dtype
            if self.device.type == 'cuda':
                autocast_dtype = torch.float16
            else:
                autocast_dtype = torch.bfloat16

            # Check if we should use autocast (skip if model uses float32)
            use_autocast = True
            if self.config.torch_dtype == torch.float32:
                use_autocast = False

            if use_autocast:
                with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
                    output = self.model(input_ids, labels)
                    loss = output["loss"]
            else:
                output = self.model(input_ids, labels)
                loss = output["loss"]

            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)

            # Gradient clipping
            grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

            if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                print(f"WARNING: NaN/Inf gradient at step {global_step}, skipping update")
            else:
                self.scaler.step(self.optimizer)

            self.scaler.update()

            global_step += 1
            running_loss += loss.item()

            if global_step % 10 == 0:
                avg_loss = running_loss / 10
                t_diff = time.time() - step_start
                if self.device.type == 'cuda':
                    mem = torch.cuda.memory_allocated() / 1e9
                    max_mem = torch.cuda.max_memory_allocated() / 1e9
                    mem_str = f"VRAM: {mem:.1f}GB (peak: {max_mem:.1f}GB)"
                else:
                    mem_str = "CPU Mode"
                
                tokens_per_sec = (self.config.max_seq_len * input_ids.size(0)) / t_diff
                print(f"  Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | "
                      f"{mem_str} | {tokens_per_sec:.0f} tok/s")
                running_loss = 0

            if global_step % 500 == 0:
                save_checkpoint(self.model, self.optimizer, self.scaler, global_step, stage_name, self.checkpoint_dir)
                
            if val_loader is not None and global_step % eval_every == 0 and global_step > start_step:
                self.validate(val_loader, global_step)

        return global_step