File size: 11,617 Bytes
8b187bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
"""
Knowledge Distillation for MiniMind
Train smaller models using larger teacher models.
"""

import math
from typing import Optional, Dict, Any, Callable
from pathlib import Path
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast


@dataclass
class DistillationConfig:
    """Configuration for knowledge distillation."""
    # Distillation parameters
    temperature: float = 2.0
    alpha_ce: float = 0.5      # Weight for hard label loss
    alpha_kd: float = 0.5      # Weight for distillation loss
    alpha_hidden: float = 0.0  # Weight for hidden state matching

    # Optimization
    learning_rate: float = 1e-4
    min_learning_rate: float = 1e-5
    weight_decay: float = 0.1
    warmup_steps: int = 500
    grad_clip: float = 1.0

    # Training
    num_epochs: int = 5
    batch_size: int = 4
    gradient_accumulation_steps: int = 8
    max_steps: Optional[int] = None

    # Mixed precision
    use_amp: bool = True

    # Checkpointing
    save_steps: int = 500
    output_dir: str = "./distill_outputs"
    log_steps: int = 10


class DistillationTrainer:
    """
    Knowledge Distillation Trainer.
    Supports:
    - Soft label distillation (KL divergence)
    - Hard label training (CE loss)
    - Hidden state matching (optional)
    - Online and offline distillation
    """

    def __init__(
        self,
        student_model: nn.Module,
        teacher_model: Optional[nn.Module] = None,
        train_dataloader: DataLoader = None,
        config: Optional[DistillationConfig] = None,
        projection_layer: Optional[nn.Module] = None,
    ):
        self.student = student_model
        self.teacher = teacher_model
        self.train_dataloader = train_dataloader
        self.config = config or DistillationConfig()
        self.projection_layer = projection_layer  # For hidden state matching

        self.device = next(student_model.parameters()).device

        if self.teacher is not None:
            self.teacher.eval()
            for param in self.teacher.parameters():
                param.requires_grad = False

        self.optimizer = self._create_optimizer()
        self.scheduler = self._create_scheduler()
        self.scaler = GradScaler() if self.config.use_amp else None

        self.global_step = 0
        Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)

    def _create_optimizer(self) -> torch.optim.Optimizer:
        params = list(self.student.parameters())
        if self.projection_layer is not None:
            params += list(self.projection_layer.parameters())

        return torch.optim.AdamW(
            params,
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay,
        )

    def _create_scheduler(self):
        total_steps = self._get_total_steps()

        def lr_lambda(step):
            if step < self.config.warmup_steps:
                return step / max(1, self.config.warmup_steps)
            progress = (step - self.config.warmup_steps) / max(1, total_steps - self.config.warmup_steps)
            return max(
                self.config.min_learning_rate / self.config.learning_rate,
                0.5 * (1.0 + math.cos(math.pi * progress))
            )

        return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)

    def _get_total_steps(self) -> int:
        if self.config.max_steps:
            return self.config.max_steps
        steps_per_epoch = len(self.train_dataloader) // self.config.gradient_accumulation_steps
        return steps_per_epoch * self.config.num_epochs

    def distillation_loss(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        labels: torch.Tensor,
        student_hidden: Optional[torch.Tensor] = None,
        teacher_hidden: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute combined distillation loss.

        Args:
            student_logits: Student model output logits [B, T, V]
            teacher_logits: Teacher model output logits [B, T, V]
            labels: Ground truth labels [B, T]
            student_hidden: Student hidden states (optional)
            teacher_hidden: Teacher hidden states (optional)

        Returns:
            Dictionary with loss components and total loss
        """
        # Temperature-scaled soft labels
        T = self.config.temperature

        # Soft label loss (KL divergence)
        student_log_probs = F.log_softmax(student_logits / T, dim=-1)
        teacher_probs = F.softmax(teacher_logits / T, dim=-1)
        kd_loss = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction="batchmean"
        ) * (T ** 2)

        # Hard label loss (Cross entropy)
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        ce_loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100,
        )

        # Hidden state matching (optional)
        hidden_loss = torch.tensor(0.0, device=self.device)
        if student_hidden is not None and teacher_hidden is not None and self.projection_layer is not None:
            projected_student = self.projection_layer(student_hidden)
            hidden_loss = F.mse_loss(projected_student, teacher_hidden)

        # Combined loss
        total_loss = (
            self.config.alpha_ce * ce_loss +
            self.config.alpha_kd * kd_loss +
            self.config.alpha_hidden * hidden_loss
        )

        return {
            "total_loss": total_loss,
            "ce_loss": ce_loss,
            "kd_loss": kd_loss,
            "hidden_loss": hidden_loss,
        }

    def train(self) -> Dict[str, float]:
        """Main distillation training loop."""
        self.student.train()
        total_steps = self._get_total_steps()

        print(f"Starting knowledge distillation for {total_steps} steps")
        print(f"  Temperature: {self.config.temperature}")
        print(f"  Alpha CE: {self.config.alpha_ce}, Alpha KD: {self.config.alpha_kd}")

        running_losses = {"total": 0.0, "ce": 0.0, "kd": 0.0}

        for epoch in range(self.config.num_epochs):
            for step, batch in enumerate(self.train_dataloader):
                losses = self._training_step(batch)

                for key in running_losses:
                    running_losses[key] += losses.get(f"{key}_loss", losses.get("total_loss", 0.0)).item() if isinstance(losses.get(f"{key}_loss", losses.get("total_loss")), torch.Tensor) else 0.0

                if (step + 1) % self.config.gradient_accumulation_steps == 0:
                    self._optimizer_step()
                    self.global_step += 1

                    if self.global_step % self.config.log_steps == 0:
                        avg_losses = {k: v / self.config.log_steps for k, v in running_losses.items()}
                        print(
                            f"Step {self.global_step}/{total_steps} | "
                            f"Total: {avg_losses['total']:.4f} | "
                            f"CE: {avg_losses['ce']:.4f} | "
                            f"KD: {avg_losses['kd']:.4f}"
                        )
                        running_losses = {k: 0.0 for k in running_losses}

                    if self.global_step % self.config.save_steps == 0:
                        self._save_checkpoint()

                    if self.config.max_steps and self.global_step >= self.config.max_steps:
                        break

            if self.config.max_steps and self.global_step >= self.config.max_steps:
                break

        self._save_checkpoint(final=True)
        return {"final_step": self.global_step}

    def _training_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Single distillation training step."""
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch.get("attention_mask")
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)
        labels = batch["labels"].to(self.device)

        # Check for pre-computed teacher logits
        teacher_logits = batch.get("teacher_logits")
        if teacher_logits is not None:
            teacher_logits = teacher_logits.to(self.device)
        elif self.teacher is not None:
            with torch.no_grad():
                _, teacher_logits, _, _ = self.teacher(input_ids, attention_mask)

        if self.config.use_amp:
            with autocast(dtype=torch.float16):
                _, student_logits, _, _ = self.student(input_ids, attention_mask)
                losses = self.distillation_loss(student_logits, teacher_logits, labels)
                loss = losses["total_loss"] / self.config.gradient_accumulation_steps
            self.scaler.scale(loss).backward()
        else:
            _, student_logits, _, _ = self.student(input_ids, attention_mask)
            losses = self.distillation_loss(student_logits, teacher_logits, labels)
            loss = losses["total_loss"] / self.config.gradient_accumulation_steps
            loss.backward()

        return losses

    def _optimizer_step(self):
        if self.config.use_amp:
            self.scaler.unscale_(self.optimizer)

        torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.config.grad_clip)

        if self.config.use_amp:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()

        self.scheduler.step()
        self.optimizer.zero_grad()

    def _save_checkpoint(self, final: bool = False):
        name = "final" if final else f"step_{self.global_step}"
        path = Path(self.config.output_dir) / name
        path.mkdir(parents=True, exist_ok=True)

        torch.save(self.student.state_dict(), path / "student_model.pt")
        if self.projection_layer is not None:
            torch.save(self.projection_layer.state_dict(), path / "projection.pt")

        print(f"Checkpoint saved to {path}")


def generate_teacher_logits(
    teacher_model: nn.Module,
    dataloader: DataLoader,
    output_path: str,
    device: str = "cuda",
    top_k: int = 100,  # Only save top-k logits to reduce storage
):
    """
    Pre-generate teacher logits for offline distillation.
    Saves storage by only keeping top-k logits per position.
    """
    teacher_model.eval()
    teacher_model.to(device)

    all_logits = []

    print(f"Generating teacher logits for {len(dataloader)} batches...")

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch.get("attention_mask")
            if attention_mask is not None:
                attention_mask = attention_mask.to(device)

            _, logits, _, _ = teacher_model(input_ids, attention_mask)

            # Keep only top-k logits
            if top_k > 0 and top_k < logits.shape[-1]:
                topk_values, topk_indices = torch.topk(logits, k=top_k, dim=-1)
                sparse_logits = {
                    "values": topk_values.cpu(),
                    "indices": topk_indices.cpu(),
                }
                all_logits.append(sparse_logits)
            else:
                all_logits.append(logits.cpu())

    torch.save(all_logits, output_path)
    print(f"Teacher logits saved to {output_path}")