wlabchoi commited on
Commit
dc45fe9
·
verified ·
1 Parent(s): e5ff53a

Upload train_qwen3_distillation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_qwen3_distillation.py +6 -2
train_qwen3_distillation.py CHANGED
@@ -2,6 +2,7 @@
2
  # dependencies = ["transformers>=4.40.0", "datasets", "torch", "accelerate", "peft>=0.7.0", "trackio", "bitsandbytes"]
3
  # ///
4
 
 
5
  import torch
6
  import torch.nn.functional as F
7
  from datasets import load_dataset
@@ -17,6 +18,9 @@ import trackio
17
  from typing import Dict, Optional
18
  import numpy as np
19
 
 
 
 
20
  print("="*50)
21
  print("Knowledge Distillation: Qwen3-4B -> Qwen3-0.6B")
22
  print("Method: MiniLLM (Reversed KLD + Teacher Sampling)")
@@ -167,7 +171,7 @@ class MiniLLMTrainer(Trainer):
167
  self.alpha = alpha
168
  self.use_teacher_sampling = True # MiniLLM uses teacher sampling
169
 
170
- def compute_loss(self, model, inputs, return_outputs=False):
171
  """
172
  MiniLLM Loss:
173
  1. Sample tokens from teacher distribution
@@ -290,7 +294,7 @@ training_args = TrainingArguments(
290
  hub_private_repo=False,
291
 
292
  # Performance
293
- dataloader_num_workers=4,
294
  remove_unused_columns=False,
295
  )
296
 
 
2
  # dependencies = ["transformers>=4.40.0", "datasets", "torch", "accelerate", "peft>=0.7.0", "trackio", "bitsandbytes"]
3
  # ///
4
 
5
+ import os
6
  import torch
7
  import torch.nn.functional as F
8
  from datasets import load_dataset
 
18
  from typing import Dict, Optional
19
  import numpy as np
20
 
21
+ # Disable tokenizer parallelism warning
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+
24
  print("="*50)
25
  print("Knowledge Distillation: Qwen3-4B -> Qwen3-0.6B")
26
  print("Method: MiniLLM (Reversed KLD + Teacher Sampling)")
 
171
  self.alpha = alpha
172
  self.use_teacher_sampling = True # MiniLLM uses teacher sampling
173
 
174
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
175
  """
176
  MiniLLM Loss:
177
  1. Sample tokens from teacher distribution
 
294
  hub_private_repo=False,
295
 
296
  # Performance
297
+ dataloader_num_workers=0, # Avoid multiprocessing issues with tokenizers
298
  remove_unused_columns=False,
299
  )
300