Oleg Lavrovsky
commited on
Initial testing
Browse files- README.md +185 -3
- accelerate_config/fsdp2.yaml +25 -0
- distill.py +78 -0
- hello.py +6 -0
- main.py +160 -0
- pyproject.toml +9 -0
- requirements.txt +4 -0
- run.sh +10 -0
- uv.lock +0 -0
README.md
CHANGED
|
@@ -1,3 +1,185 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Knowledge Distillation
|
| 2 |
+
|
| 3 |
+
Knowledge Distillation is a machine learning technique where a compact "student" model learns to replicate the behavior of a larger, more complex "teacher" model to achieve comparable performance with improved efficiency.
|
| 4 |
+
|
| 5 |
+
Model Optimizer's Distillation is a set of wrappers and utilities to easily perform Knowledge Distillation among teacher and student models. Given a pretrained teacher model, Distillation has the potential to train a smaller student model faster and/or with higher accuracy than the student model could achieve on its own.
|
| 6 |
+
|
| 7 |
+
This section focuses on demonstrating how to apply Model Optimizer to perform knowledge distillation with ease.
|
| 8 |
+
|
| 9 |
+
<div align="center">
|
| 10 |
+
|
| 11 |
+
| **Section** | **Description** | **Link** | **Docs** |
|
| 12 |
+
| :------------: | :------------: | :------------: | :------------: |
|
| 13 |
+
| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | |
|
| 14 |
+
| Getting Started | Learn how to optimize your models using distillation to produce more intellegant smaller models | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/4_distillation.html)\] |
|
| 15 |
+
| Support Matrix | View the support matrix to see compatibility and feature availability across different models | \[[Link](#support-matrix)\] | |
|
| 16 |
+
| Distillation with Megatron-LM | Learn how to distill your models with Megatron-LM Framework | \[[Link](#knowledge-distillation-kd-in-nvidia-megatron-lm-framework)\] | |
|
| 17 |
+
| Distillation with NeMo | Learn how to distill your models with NeMo Framework | \[[Link](#knowledge-distillation-kd-in-nvidia-nemo-framework)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/4_distillation.html)\] |
|
| 18 |
+
| Distillation with Huggingface | Learn how to distill your models with Hugging Face | \[[Link](#knowledge-distillation-kd-for-huggingface-models)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/4_distillation.html)\] |
|
| 19 |
+
| Resources | Extra links to relevant resources | \[[Link](#resources)\] | |
|
| 20 |
+
| NeMo Prune + Distill Simplified Flow | Example script demonstrating end-to-end pruning plus distillation in NeMo | \[[Link](../nemo_run/prune_distill/README.md)\] | |
|
| 21 |
+
|
| 22 |
+
</div>
|
| 23 |
+
|
| 24 |
+
## Pre-Requisites
|
| 25 |
+
|
| 26 |
+
### Docker
|
| 27 |
+
|
| 28 |
+
For Hugging Face models, please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.06-py3`).
|
| 29 |
+
For NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.09`) which has all the dependencies installed.
|
| 30 |
+
Visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
|
| 31 |
+
|
| 32 |
+
Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install example-specific dependencies.
|
| 33 |
+
|
| 34 |
+
### Local Installation
|
| 35 |
+
|
| 36 |
+
For Hugging Face models, install Model Optimizer with `hf` dependencies using `pip` from [PyPI](https://pypi.org/project/nvidia-modelopt/) and install the requirements for the example:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
pip install -U nvidia-modelopt[hf]
|
| 40 |
+
pip install -r requirements.txt
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Getting Started
|
| 44 |
+
|
| 45 |
+
### Set up your base models
|
| 46 |
+
|
| 47 |
+
First obtain both a pretrained model to act as the teacher and a (usually smaller) model to serve as the student.
|
| 48 |
+
|
| 49 |
+
```python
|
| 50 |
+
from transformers import AutoModelForCausalLM
|
| 51 |
+
|
| 52 |
+
# Define student & teacher
|
| 53 |
+
student_model = AutoModelForCausalLM.from_pretrained("student-model-id-or-path")
|
| 54 |
+
teacher_model = AutoModelForCausalLM.from_pretrained("teacher-model-id-or-path")
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Set up the meta model
|
| 58 |
+
|
| 59 |
+
As Knowledge Distillation involves (at least) two models, ModelOpt simplifies the integration process by wrapping both student and teacher into one meta model.
|
| 60 |
+
|
| 61 |
+
Please see an example Distillation setup below. This example assumes the outputs of `teacher_model` and `student_model` are logits.
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
import modelopt.torch.distill as mtd
|
| 65 |
+
|
| 66 |
+
distillation_config = {
|
| 67 |
+
"teacher_model": teacher_model,
|
| 68 |
+
"criterion": mtd.LogitsDistillationLoss(), # callable receiving student and teacher outputs, in order
|
| 69 |
+
"loss_balancer": mtd.StaticLossBalancer(), # combines multiple losses; omit if only one distillation loss used
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed).
|
| 76 |
+
|
| 77 |
+
See [Distillation](https://nvidia.github.io/Model-Optimizer/guides/4_distillation.html) for more info.
|
| 78 |
+
|
| 79 |
+
### Distill during training
|
| 80 |
+
|
| 81 |
+
To Distill from teacher to student, simply use the meta model in the usual training loop, while also using the meta model’s `.compute_kd_loss()` method to compute the distillation loss, in addition to the original user loss.
|
| 82 |
+
|
| 83 |
+
An example of Distillation training is given below:
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
# Setup the data loaders. As example:
|
| 87 |
+
train_loader = get_train_loader()
|
| 88 |
+
|
| 89 |
+
# Define user loss function. As example:
|
| 90 |
+
loss_fn = get_user_loss_fn()
|
| 91 |
+
|
| 92 |
+
for input, labels in train_dataloader:
|
| 93 |
+
distillation_model.zero_grad()
|
| 94 |
+
# Forward through the wrapped models
|
| 95 |
+
out = distillation_model(input)
|
| 96 |
+
# Same loss as originally present
|
| 97 |
+
loss = loss_fn(out, labels)
|
| 98 |
+
# Combine distillation and user losses
|
| 99 |
+
loss_total = distillation_model.compute_kd_loss(student_loss=loss)
|
| 100 |
+
loss_total.backward()
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
> [!NOTE]
|
| 104 |
+
> DataParallel may break ModelOpt’s Distillation feature. Note that HuggingFace Trainer uses DataParallel by default.
|
| 105 |
+
|
| 106 |
+
### Export trained model
|
| 107 |
+
|
| 108 |
+
The model can easily be reverted to its original class for further use (i.e deployment) without any ModelOpt modifications attached.
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
model = mtd.export(distillation_model)
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## Support Matrix
|
| 115 |
+
|
| 116 |
+
### Current out of the box components
|
| 117 |
+
|
| 118 |
+
Loss criterion:
|
| 119 |
+
|
| 120 |
+
- `mtd.LogitsDistillationLoss()` - Standard KL-Divergence on output logits
|
| 121 |
+
- `mtd.MGDLoss()` - Masked Generative Distillation loss for 2D convolutional outputs
|
| 122 |
+
- `mtd.MFTLoss()` - KL-divergence loss with Minifinetuning threshold modification
|
| 123 |
+
|
| 124 |
+
Loss balancers:
|
| 125 |
+
|
| 126 |
+
- `mtd.StaticLossBalancer()` - Combines original student loss and KD loss into a single weighted sum (without changing over time)
|
| 127 |
+
|
| 128 |
+
### Supported Models
|
| 129 |
+
|
| 130 |
+
> [!NOTE]
|
| 131 |
+
> The following are models that were confirmed to run with ModelOpt distillation, but it is absolutely not limited to these
|
| 132 |
+
|
| 133 |
+
| Model | type | confirmed compatible |
|
| 134 |
+
| :---: | :---: | :---: |
|
| 135 |
+
| Nemotron | gpt | ✅ |
|
| 136 |
+
| Llama 3 | llama | ✅ |
|
| 137 |
+
| Llama 4 | llama | ✅ |
|
| 138 |
+
| Gemma 2 | gemma | ✅ |
|
| 139 |
+
| Gemma 3 | gemma | ✅ |
|
| 140 |
+
| Phi 3 | phi | ✅ |
|
| 141 |
+
| Qwen 2 | qwen2 | ✅ |
|
| 142 |
+
| Qwen 3 | qwen3 | ✅ |
|
| 143 |
+
| Mamba | mamba | ✅ |
|
| 144 |
+
|
| 145 |
+
## Knowledge Distillation (KD) in NVIDIA Megatron-LM Framework
|
| 146 |
+
|
| 147 |
+
Checkout the Knowledge Distillation example in the [Megatron-LM repository](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt).
|
| 148 |
+
|
| 149 |
+
## Knowledge Distillation (KD) in NVIDIA NeMo Framework
|
| 150 |
+
|
| 151 |
+
Checkout the stand-alone distillation script in the [NeMo documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html).
|
| 152 |
+
|
| 153 |
+
You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial.
|
| 154 |
+
|
| 155 |
+
## Knowledge Distillation (KD) for HuggingFace Models
|
| 156 |
+
|
| 157 |
+
In this e2e example we finetune Llama-3.2 models on the [smol-smoltalk-Interaction-SFT](https://huggingface.co/datasets/ReactiveAI/smol-smoltalk-Interaction-SFT)
|
| 158 |
+
dataset as a minimal example to demonstrate a simple way of integrating Model Optimizer's KD feature.
|
| 159 |
+
|
| 160 |
+
We replace normal supervised finetuning (SFT) of a Llama-3.2-1B base model by distilling information from Llama-3.2-3B-Instruct which has already been instruction-finetuned.
|
| 161 |
+
|
| 162 |
+
> [!NOTE]
|
| 163 |
+
> We can fit the following in memory using [FSDP](https://huggingface.co/docs/accelerate/en/usage_guides/fsdp) enabled on 8x RTX 6000 (total ~400GB VRAM)
|
| 164 |
+
|
| 165 |
+
```bash
|
| 166 |
+
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
|
| 167 |
+
main.py \
|
| 168 |
+
--teacher_name_or_path 'meta-llama/Llama-3.2-3B-Instruct' \
|
| 169 |
+
--student_name_or_path 'meta-llama/Llama-3.2-1B' \
|
| 170 |
+
--output_dir ./llama3.2-distill \
|
| 171 |
+
--max_length 2048 \
|
| 172 |
+
--per_device_train_batch_size 4 \
|
| 173 |
+
--per_device_eval_batch_size 8 \
|
| 174 |
+
--max_steps 200 \
|
| 175 |
+
--logging_steps 5
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
## Resources
|
| 179 |
+
|
| 180 |
+
- 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146)
|
| 181 |
+
- 📖 [Documentation](https://nvidia.github.io/Model-Optimizer)
|
| 182 |
+
- 🎯 [Benchmarks](../benchmark.md)
|
| 183 |
+
- 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html)
|
| 184 |
+
- 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md)
|
| 185 |
+
- ✨ [File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md)
|
accelerate_config/fsdp2.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: FSDP
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
fsdp_config:
|
| 7 |
+
fsdp_activation_checkpointing: false
|
| 8 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
| 9 |
+
fsdp_cpu_ram_efficient_loading: false
|
| 10 |
+
fsdp_offload_params: false
|
| 11 |
+
fsdp_reshard_after_forward: true
|
| 12 |
+
fsdp_state_dict_type: SHARDED_STATE_DICT
|
| 13 |
+
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
| 14 |
+
fsdp_version: 2
|
| 15 |
+
machine_rank: 0
|
| 16 |
+
main_training_function: main
|
| 17 |
+
mixed_precision: bf16
|
| 18 |
+
num_machines: 1
|
| 19 |
+
num_processes: gpu
|
| 20 |
+
rdzv_backend: static
|
| 21 |
+
same_network: true
|
| 22 |
+
tpu_env: []
|
| 23 |
+
tpu_use_cluster: false
|
| 24 |
+
tpu_use_sudo: false
|
| 25 |
+
use_cpu: false
|
distill.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by Apertus on Public AI
|
| 2 |
+
from smol import DistillationTrainer
|
| 3 |
+
from transformers import AutoModel, AutoTokenizer
|
| 4 |
+
from transformers import DistilBERTForSequenceClassification
|
| 5 |
+
from transformers import AdamW
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
# Step 1: Load the large model (teacher model)
|
| 10 |
+
# Assuming you have a large model (e.g., 8B parameters) and a tokenizer
|
| 11 |
+
teacher_model = AutoModel.from_pretrained("swiss-ai/Apertus-8B-Instruct-2509")
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Base-2407")
|
| 13 |
+
|
| 14 |
+
# Step 2: Choose the smaller model (student model)
|
| 15 |
+
# Here, we use DistilBERT as an example
|
| 16 |
+
student_model = DistilBERTForSequenceClassification.from_pretrained("distilbert-base-uncased")
|
| 17 |
+
|
| 18 |
+
# Define the distillation loss function (e.g., using KLDivLoss)
|
| 19 |
+
class DistillationLoss(nn.Module):
|
| 20 |
+
def __init__(self, temperature, alpha):
|
| 21 |
+
super(DistillationLoss, self).__init__()
|
| 22 |
+
self.kl_loss = nn.KLDivLoss(temperature=temperature)
|
| 23 |
+
self.alpha = alpha
|
| 24 |
+
|
| 25 |
+
def forward(self, student_output, teacher_output):
|
| 26 |
+
return self.kl_loss(student_output.log_softmax(-1), teacher_output.softmax(-1)) * self.alpha
|
| 27 |
+
|
| 28 |
+
# Define a simple training loop
|
| 29 |
+
def train_step(model, batch, optimizer, loss_fn, device):
|
| 30 |
+
# Preprocess batch
|
| 31 |
+
inputs = tokenizer(batch["input_ids"], **tokenizer_args) # Tokenize the input
|
| 32 |
+
labels = batch["labels"]
|
| 33 |
+
|
| 34 |
+
# Forward pass with teacher model
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
teacher_output = model(**inputs)
|
| 37 |
+
teacher_output = teacher_output.logits if "logits" in teacher_output else teacher_output.logits # Handle model output
|
| 38 |
+
teacher_output = teacher_output.detach().to(device)
|
| 39 |
+
|
| 40 |
+
# Forward pass with student model
|
| 41 |
+
student_output = model(**inputs)
|
| 42 |
+
student_logits = student_output.logits if hasattr(student_output, "logits") else student_output.logits # Handle model output
|
| 43 |
+
student_logits = student_logits.to(device)
|
| 44 |
+
|
| 45 |
+
# Compute distillation loss
|
| 46 |
+
distillation_loss = loss_fn(student_logits, teacher_output.softmax(-1))
|
| 47 |
+
loss = distillation_loss
|
| 48 |
+
|
| 49 |
+
# Compute task loss (e.g., cross-entropy for classification)
|
| 50 |
+
task_loss = loss_function(student_logits, labels.to(device)) # Replace with your task-specific loss
|
| 51 |
+
total_loss = distillation_loss + task_loss # Combine both losses
|
| 52 |
+
|
| 53 |
+
# Backward and optimize
|
| 54 |
+
optimizer.zero_grad()
|
| 55 |
+
total_loss.backward()
|
| 56 |
+
optimizer.step()
|
| 57 |
+
|
| 58 |
+
return total_loss.item(), student_output, teacher_output
|
| 59 |
+
|
| 60 |
+
# Initialize SMOL's DistillationTrainer
|
| 61 |
+
from smol.trainer import DistillationTrainer
|
| 62 |
+
trainer = DistillationTrainer(
|
| 63 |
+
student_model,
|
| 64 |
+
optimizer=AdamW(student_model.parameters(), lr=1e-5), # Example learning rate
|
| 65 |
+
loss_fn=DistillationLoss(temperature=1.0, alpha=0.5), # Example distillation loss
|
| 66 |
+
train_dataset=your_train_dataset, # Your training dataset
|
| 67 |
+
eval_dataset=your_eval_dataset, # Your evaluation dataset
|
| 68 |
+
device="cuda" if torch.cuda.is_available() else "cpu", # Use GPU if available
|
| 69 |
+
num_epochs=5, # Number of epochs
|
| 70 |
+
batch_size=16, # Batch size
|
| 71 |
+
log_dir="distillation_logs", # Log directory
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Train the model
|
| 75 |
+
trainer.train()
|
| 76 |
+
|
| 77 |
+
# Alternatively, you can use SMOL's simplified training loop (as of SMOL 0.3.0, check the latest docs)
|
| 78 |
+
# trainer.train(steps=1000, evaluate_every=100, ...)
|
hello.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def main():
|
| 2 |
+
print("Hello from distill!")
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 20 |
+
|
| 21 |
+
import datasets
|
| 22 |
+
import torch
|
| 23 |
+
import torch.distributed
|
| 24 |
+
import transformers
|
| 25 |
+
from accelerate.logging import get_logger
|
| 26 |
+
from transformers import AutoTokenizer
|
| 27 |
+
from trl import SFTTrainer
|
| 28 |
+
|
| 29 |
+
import modelopt.torch.opt as mto
|
| 30 |
+
from modelopt.torch.distill.plugins.huggingface import KDTrainer, LMLogitsLoss
|
| 31 |
+
|
| 32 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ModelArguments:
|
| 37 |
+
teacher_name_or_path: str | None = None
|
| 38 |
+
student_name_or_path: str | None = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 43 |
+
do_train: bool = True
|
| 44 |
+
do_eval: bool = True
|
| 45 |
+
save_strategy: str = "no"
|
| 46 |
+
max_length: int = 1024
|
| 47 |
+
optim: str = "adamw_torch"
|
| 48 |
+
learning_rate: float = 1e-5
|
| 49 |
+
lr_scheduler_type: str = "cosine"
|
| 50 |
+
dataloader_drop_last: bool = True
|
| 51 |
+
dataset_num_proc: int = 8
|
| 52 |
+
bf16: bool = True
|
| 53 |
+
#tf32: bool = True
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _format_smoltalk_chat_template(sample, tokenizer):
|
| 57 |
+
# smol-smoltalk-Interaction-SFT dataset has "query" and "answer" fields
|
| 58 |
+
# Convert them to messages format and use tokenizer's apply_chat_template
|
| 59 |
+
messages = [
|
| 60 |
+
{"role": "user", "content": sample["query"]},
|
| 61 |
+
{"role": "assistant", "content": sample["answer"]},
|
| 62 |
+
]
|
| 63 |
+
return tokenizer.apply_chat_template(messages, tokenize=False)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class KDSFTTrainer(KDTrainer, SFTTrainer):
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def train():
|
| 71 |
+
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
| 72 |
+
model_args, training_args = parser.parse_args_into_dataclasses()
|
| 73 |
+
|
| 74 |
+
# Enable automatic save/load of modelopt state huggingface checkpointing
|
| 75 |
+
# modelopt state will be saved automatically to "modelopt_state.pth"
|
| 76 |
+
mto.enable_huggingface_checkpointing()
|
| 77 |
+
|
| 78 |
+
# Set total batch size across all ranks to equal 64
|
| 79 |
+
total_batch_size = 64
|
| 80 |
+
num_accum_steps = total_batch_size / (
|
| 81 |
+
training_args.per_device_train_batch_size * torch.distributed.get_world_size()
|
| 82 |
+
)
|
| 83 |
+
if not num_accum_steps.is_integer():
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"`per_device_train_batch_size` * `world_size` must be a factor of {total_batch_size}"
|
| 86 |
+
)
|
| 87 |
+
training_args.gradient_accumulation_steps = int(num_accum_steps)
|
| 88 |
+
logger.info(
|
| 89 |
+
f"Using {int(num_accum_steps)} grad accumulation steps for effective batchsize of {total_batch_size}."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Dataset
|
| 93 |
+
logger.info("Loading dataset...")
|
| 94 |
+
dset = datasets.load_dataset("ReactiveAI/smol-smoltalk-Interaction-SFT", split="train")
|
| 95 |
+
dset_splits = dset.train_test_split(train_size=12800, test_size=1280, seed=420)
|
| 96 |
+
dset_train, dset_eval = dset_splits["train"], dset_splits["test"]
|
| 97 |
+
logger.info("Dataset loaded.")
|
| 98 |
+
|
| 99 |
+
# Tokenizer
|
| 100 |
+
logger.info("Loading tokenizer...")
|
| 101 |
+
model_path = model_args.teacher_name_or_path or model_args.student_name_or_path
|
| 102 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
| 103 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 104 |
+
tokenizer.padding_side = "right"
|
| 105 |
+
logger.info("Tokenizer loaded.")
|
| 106 |
+
|
| 107 |
+
# Model(s)
|
| 108 |
+
logger.info("Loading student model...")
|
| 109 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 110 |
+
model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
|
| 111 |
+
)
|
| 112 |
+
logger.info("Student loaded.")
|
| 113 |
+
logger.info("Loading teacher model...")
|
| 114 |
+
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 115 |
+
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Distillation configuration
|
| 119 |
+
kd_config = {
|
| 120 |
+
"teacher_model": teacher_model,
|
| 121 |
+
"criterion": LMLogitsLoss(),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
# Fix problematic settings that logger.info excessive warnings
|
| 125 |
+
model.generation_config.temperature = None
|
| 126 |
+
model.generation_config.top_p = None
|
| 127 |
+
|
| 128 |
+
# Trainer
|
| 129 |
+
trainer = KDSFTTrainer(
|
| 130 |
+
model,
|
| 131 |
+
training_args,
|
| 132 |
+
distill_config=kd_config,
|
| 133 |
+
train_dataset=dset_train,
|
| 134 |
+
eval_dataset=dset_eval,
|
| 135 |
+
formatting_func=lambda sample: _format_smoltalk_chat_template(sample, tokenizer),
|
| 136 |
+
processing_class=tokenizer,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Do training
|
| 140 |
+
if training_args.do_train:
|
| 141 |
+
logger.info("Beginning training...")
|
| 142 |
+
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
| 143 |
+
logger.info("Training done.")
|
| 144 |
+
|
| 145 |
+
# Do evaluation
|
| 146 |
+
if training_args.do_eval:
|
| 147 |
+
logger.info("Evaluating...")
|
| 148 |
+
eval_results = trainer.evaluate()
|
| 149 |
+
logger.info(eval_results)
|
| 150 |
+
logger.info("Evaluation complete.")
|
| 151 |
+
|
| 152 |
+
# Save checkpoint
|
| 153 |
+
logger.info("Saving checkpoint...")
|
| 154 |
+
trainer.save_state()
|
| 155 |
+
trainer.save_model(trainer.args.output_dir)
|
| 156 |
+
logger.info("Checkpoint saved.")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
train()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "distill"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"smol>=0.5.7",
|
| 9 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pyarrow
|
| 2 |
+
torchao>=0.14.1
|
| 3 |
+
transformers<5.0
|
| 4 |
+
trl>=0.23.0
|
run.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
uv run accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
|
| 2 |
+
main.py \
|
| 3 |
+
--teacher_name_or_path 'swiss-ai/Apertus-8B-Instruct-2509' \
|
| 4 |
+
--student_name_or_path 'HuggingFaceTB/SmolLM2-135M-Instruct' \
|
| 5 |
+
--output_dir ./Apertus-8B-distill \
|
| 6 |
+
--max_length 2048 \
|
| 7 |
+
--per_device_train_batch_size 4 \
|
| 8 |
+
--per_device_eval_batch_size 8 \
|
| 9 |
+
--max_steps 200 \
|
| 10 |
+
--logging_steps 5
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|