Upload train_qwen3_distillation.py with huggingface_hub
Browse files
train_qwen3_distillation.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
# dependencies = ["transformers>=4.40.0", "datasets", "torch", "accelerate", "peft>=0.7.0", "trackio", "bitsandbytes"]
|
| 3 |
# ///
|
| 4 |
|
|
|
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from datasets import load_dataset
|
|
@@ -17,6 +18,9 @@ import trackio
|
|
| 17 |
from typing import Dict, Optional
|
| 18 |
import numpy as np
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
print("="*50)
|
| 21 |
print("Knowledge Distillation: Qwen3-4B -> Qwen3-0.6B")
|
| 22 |
print("Method: MiniLLM (Reversed KLD + Teacher Sampling)")
|
|
@@ -167,7 +171,7 @@ class MiniLLMTrainer(Trainer):
|
|
| 167 |
self.alpha = alpha
|
| 168 |
self.use_teacher_sampling = True # MiniLLM uses teacher sampling
|
| 169 |
|
| 170 |
-
def compute_loss(self, model, inputs, return_outputs=False):
|
| 171 |
"""
|
| 172 |
MiniLLM Loss:
|
| 173 |
1. Sample tokens from teacher distribution
|
|
@@ -290,7 +294,7 @@ training_args = TrainingArguments(
|
|
| 290 |
hub_private_repo=False,
|
| 291 |
|
| 292 |
# Performance
|
| 293 |
-
dataloader_num_workers=
|
| 294 |
remove_unused_columns=False,
|
| 295 |
)
|
| 296 |
|
|
|
|
| 2 |
# dependencies = ["transformers>=4.40.0", "datasets", "torch", "accelerate", "peft>=0.7.0", "trackio", "bitsandbytes"]
|
| 3 |
# ///
|
| 4 |
|
| 5 |
+
import os
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from datasets import load_dataset
|
|
|
|
| 18 |
from typing import Dict, Optional
|
| 19 |
import numpy as np
|
| 20 |
|
| 21 |
+
# Disable tokenizer parallelism warning
|
| 22 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 23 |
+
|
| 24 |
print("="*50)
|
| 25 |
print("Knowledge Distillation: Qwen3-4B -> Qwen3-0.6B")
|
| 26 |
print("Method: MiniLLM (Reversed KLD + Teacher Sampling)")
|
|
|
|
| 171 |
self.alpha = alpha
|
| 172 |
self.use_teacher_sampling = True # MiniLLM uses teacher sampling
|
| 173 |
|
| 174 |
+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
| 175 |
"""
|
| 176 |
MiniLLM Loss:
|
| 177 |
1. Sample tokens from teacher distribution
|
|
|
|
| 294 |
hub_private_repo=False,
|
| 295 |
|
| 296 |
# Performance
|
| 297 |
+
dataloader_num_workers=0, # Avoid multiprocessing issues with tokenizers
|
| 298 |
remove_unused_columns=False,
|
| 299 |
)
|
| 300 |
|