File size: 8,898 Bytes
3270dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""Base classes for models, trainers, and datasets."""

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Any, Iterator
import torch
import torch.nn as nn
from torch.utils.data import Dataset as TorchDataset
from taoTrain.config import TrainingConfig, ModelConfig


# ============================================================================
# Base Model
# ============================================================================


class BaseModel(nn.Module, ABC):
    """Abstract base class for language models."""
    
    def __init__(self, config: ModelConfig):
        """Initialize model with config."""
        super().__init__()
        self.config = config
    
    @abstractmethod
    def forward(

        self,

        input_ids: torch.Tensor,

        attention_mask: Optional[torch.Tensor] = None,

        labels: Optional[torch.Tensor] = None,

    ) -> dict[str, torch.Tensor]:
        """

        Forward pass.

        

        Args:

            input_ids: Shape (batch_size, seq_length)

            attention_mask: Shape (batch_size, seq_length), optional

            labels: Shape (batch_size, seq_length), optional (for loss computation)

        

        Returns:

            Dict with keys:

                - 'logits': Shape (batch_size, seq_length, vocab_size)

                - 'loss': Scalar (if labels provided)

        """
        pass
    
    def count_parameters(self) -> int:
        """Count total trainable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def get_num_layers(self) -> int:
        """Get number of layers (for model architecture)."""
        return self.config.num_layers


# ============================================================================
# Base Dataset
# ============================================================================


class BaseDataset(TorchDataset, ABC):
    """Abstract base class for datasets."""
    
    def __init__(self, config: "TrainingConfig"):
        """Initialize dataset."""
        self.config = config
        self.data = None
    
    @abstractmethod
    def __len__(self) -> int:
        """Return dataset size."""
        pass
    
    @abstractmethod
    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        """

        Get a single sample.

        

        Returns:

            Dict with keys:

                - 'input_ids': 1D tensor of token IDs

                - 'attention_mask': 1D tensor of attention mask

                - 'labels': 1D tensor of labels (optional)

        """
        pass
    
    def load_dataset(self) -> None:
        """Load dataset from HuggingFace or other source."""
        pass
    
    def preprocess(self) -> None:
        """Preprocess dataset (tokenization, etc)."""
        pass


# ============================================================================
# Base Trainer
# ============================================================================


class BaseTrainer(ABC):
    """Abstract base class for trainers."""
    
    def __init__(

        self,

        model: BaseModel,

        train_dataset: BaseDataset,

        val_dataset: Optional[BaseDataset],

        config: TrainingConfig,

        device: torch.device,

    ):
        """Initialize trainer."""
        self.model = model.to(device)
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.config = config
        self.device = device
        
        # Training state
        self.global_step = 0
        self.current_epoch = 0
        self.best_loss = float('inf')
        
        # Logging
        self.logger = None
        
        # Optimizer and scheduler (to be set up by subclass)
        self.optimizer = None
        self.scheduler = None
    
    @abstractmethod
    def training_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
        """

        Single training step.

        

        Args:

            batch: Training batch with input_ids, attention_mask, labels, etc.

        

        Returns:

            Dict with metrics (e.g., {'loss': 0.5, 'accuracy': 0.8})

        """
        pass
    
    @abstractmethod
    def validation_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
        """

        Single validation step.

        

        Args:

            batch: Validation batch

        

        Returns:

            Dict with validation metrics

        """
        pass
    
    @abstractmethod
    def train_epoch(self) -> dict[str, float]:
        """

        Train for one epoch.

        

        Returns:

            Dict with epoch-level metrics

        """
        pass
    
    @abstractmethod
    def validate(self) -> dict[str, float]:
        """

        Run validation on the entire validation set.

        

        Returns:

            Dict with validation metrics

        """
        pass
    
    def save_checkpoint(self, path: str | Path) -> None:
        """

        Save checkpoint in canonical format.

        

        Uses canonical checkpoint format:

        {

            'step': int,

            'model_state': state_dict,

            'optimizer_state': state_dict,

            'config': dict,

            'metrics': dict,

            'global_step': int,           # Legacy compat

            'current_epoch': int,         # Legacy compat

            'best_loss': float,           # Legacy compat

        }

        

        Args:

            path: Path to save checkpoint

        """
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        
        # Save in canonical format
        checkpoint = {
            # Canonical format keys
            'step': self.global_step,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict() if self.optimizer else None,
            'config': self.config.to_dict(),
            'metrics': {},
            # Legacy format keys (for backward compatibility with code that reads them)
            'global_step': self.global_step,
            'current_epoch': self.current_epoch,
            'best_loss': self.best_loss,
        }
        
        torch.save(checkpoint, path)
    
    def load_checkpoint(self, path: str | Path) -> None:
        """

        Load checkpoint (handles both canonical and legacy formats).

        

        Args:

            path: Path to checkpoint

        """
        path = Path(path)
        checkpoint = torch.load(path, map_location=self.device)
        
        # Try canonical keys first, fall back to legacy keys
        model_state_key = 'model_state' if 'model_state' in checkpoint else 'model_state_dict'
        optimizer_state_key = 'optimizer_state' if 'optimizer_state' in checkpoint else 'optimizer_state_dict'
        
        self.model.load_state_dict(checkpoint[model_state_key])
        if self.optimizer and checkpoint.get(optimizer_state_key):
            self.optimizer.load_state_dict(checkpoint[optimizer_state_key])
        
        # Try canonical 'step' first, fall back to legacy 'global_step'
        self.global_step = checkpoint.get('step', checkpoint.get('global_step', 0))
        self.current_epoch = checkpoint.get('current_epoch', 0)
        self.best_loss = checkpoint.get('best_loss', float('inf'))
    
    def _get_lr(self) -> float:
        """Get current learning rate from optimizer."""
        for param_group in self.optimizer.param_groups:
            return param_group['lr']
        return 0.0


# ============================================================================
# Utility functions
# ============================================================================


def create_model(config: TrainingConfig, device: torch.device) -> BaseModel:
    """Create model from config (calls registry)."""
    from taoTrain.models import get_model
    return get_model(config.model, device=device)


def create_datasets(

    config: TrainingConfig,

) -> tuple[BaseDataset, Optional[BaseDataset]]:
    """Create train and validation datasets using factory pattern."""
    # Import here to avoid circular imports
    from taoTrain.data import DatasetFactory
    
    # Create train dataset
    train_dataset = DatasetFactory.create_dataset(config, split="train")
    
    # Create validation dataset (only for HuggingFace datasets with explicit validation split)
    val_dataset = None
    if not config.dataset.local and hasattr(config.dataset, "validation_split"):
        val_dataset = DatasetFactory.create_dataset(config, split="validation")
    
    return train_dataset, val_dataset