File size: 4,948 Bytes
3df89a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239566
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""

Training module for KerdosAI.

"""

from typing import Dict, Any, Optional
import torch
from torch.utils.data import DataLoader
from transformers import (
    Trainer as HFTrainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from datasets import Dataset
from tqdm import tqdm
import wandb

class Trainer:
    """

    Handles the training process for the LLM.

    """
    
    def __init__(

        self,

        model: Any,

        tokenizer: Any,

        device: str,

        use_wandb: bool = True

    ):
        """

        Initialize the trainer.

        

        Args:

            model: The model to train

            tokenizer: The tokenizer for the model

            device: Device to run training on

            use_wandb: Whether to use Weights & Biases for logging

        """
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.use_wandb = use_wandb
        
        if use_wandb:
            wandb.init(project="kerdosai")
    
    def train(

        self,

        dataset: Dataset,

        epochs: int = 3,

        batch_size: int = 4,

        learning_rate: float = 2e-5,

        gradient_accumulation_steps: int = 1,

        warmup_steps: int = 100,

        weight_decay: float = 0.01,

        logging_steps: int = 10,

        save_steps: int = 100,

        output_dir: str = "output",

        **kwargs

    ) -> Dict[str, Any]:
        """

        Train the model on the provided dataset.

        

        Args:

            dataset: Training dataset

            epochs: Number of training epochs

            batch_size: Training batch size

            learning_rate: Learning rate

            gradient_accumulation_steps: Number of steps for gradient accumulation

            warmup_steps: Number of warmup steps

            weight_decay: Weight decay for optimizer

            logging_steps: Number of steps between logging

            save_steps: Number of steps between model saves

            output_dir: Directory to save checkpoints

            **kwargs: Additional training arguments

            

        Returns:

            Dictionary containing training metrics

        """
        # Prepare training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=batch_size,
            learning_rate=learning_rate,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=warmup_steps,
            weight_decay=weight_decay,
            logging_steps=logging_steps,
            save_steps=save_steps,
            fp16=self.device == "cuda",
            report_to="wandb" if self.use_wandb else "none",
            **kwargs
        )
        
        # Initialize data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False
        )
        
        # Initialize HuggingFace trainer
        trainer = HFTrainer(
            model=self.model,
            args=training_args,
            train_dataset=dataset,
            data_collator=data_collator
        )
        
        # Train the model
        train_result = trainer.train()
        
        # Save the final model
        trainer.save_model(output_dir)
        
        # Log final metrics
        metrics = train_result.metrics
        if self.use_wandb:
            wandb.log(metrics)
        
        return metrics
    
    def evaluate(

        self,

        dataset: Dataset,

        batch_size: int = 4,

        **kwargs

    ) -> Dict[str, float]:
        """

        Evaluate the model on the provided dataset.

        

        Args:

            dataset: Evaluation dataset

            batch_size: Evaluation batch size

            **kwargs: Additional evaluation arguments

            

        Returns:

            Dictionary containing evaluation metrics

        """
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False
        )
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating"):
                # Move batch to device
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Forward pass
                outputs = self.model(**batch)
                loss = outputs.loss
                
                total_loss += loss.item()
                num_batches += 1
        
        avg_loss = total_loss / num_batches
        metrics = {"eval_loss": avg_loss}
        
        if self.use_wandb:
            wandb.log(metrics)
        
        return metrics