Upload folder using huggingface_hub
Browse files- README.md +41 -0
- config.json +2026 -0
- core/.ipynb_checkpoints/distill-checkpoint.py +184 -0
- core/.ipynb_checkpoints/finetune-checkpoint.py +267 -0
- core/.ipynb_checkpoints/profiler-checkpoint.py +236 -0
- core/.ipynb_checkpoints/proxy_cost-checkpoint.py +771 -0
- core/.ipynb_checkpoints/train-checkpoint.py +327 -0
- core/.ipynb_checkpoints/utils-checkpoint.py +190 -0
- core/__init__.py +0 -0
- core/__pycache__/__init__.cpython-310.pyc +0 -0
- core/__pycache__/distill.cpython-310.pyc +0 -0
- core/__pycache__/export.cpython-310.pyc +0 -0
- core/__pycache__/finetune.cpython-310.pyc +0 -0
- core/__pycache__/gates.cpython-310.pyc +0 -0
- core/__pycache__/profiler.cpython-310.pyc +0 -0
- core/__pycache__/proxy_cost.cpython-310.pyc +0 -0
- core/__pycache__/search_export.cpython-310.pyc +0 -0
- core/__pycache__/train.cpython-310.pyc +0 -0
- core/__pycache__/utils.cpython-310.pyc +0 -0
- core/distill.py +183 -0
- core/export.py +220 -0
- core/finetune.py +267 -0
- core/gates.py +389 -0
- core/profiler.py +236 -0
- core/proxy_cost.py +771 -0
- core/search_export.py +76 -0
- core/train.py +327 -0
- core/utils.py +190 -0
- custom_code.py +1 -0
- huggingface/.ipynb_checkpoints/llama-checkpoint.py +607 -0
- huggingface/.ipynb_checkpoints/vit-checkpoint.py +383 -0
- huggingface/__init__.py +0 -0
- huggingface/__pycache__/__init__.cpython-310.pyc +0 -0
- huggingface/__pycache__/vit.cpython-310.pyc +0 -0
- huggingface/llama.py +607 -0
- huggingface/registry.py +0 -0
- huggingface/vit.py +383 -0
- model_index.json +5 -0
- pytorch_model.bin +3 -0
README.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
```yaml
|
| 2 |
+
---
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
tags:
|
| 5 |
+
- resnet
|
| 6 |
+
- pruning
|
| 7 |
+
- knowledge-distillation
|
| 8 |
+
- speedup
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
dataset: imagenet-1k
|
| 11 |
+
pipeline_tag: image-classification
|
| 12 |
+
---
|
| 13 |
+
```
|
| 14 |
+
# hawada/vit-base-patch16-224-rtx4090-gated
|
| 15 |
+
|
| 16 |
+
This repository contains two variants:
|
| 17 |
+
- **Gated student** (with learned pruning gates) – requires custom code.
|
| 18 |
+
- **Slim student** (post-prune/export) – loads with standard code (LLM) or bundled code (ResNet).
|
| 19 |
+
|
| 20 |
+
## Inference (LLM, slim)
|
| 21 |
+
```python
|
| 22 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 23 |
+
tok = AutoTokenizer.from_pretrained('hawada/vit-base-patch16-224-rtx4090-slim')
|
| 24 |
+
mdl = AutoModelForCausalLM.from_pretrained('hawada/vit-base-patch16-224-rtx4090-slim', torch_dtype='auto').eval()
|
| 25 |
+
x = tok('Hello', return_tensors='pt')
|
| 26 |
+
print(tok.decode(mdl.generate(**x, max_new_tokens=16)[0]))
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Notes
|
| 30 |
+
- The **gated** repo includes lightweight custom code (adapters/…, core/…) needed to attach/load gates.
|
| 31 |
+
- The **slim** LLM is exported to standard HF architecture for out-of-the-box loading.
|
| 32 |
+
- For ResNet, both repos include minimal custom code to define the module.
|
| 33 |
+
|
| 34 |
+
## Training metadata
|
| 35 |
+
```json
|
| 36 |
+
{
|
| 37 |
+
"base_id": "google/vit-base-patch16-224",
|
| 38 |
+
"variant": "gated-student",
|
| 39 |
+
"repo_slim": "hawada/vit-base-patch16-224-rtx4090-slim"
|
| 40 |
+
}
|
| 41 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,2026 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"ViTForImageClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.0,
|
| 6 |
+
"encoder_stride": 16,
|
| 7 |
+
"hidden_act": "gelu",
|
| 8 |
+
"hidden_dropout_prob": 0.0,
|
| 9 |
+
"hidden_size": 768,
|
| 10 |
+
"id2label": {
|
| 11 |
+
"0": "tench, Tinca tinca",
|
| 12 |
+
"1": "goldfish, Carassius auratus",
|
| 13 |
+
"2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
|
| 14 |
+
"3": "tiger shark, Galeocerdo cuvieri",
|
| 15 |
+
"4": "hammerhead, hammerhead shark",
|
| 16 |
+
"5": "electric ray, crampfish, numbfish, torpedo",
|
| 17 |
+
"6": "stingray",
|
| 18 |
+
"7": "cock",
|
| 19 |
+
"8": "hen",
|
| 20 |
+
"9": "ostrich, Struthio camelus",
|
| 21 |
+
"10": "brambling, Fringilla montifringilla",
|
| 22 |
+
"11": "goldfinch, Carduelis carduelis",
|
| 23 |
+
"12": "house finch, linnet, Carpodacus mexicanus",
|
| 24 |
+
"13": "junco, snowbird",
|
| 25 |
+
"14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
|
| 26 |
+
"15": "robin, American robin, Turdus migratorius",
|
| 27 |
+
"16": "bulbul",
|
| 28 |
+
"17": "jay",
|
| 29 |
+
"18": "magpie",
|
| 30 |
+
"19": "chickadee",
|
| 31 |
+
"20": "water ouzel, dipper",
|
| 32 |
+
"21": "kite",
|
| 33 |
+
"22": "bald eagle, American eagle, Haliaeetus leucocephalus",
|
| 34 |
+
"23": "vulture",
|
| 35 |
+
"24": "great grey owl, great gray owl, Strix nebulosa",
|
| 36 |
+
"25": "European fire salamander, Salamandra salamandra",
|
| 37 |
+
"26": "common newt, Triturus vulgaris",
|
| 38 |
+
"27": "eft",
|
| 39 |
+
"28": "spotted salamander, Ambystoma maculatum",
|
| 40 |
+
"29": "axolotl, mud puppy, Ambystoma mexicanum",
|
| 41 |
+
"30": "bullfrog, Rana catesbeiana",
|
| 42 |
+
"31": "tree frog, tree-frog",
|
| 43 |
+
"32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
|
| 44 |
+
"33": "loggerhead, loggerhead turtle, Caretta caretta",
|
| 45 |
+
"34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
|
| 46 |
+
"35": "mud turtle",
|
| 47 |
+
"36": "terrapin",
|
| 48 |
+
"37": "box turtle, box tortoise",
|
| 49 |
+
"38": "banded gecko",
|
| 50 |
+
"39": "common iguana, iguana, Iguana iguana",
|
| 51 |
+
"40": "American chameleon, anole, Anolis carolinensis",
|
| 52 |
+
"41": "whiptail, whiptail lizard",
|
| 53 |
+
"42": "agama",
|
| 54 |
+
"43": "frilled lizard, Chlamydosaurus kingi",
|
| 55 |
+
"44": "alligator lizard",
|
| 56 |
+
"45": "Gila monster, Heloderma suspectum",
|
| 57 |
+
"46": "green lizard, Lacerta viridis",
|
| 58 |
+
"47": "African chameleon, Chamaeleo chamaeleon",
|
| 59 |
+
"48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
|
| 60 |
+
"49": "African crocodile, Nile crocodile, Crocodylus niloticus",
|
| 61 |
+
"50": "American alligator, Alligator mississipiensis",
|
| 62 |
+
"51": "triceratops",
|
| 63 |
+
"52": "thunder snake, worm snake, Carphophis amoenus",
|
| 64 |
+
"53": "ringneck snake, ring-necked snake, ring snake",
|
| 65 |
+
"54": "hognose snake, puff adder, sand viper",
|
| 66 |
+
"55": "green snake, grass snake",
|
| 67 |
+
"56": "king snake, kingsnake",
|
| 68 |
+
"57": "garter snake, grass snake",
|
| 69 |
+
"58": "water snake",
|
| 70 |
+
"59": "vine snake",
|
| 71 |
+
"60": "night snake, Hypsiglena torquata",
|
| 72 |
+
"61": "boa constrictor, Constrictor constrictor",
|
| 73 |
+
"62": "rock python, rock snake, Python sebae",
|
| 74 |
+
"63": "Indian cobra, Naja naja",
|
| 75 |
+
"64": "green mamba",
|
| 76 |
+
"65": "sea snake",
|
| 77 |
+
"66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
|
| 78 |
+
"67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
|
| 79 |
+
"68": "sidewinder, horned rattlesnake, Crotalus cerastes",
|
| 80 |
+
"69": "trilobite",
|
| 81 |
+
"70": "harvestman, daddy longlegs, Phalangium opilio",
|
| 82 |
+
"71": "scorpion",
|
| 83 |
+
"72": "black and gold garden spider, Argiope aurantia",
|
| 84 |
+
"73": "barn spider, Araneus cavaticus",
|
| 85 |
+
"74": "garden spider, Aranea diademata",
|
| 86 |
+
"75": "black widow, Latrodectus mactans",
|
| 87 |
+
"76": "tarantula",
|
| 88 |
+
"77": "wolf spider, hunting spider",
|
| 89 |
+
"78": "tick",
|
| 90 |
+
"79": "centipede",
|
| 91 |
+
"80": "black grouse",
|
| 92 |
+
"81": "ptarmigan",
|
| 93 |
+
"82": "ruffed grouse, partridge, Bonasa umbellus",
|
| 94 |
+
"83": "prairie chicken, prairie grouse, prairie fowl",
|
| 95 |
+
"84": "peacock",
|
| 96 |
+
"85": "quail",
|
| 97 |
+
"86": "partridge",
|
| 98 |
+
"87": "African grey, African gray, Psittacus erithacus",
|
| 99 |
+
"88": "macaw",
|
| 100 |
+
"89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
|
| 101 |
+
"90": "lorikeet",
|
| 102 |
+
"91": "coucal",
|
| 103 |
+
"92": "bee eater",
|
| 104 |
+
"93": "hornbill",
|
| 105 |
+
"94": "hummingbird",
|
| 106 |
+
"95": "jacamar",
|
| 107 |
+
"96": "toucan",
|
| 108 |
+
"97": "drake",
|
| 109 |
+
"98": "red-breasted merganser, Mergus serrator",
|
| 110 |
+
"99": "goose",
|
| 111 |
+
"100": "black swan, Cygnus atratus",
|
| 112 |
+
"101": "tusker",
|
| 113 |
+
"102": "echidna, spiny anteater, anteater",
|
| 114 |
+
"103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
|
| 115 |
+
"104": "wallaby, brush kangaroo",
|
| 116 |
+
"105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
|
| 117 |
+
"106": "wombat",
|
| 118 |
+
"107": "jellyfish",
|
| 119 |
+
"108": "sea anemone, anemone",
|
| 120 |
+
"109": "brain coral",
|
| 121 |
+
"110": "flatworm, platyhelminth",
|
| 122 |
+
"111": "nematode, nematode worm, roundworm",
|
| 123 |
+
"112": "conch",
|
| 124 |
+
"113": "snail",
|
| 125 |
+
"114": "slug",
|
| 126 |
+
"115": "sea slug, nudibranch",
|
| 127 |
+
"116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
|
| 128 |
+
"117": "chambered nautilus, pearly nautilus, nautilus",
|
| 129 |
+
"118": "Dungeness crab, Cancer magister",
|
| 130 |
+
"119": "rock crab, Cancer irroratus",
|
| 131 |
+
"120": "fiddler crab",
|
| 132 |
+
"121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
|
| 133 |
+
"122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
|
| 134 |
+
"123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
|
| 135 |
+
"124": "crayfish, crawfish, crawdad, crawdaddy",
|
| 136 |
+
"125": "hermit crab",
|
| 137 |
+
"126": "isopod",
|
| 138 |
+
"127": "white stork, Ciconia ciconia",
|
| 139 |
+
"128": "black stork, Ciconia nigra",
|
| 140 |
+
"129": "spoonbill",
|
| 141 |
+
"130": "flamingo",
|
| 142 |
+
"131": "little blue heron, Egretta caerulea",
|
| 143 |
+
"132": "American egret, great white heron, Egretta albus",
|
| 144 |
+
"133": "bittern",
|
| 145 |
+
"134": "crane",
|
| 146 |
+
"135": "limpkin, Aramus pictus",
|
| 147 |
+
"136": "European gallinule, Porphyrio porphyrio",
|
| 148 |
+
"137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
|
| 149 |
+
"138": "bustard",
|
| 150 |
+
"139": "ruddy turnstone, Arenaria interpres",
|
| 151 |
+
"140": "red-backed sandpiper, dunlin, Erolia alpina",
|
| 152 |
+
"141": "redshank, Tringa totanus",
|
| 153 |
+
"142": "dowitcher",
|
| 154 |
+
"143": "oystercatcher, oyster catcher",
|
| 155 |
+
"144": "pelican",
|
| 156 |
+
"145": "king penguin, Aptenodytes patagonica",
|
| 157 |
+
"146": "albatross, mollymawk",
|
| 158 |
+
"147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
|
| 159 |
+
"148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
|
| 160 |
+
"149": "dugong, Dugong dugon",
|
| 161 |
+
"150": "sea lion",
|
| 162 |
+
"151": "Chihuahua",
|
| 163 |
+
"152": "Japanese spaniel",
|
| 164 |
+
"153": "Maltese dog, Maltese terrier, Maltese",
|
| 165 |
+
"154": "Pekinese, Pekingese, Peke",
|
| 166 |
+
"155": "Shih-Tzu",
|
| 167 |
+
"156": "Blenheim spaniel",
|
| 168 |
+
"157": "papillon",
|
| 169 |
+
"158": "toy terrier",
|
| 170 |
+
"159": "Rhodesian ridgeback",
|
| 171 |
+
"160": "Afghan hound, Afghan",
|
| 172 |
+
"161": "basset, basset hound",
|
| 173 |
+
"162": "beagle",
|
| 174 |
+
"163": "bloodhound, sleuthhound",
|
| 175 |
+
"164": "bluetick",
|
| 176 |
+
"165": "black-and-tan coonhound",
|
| 177 |
+
"166": "Walker hound, Walker foxhound",
|
| 178 |
+
"167": "English foxhound",
|
| 179 |
+
"168": "redbone",
|
| 180 |
+
"169": "borzoi, Russian wolfhound",
|
| 181 |
+
"170": "Irish wolfhound",
|
| 182 |
+
"171": "Italian greyhound",
|
| 183 |
+
"172": "whippet",
|
| 184 |
+
"173": "Ibizan hound, Ibizan Podenco",
|
| 185 |
+
"174": "Norwegian elkhound, elkhound",
|
| 186 |
+
"175": "otterhound, otter hound",
|
| 187 |
+
"176": "Saluki, gazelle hound",
|
| 188 |
+
"177": "Scottish deerhound, deerhound",
|
| 189 |
+
"178": "Weimaraner",
|
| 190 |
+
"179": "Staffordshire bullterrier, Staffordshire bull terrier",
|
| 191 |
+
"180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
|
| 192 |
+
"181": "Bedlington terrier",
|
| 193 |
+
"182": "Border terrier",
|
| 194 |
+
"183": "Kerry blue terrier",
|
| 195 |
+
"184": "Irish terrier",
|
| 196 |
+
"185": "Norfolk terrier",
|
| 197 |
+
"186": "Norwich terrier",
|
| 198 |
+
"187": "Yorkshire terrier",
|
| 199 |
+
"188": "wire-haired fox terrier",
|
| 200 |
+
"189": "Lakeland terrier",
|
| 201 |
+
"190": "Sealyham terrier, Sealyham",
|
| 202 |
+
"191": "Airedale, Airedale terrier",
|
| 203 |
+
"192": "cairn, cairn terrier",
|
| 204 |
+
"193": "Australian terrier",
|
| 205 |
+
"194": "Dandie Dinmont, Dandie Dinmont terrier",
|
| 206 |
+
"195": "Boston bull, Boston terrier",
|
| 207 |
+
"196": "miniature schnauzer",
|
| 208 |
+
"197": "giant schnauzer",
|
| 209 |
+
"198": "standard schnauzer",
|
| 210 |
+
"199": "Scotch terrier, Scottish terrier, Scottie",
|
| 211 |
+
"200": "Tibetan terrier, chrysanthemum dog",
|
| 212 |
+
"201": "silky terrier, Sydney silky",
|
| 213 |
+
"202": "soft-coated wheaten terrier",
|
| 214 |
+
"203": "West Highland white terrier",
|
| 215 |
+
"204": "Lhasa, Lhasa apso",
|
| 216 |
+
"205": "flat-coated retriever",
|
| 217 |
+
"206": "curly-coated retriever",
|
| 218 |
+
"207": "golden retriever",
|
| 219 |
+
"208": "Labrador retriever",
|
| 220 |
+
"209": "Chesapeake Bay retriever",
|
| 221 |
+
"210": "German short-haired pointer",
|
| 222 |
+
"211": "vizsla, Hungarian pointer",
|
| 223 |
+
"212": "English setter",
|
| 224 |
+
"213": "Irish setter, red setter",
|
| 225 |
+
"214": "Gordon setter",
|
| 226 |
+
"215": "Brittany spaniel",
|
| 227 |
+
"216": "clumber, clumber spaniel",
|
| 228 |
+
"217": "English springer, English springer spaniel",
|
| 229 |
+
"218": "Welsh springer spaniel",
|
| 230 |
+
"219": "cocker spaniel, English cocker spaniel, cocker",
|
| 231 |
+
"220": "Sussex spaniel",
|
| 232 |
+
"221": "Irish water spaniel",
|
| 233 |
+
"222": "kuvasz",
|
| 234 |
+
"223": "schipperke",
|
| 235 |
+
"224": "groenendael",
|
| 236 |
+
"225": "malinois",
|
| 237 |
+
"226": "briard",
|
| 238 |
+
"227": "kelpie",
|
| 239 |
+
"228": "komondor",
|
| 240 |
+
"229": "Old English sheepdog, bobtail",
|
| 241 |
+
"230": "Shetland sheepdog, Shetland sheep dog, Shetland",
|
| 242 |
+
"231": "collie",
|
| 243 |
+
"232": "Border collie",
|
| 244 |
+
"233": "Bouvier des Flandres, Bouviers des Flandres",
|
| 245 |
+
"234": "Rottweiler",
|
| 246 |
+
"235": "German shepherd, German shepherd dog, German police dog, alsatian",
|
| 247 |
+
"236": "Doberman, Doberman pinscher",
|
| 248 |
+
"237": "miniature pinscher",
|
| 249 |
+
"238": "Greater Swiss Mountain dog",
|
| 250 |
+
"239": "Bernese mountain dog",
|
| 251 |
+
"240": "Appenzeller",
|
| 252 |
+
"241": "EntleBucher",
|
| 253 |
+
"242": "boxer",
|
| 254 |
+
"243": "bull mastiff",
|
| 255 |
+
"244": "Tibetan mastiff",
|
| 256 |
+
"245": "French bulldog",
|
| 257 |
+
"246": "Great Dane",
|
| 258 |
+
"247": "Saint Bernard, St Bernard",
|
| 259 |
+
"248": "Eskimo dog, husky",
|
| 260 |
+
"249": "malamute, malemute, Alaskan malamute",
|
| 261 |
+
"250": "Siberian husky",
|
| 262 |
+
"251": "dalmatian, coach dog, carriage dog",
|
| 263 |
+
"252": "affenpinscher, monkey pinscher, monkey dog",
|
| 264 |
+
"253": "basenji",
|
| 265 |
+
"254": "pug, pug-dog",
|
| 266 |
+
"255": "Leonberg",
|
| 267 |
+
"256": "Newfoundland, Newfoundland dog",
|
| 268 |
+
"257": "Great Pyrenees",
|
| 269 |
+
"258": "Samoyed, Samoyede",
|
| 270 |
+
"259": "Pomeranian",
|
| 271 |
+
"260": "chow, chow chow",
|
| 272 |
+
"261": "keeshond",
|
| 273 |
+
"262": "Brabancon griffon",
|
| 274 |
+
"263": "Pembroke, Pembroke Welsh corgi",
|
| 275 |
+
"264": "Cardigan, Cardigan Welsh corgi",
|
| 276 |
+
"265": "toy poodle",
|
| 277 |
+
"266": "miniature poodle",
|
| 278 |
+
"267": "standard poodle",
|
| 279 |
+
"268": "Mexican hairless",
|
| 280 |
+
"269": "timber wolf, grey wolf, gray wolf, Canis lupus",
|
| 281 |
+
"270": "white wolf, Arctic wolf, Canis lupus tundrarum",
|
| 282 |
+
"271": "red wolf, maned wolf, Canis rufus, Canis niger",
|
| 283 |
+
"272": "coyote, prairie wolf, brush wolf, Canis latrans",
|
| 284 |
+
"273": "dingo, warrigal, warragal, Canis dingo",
|
| 285 |
+
"274": "dhole, Cuon alpinus",
|
| 286 |
+
"275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
|
| 287 |
+
"276": "hyena, hyaena",
|
| 288 |
+
"277": "red fox, Vulpes vulpes",
|
| 289 |
+
"278": "kit fox, Vulpes macrotis",
|
| 290 |
+
"279": "Arctic fox, white fox, Alopex lagopus",
|
| 291 |
+
"280": "grey fox, gray fox, Urocyon cinereoargenteus",
|
| 292 |
+
"281": "tabby, tabby cat",
|
| 293 |
+
"282": "tiger cat",
|
| 294 |
+
"283": "Persian cat",
|
| 295 |
+
"284": "Siamese cat, Siamese",
|
| 296 |
+
"285": "Egyptian cat",
|
| 297 |
+
"286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
|
| 298 |
+
"287": "lynx, catamount",
|
| 299 |
+
"288": "leopard, Panthera pardus",
|
| 300 |
+
"289": "snow leopard, ounce, Panthera uncia",
|
| 301 |
+
"290": "jaguar, panther, Panthera onca, Felis onca",
|
| 302 |
+
"291": "lion, king of beasts, Panthera leo",
|
| 303 |
+
"292": "tiger, Panthera tigris",
|
| 304 |
+
"293": "cheetah, chetah, Acinonyx jubatus",
|
| 305 |
+
"294": "brown bear, bruin, Ursus arctos",
|
| 306 |
+
"295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
|
| 307 |
+
"296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
|
| 308 |
+
"297": "sloth bear, Melursus ursinus, Ursus ursinus",
|
| 309 |
+
"298": "mongoose",
|
| 310 |
+
"299": "meerkat, mierkat",
|
| 311 |
+
"300": "tiger beetle",
|
| 312 |
+
"301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
|
| 313 |
+
"302": "ground beetle, carabid beetle",
|
| 314 |
+
"303": "long-horned beetle, longicorn, longicorn beetle",
|
| 315 |
+
"304": "leaf beetle, chrysomelid",
|
| 316 |
+
"305": "dung beetle",
|
| 317 |
+
"306": "rhinoceros beetle",
|
| 318 |
+
"307": "weevil",
|
| 319 |
+
"308": "fly",
|
| 320 |
+
"309": "bee",
|
| 321 |
+
"310": "ant, emmet, pismire",
|
| 322 |
+
"311": "grasshopper, hopper",
|
| 323 |
+
"312": "cricket",
|
| 324 |
+
"313": "walking stick, walkingstick, stick insect",
|
| 325 |
+
"314": "cockroach, roach",
|
| 326 |
+
"315": "mantis, mantid",
|
| 327 |
+
"316": "cicada, cicala",
|
| 328 |
+
"317": "leafhopper",
|
| 329 |
+
"318": "lacewing, lacewing fly",
|
| 330 |
+
"319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
|
| 331 |
+
"320": "damselfly",
|
| 332 |
+
"321": "admiral",
|
| 333 |
+
"322": "ringlet, ringlet butterfly",
|
| 334 |
+
"323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
|
| 335 |
+
"324": "cabbage butterfly",
|
| 336 |
+
"325": "sulphur butterfly, sulfur butterfly",
|
| 337 |
+
"326": "lycaenid, lycaenid butterfly",
|
| 338 |
+
"327": "starfish, sea star",
|
| 339 |
+
"328": "sea urchin",
|
| 340 |
+
"329": "sea cucumber, holothurian",
|
| 341 |
+
"330": "wood rabbit, cottontail, cottontail rabbit",
|
| 342 |
+
"331": "hare",
|
| 343 |
+
"332": "Angora, Angora rabbit",
|
| 344 |
+
"333": "hamster",
|
| 345 |
+
"334": "porcupine, hedgehog",
|
| 346 |
+
"335": "fox squirrel, eastern fox squirrel, Sciurus niger",
|
| 347 |
+
"336": "marmot",
|
| 348 |
+
"337": "beaver",
|
| 349 |
+
"338": "guinea pig, Cavia cobaya",
|
| 350 |
+
"339": "sorrel",
|
| 351 |
+
"340": "zebra",
|
| 352 |
+
"341": "hog, pig, grunter, squealer, Sus scrofa",
|
| 353 |
+
"342": "wild boar, boar, Sus scrofa",
|
| 354 |
+
"343": "warthog",
|
| 355 |
+
"344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
|
| 356 |
+
"345": "ox",
|
| 357 |
+
"346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
|
| 358 |
+
"347": "bison",
|
| 359 |
+
"348": "ram, tup",
|
| 360 |
+
"349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
|
| 361 |
+
"350": "ibex, Capra ibex",
|
| 362 |
+
"351": "hartebeest",
|
| 363 |
+
"352": "impala, Aepyceros melampus",
|
| 364 |
+
"353": "gazelle",
|
| 365 |
+
"354": "Arabian camel, dromedary, Camelus dromedarius",
|
| 366 |
+
"355": "llama",
|
| 367 |
+
"356": "weasel",
|
| 368 |
+
"357": "mink",
|
| 369 |
+
"358": "polecat, fitch, foulmart, foumart, Mustela putorius",
|
| 370 |
+
"359": "black-footed ferret, ferret, Mustela nigripes",
|
| 371 |
+
"360": "otter",
|
| 372 |
+
"361": "skunk, polecat, wood pussy",
|
| 373 |
+
"362": "badger",
|
| 374 |
+
"363": "armadillo",
|
| 375 |
+
"364": "three-toed sloth, ai, Bradypus tridactylus",
|
| 376 |
+
"365": "orangutan, orang, orangutang, Pongo pygmaeus",
|
| 377 |
+
"366": "gorilla, Gorilla gorilla",
|
| 378 |
+
"367": "chimpanzee, chimp, Pan troglodytes",
|
| 379 |
+
"368": "gibbon, Hylobates lar",
|
| 380 |
+
"369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
|
| 381 |
+
"370": "guenon, guenon monkey",
|
| 382 |
+
"371": "patas, hussar monkey, Erythrocebus patas",
|
| 383 |
+
"372": "baboon",
|
| 384 |
+
"373": "macaque",
|
| 385 |
+
"374": "langur",
|
| 386 |
+
"375": "colobus, colobus monkey",
|
| 387 |
+
"376": "proboscis monkey, Nasalis larvatus",
|
| 388 |
+
"377": "marmoset",
|
| 389 |
+
"378": "capuchin, ringtail, Cebus capucinus",
|
| 390 |
+
"379": "howler monkey, howler",
|
| 391 |
+
"380": "titi, titi monkey",
|
| 392 |
+
"381": "spider monkey, Ateles geoffroyi",
|
| 393 |
+
"382": "squirrel monkey, Saimiri sciureus",
|
| 394 |
+
"383": "Madagascar cat, ring-tailed lemur, Lemur catta",
|
| 395 |
+
"384": "indri, indris, Indri indri, Indri brevicaudatus",
|
| 396 |
+
"385": "Indian elephant, Elephas maximus",
|
| 397 |
+
"386": "African elephant, Loxodonta africana",
|
| 398 |
+
"387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
|
| 399 |
+
"388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
|
| 400 |
+
"389": "barracouta, snoek",
|
| 401 |
+
"390": "eel",
|
| 402 |
+
"391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
|
| 403 |
+
"392": "rock beauty, Holocanthus tricolor",
|
| 404 |
+
"393": "anemone fish",
|
| 405 |
+
"394": "sturgeon",
|
| 406 |
+
"395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
|
| 407 |
+
"396": "lionfish",
|
| 408 |
+
"397": "puffer, pufferfish, blowfish, globefish",
|
| 409 |
+
"398": "abacus",
|
| 410 |
+
"399": "abaya",
|
| 411 |
+
"400": "academic gown, academic robe, judge's robe",
|
| 412 |
+
"401": "accordion, piano accordion, squeeze box",
|
| 413 |
+
"402": "acoustic guitar",
|
| 414 |
+
"403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
|
| 415 |
+
"404": "airliner",
|
| 416 |
+
"405": "airship, dirigible",
|
| 417 |
+
"406": "altar",
|
| 418 |
+
"407": "ambulance",
|
| 419 |
+
"408": "amphibian, amphibious vehicle",
|
| 420 |
+
"409": "analog clock",
|
| 421 |
+
"410": "apiary, bee house",
|
| 422 |
+
"411": "apron",
|
| 423 |
+
"412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
|
| 424 |
+
"413": "assault rifle, assault gun",
|
| 425 |
+
"414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
|
| 426 |
+
"415": "bakery, bakeshop, bakehouse",
|
| 427 |
+
"416": "balance beam, beam",
|
| 428 |
+
"417": "balloon",
|
| 429 |
+
"418": "ballpoint, ballpoint pen, ballpen, Biro",
|
| 430 |
+
"419": "Band Aid",
|
| 431 |
+
"420": "banjo",
|
| 432 |
+
"421": "bannister, banister, balustrade, balusters, handrail",
|
| 433 |
+
"422": "barbell",
|
| 434 |
+
"423": "barber chair",
|
| 435 |
+
"424": "barbershop",
|
| 436 |
+
"425": "barn",
|
| 437 |
+
"426": "barometer",
|
| 438 |
+
"427": "barrel, cask",
|
| 439 |
+
"428": "barrow, garden cart, lawn cart, wheelbarrow",
|
| 440 |
+
"429": "baseball",
|
| 441 |
+
"430": "basketball",
|
| 442 |
+
"431": "bassinet",
|
| 443 |
+
"432": "bassoon",
|
| 444 |
+
"433": "bathing cap, swimming cap",
|
| 445 |
+
"434": "bath towel",
|
| 446 |
+
"435": "bathtub, bathing tub, bath, tub",
|
| 447 |
+
"436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
|
| 448 |
+
"437": "beacon, lighthouse, beacon light, pharos",
|
| 449 |
+
"438": "beaker",
|
| 450 |
+
"439": "bearskin, busby, shako",
|
| 451 |
+
"440": "beer bottle",
|
| 452 |
+
"441": "beer glass",
|
| 453 |
+
"442": "bell cote, bell cot",
|
| 454 |
+
"443": "bib",
|
| 455 |
+
"444": "bicycle-built-for-two, tandem bicycle, tandem",
|
| 456 |
+
"445": "bikini, two-piece",
|
| 457 |
+
"446": "binder, ring-binder",
|
| 458 |
+
"447": "binoculars, field glasses, opera glasses",
|
| 459 |
+
"448": "birdhouse",
|
| 460 |
+
"449": "boathouse",
|
| 461 |
+
"450": "bobsled, bobsleigh, bob",
|
| 462 |
+
"451": "bolo tie, bolo, bola tie, bola",
|
| 463 |
+
"452": "bonnet, poke bonnet",
|
| 464 |
+
"453": "bookcase",
|
| 465 |
+
"454": "bookshop, bookstore, bookstall",
|
| 466 |
+
"455": "bottlecap",
|
| 467 |
+
"456": "bow",
|
| 468 |
+
"457": "bow tie, bow-tie, bowtie",
|
| 469 |
+
"458": "brass, memorial tablet, plaque",
|
| 470 |
+
"459": "brassiere, bra, bandeau",
|
| 471 |
+
"460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
|
| 472 |
+
"461": "breastplate, aegis, egis",
|
| 473 |
+
"462": "broom",
|
| 474 |
+
"463": "bucket, pail",
|
| 475 |
+
"464": "buckle",
|
| 476 |
+
"465": "bulletproof vest",
|
| 477 |
+
"466": "bullet train, bullet",
|
| 478 |
+
"467": "butcher shop, meat market",
|
| 479 |
+
"468": "cab, hack, taxi, taxicab",
|
| 480 |
+
"469": "caldron, cauldron",
|
| 481 |
+
"470": "candle, taper, wax light",
|
| 482 |
+
"471": "cannon",
|
| 483 |
+
"472": "canoe",
|
| 484 |
+
"473": "can opener, tin opener",
|
| 485 |
+
"474": "cardigan",
|
| 486 |
+
"475": "car mirror",
|
| 487 |
+
"476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
|
| 488 |
+
"477": "carpenter's kit, tool kit",
|
| 489 |
+
"478": "carton",
|
| 490 |
+
"479": "car wheel",
|
| 491 |
+
"480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
|
| 492 |
+
"481": "cassette",
|
| 493 |
+
"482": "cassette player",
|
| 494 |
+
"483": "castle",
|
| 495 |
+
"484": "catamaran",
|
| 496 |
+
"485": "CD player",
|
| 497 |
+
"486": "cello, violoncello",
|
| 498 |
+
"487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
|
| 499 |
+
"488": "chain",
|
| 500 |
+
"489": "chainlink fence",
|
| 501 |
+
"490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
|
| 502 |
+
"491": "chain saw, chainsaw",
|
| 503 |
+
"492": "chest",
|
| 504 |
+
"493": "chiffonier, commode",
|
| 505 |
+
"494": "chime, bell, gong",
|
| 506 |
+
"495": "china cabinet, china closet",
|
| 507 |
+
"496": "Christmas stocking",
|
| 508 |
+
"497": "church, church building",
|
| 509 |
+
"498": "cinema, movie theater, movie theatre, movie house, picture palace",
|
| 510 |
+
"499": "cleaver, meat cleaver, chopper",
|
| 511 |
+
"500": "cliff dwelling",
|
| 512 |
+
"501": "cloak",
|
| 513 |
+
"502": "clog, geta, patten, sabot",
|
| 514 |
+
"503": "cocktail shaker",
|
| 515 |
+
"504": "coffee mug",
|
| 516 |
+
"505": "coffeepot",
|
| 517 |
+
"506": "coil, spiral, volute, whorl, helix",
|
| 518 |
+
"507": "combination lock",
|
| 519 |
+
"508": "computer keyboard, keypad",
|
| 520 |
+
"509": "confectionery, confectionary, candy store",
|
| 521 |
+
"510": "container ship, containership, container vessel",
|
| 522 |
+
"511": "convertible",
|
| 523 |
+
"512": "corkscrew, bottle screw",
|
| 524 |
+
"513": "cornet, horn, trumpet, trump",
|
| 525 |
+
"514": "cowboy boot",
|
| 526 |
+
"515": "cowboy hat, ten-gallon hat",
|
| 527 |
+
"516": "cradle",
|
| 528 |
+
"517": "crane",
|
| 529 |
+
"518": "crash helmet",
|
| 530 |
+
"519": "crate",
|
| 531 |
+
"520": "crib, cot",
|
| 532 |
+
"521": "Crock Pot",
|
| 533 |
+
"522": "croquet ball",
|
| 534 |
+
"523": "crutch",
|
| 535 |
+
"524": "cuirass",
|
| 536 |
+
"525": "dam, dike, dyke",
|
| 537 |
+
"526": "desk",
|
| 538 |
+
"527": "desktop computer",
|
| 539 |
+
"528": "dial telephone, dial phone",
|
| 540 |
+
"529": "diaper, nappy, napkin",
|
| 541 |
+
"530": "digital clock",
|
| 542 |
+
"531": "digital watch",
|
| 543 |
+
"532": "dining table, board",
|
| 544 |
+
"533": "dishrag, dishcloth",
|
| 545 |
+
"534": "dishwasher, dish washer, dishwashing machine",
|
| 546 |
+
"535": "disk brake, disc brake",
|
| 547 |
+
"536": "dock, dockage, docking facility",
|
| 548 |
+
"537": "dogsled, dog sled, dog sleigh",
|
| 549 |
+
"538": "dome",
|
| 550 |
+
"539": "doormat, welcome mat",
|
| 551 |
+
"540": "drilling platform, offshore rig",
|
| 552 |
+
"541": "drum, membranophone, tympan",
|
| 553 |
+
"542": "drumstick",
|
| 554 |
+
"543": "dumbbell",
|
| 555 |
+
"544": "Dutch oven",
|
| 556 |
+
"545": "electric fan, blower",
|
| 557 |
+
"546": "electric guitar",
|
| 558 |
+
"547": "electric locomotive",
|
| 559 |
+
"548": "entertainment center",
|
| 560 |
+
"549": "envelope",
|
| 561 |
+
"550": "espresso maker",
|
| 562 |
+
"551": "face powder",
|
| 563 |
+
"552": "feather boa, boa",
|
| 564 |
+
"553": "file, file cabinet, filing cabinet",
|
| 565 |
+
"554": "fireboat",
|
| 566 |
+
"555": "fire engine, fire truck",
|
| 567 |
+
"556": "fire screen, fireguard",
|
| 568 |
+
"557": "flagpole, flagstaff",
|
| 569 |
+
"558": "flute, transverse flute",
|
| 570 |
+
"559": "folding chair",
|
| 571 |
+
"560": "football helmet",
|
| 572 |
+
"561": "forklift",
|
| 573 |
+
"562": "fountain",
|
| 574 |
+
"563": "fountain pen",
|
| 575 |
+
"564": "four-poster",
|
| 576 |
+
"565": "freight car",
|
| 577 |
+
"566": "French horn, horn",
|
| 578 |
+
"567": "frying pan, frypan, skillet",
|
| 579 |
+
"568": "fur coat",
|
| 580 |
+
"569": "garbage truck, dustcart",
|
| 581 |
+
"570": "gasmask, respirator, gas helmet",
|
| 582 |
+
"571": "gas pump, gasoline pump, petrol pump, island dispenser",
|
| 583 |
+
"572": "goblet",
|
| 584 |
+
"573": "go-kart",
|
| 585 |
+
"574": "golf ball",
|
| 586 |
+
"575": "golfcart, golf cart",
|
| 587 |
+
"576": "gondola",
|
| 588 |
+
"577": "gong, tam-tam",
|
| 589 |
+
"578": "gown",
|
| 590 |
+
"579": "grand piano, grand",
|
| 591 |
+
"580": "greenhouse, nursery, glasshouse",
|
| 592 |
+
"581": "grille, radiator grille",
|
| 593 |
+
"582": "grocery store, grocery, food market, market",
|
| 594 |
+
"583": "guillotine",
|
| 595 |
+
"584": "hair slide",
|
| 596 |
+
"585": "hair spray",
|
| 597 |
+
"586": "half track",
|
| 598 |
+
"587": "hammer",
|
| 599 |
+
"588": "hamper",
|
| 600 |
+
"589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
|
| 601 |
+
"590": "hand-held computer, hand-held microcomputer",
|
| 602 |
+
"591": "handkerchief, hankie, hanky, hankey",
|
| 603 |
+
"592": "hard disc, hard disk, fixed disk",
|
| 604 |
+
"593": "harmonica, mouth organ, harp, mouth harp",
|
| 605 |
+
"594": "harp",
|
| 606 |
+
"595": "harvester, reaper",
|
| 607 |
+
"596": "hatchet",
|
| 608 |
+
"597": "holster",
|
| 609 |
+
"598": "home theater, home theatre",
|
| 610 |
+
"599": "honeycomb",
|
| 611 |
+
"600": "hook, claw",
|
| 612 |
+
"601": "hoopskirt, crinoline",
|
| 613 |
+
"602": "horizontal bar, high bar",
|
| 614 |
+
"603": "horse cart, horse-cart",
|
| 615 |
+
"604": "hourglass",
|
| 616 |
+
"605": "iPod",
|
| 617 |
+
"606": "iron, smoothing iron",
|
| 618 |
+
"607": "jack-o'-lantern",
|
| 619 |
+
"608": "jean, blue jean, denim",
|
| 620 |
+
"609": "jeep, landrover",
|
| 621 |
+
"610": "jersey, T-shirt, tee shirt",
|
| 622 |
+
"611": "jigsaw puzzle",
|
| 623 |
+
"612": "jinrikisha, ricksha, rickshaw",
|
| 624 |
+
"613": "joystick",
|
| 625 |
+
"614": "kimono",
|
| 626 |
+
"615": "knee pad",
|
| 627 |
+
"616": "knot",
|
| 628 |
+
"617": "lab coat, laboratory coat",
|
| 629 |
+
"618": "ladle",
|
| 630 |
+
"619": "lampshade, lamp shade",
|
| 631 |
+
"620": "laptop, laptop computer",
|
| 632 |
+
"621": "lawn mower, mower",
|
| 633 |
+
"622": "lens cap, lens cover",
|
| 634 |
+
"623": "letter opener, paper knife, paperknife",
|
| 635 |
+
"624": "library",
|
| 636 |
+
"625": "lifeboat",
|
| 637 |
+
"626": "lighter, light, igniter, ignitor",
|
| 638 |
+
"627": "limousine, limo",
|
| 639 |
+
"628": "liner, ocean liner",
|
| 640 |
+
"629": "lipstick, lip rouge",
|
| 641 |
+
"630": "Loafer",
|
| 642 |
+
"631": "lotion",
|
| 643 |
+
"632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
|
| 644 |
+
"633": "loupe, jeweler's loupe",
|
| 645 |
+
"634": "lumbermill, sawmill",
|
| 646 |
+
"635": "magnetic compass",
|
| 647 |
+
"636": "mailbag, postbag",
|
| 648 |
+
"637": "mailbox, letter box",
|
| 649 |
+
"638": "maillot",
|
| 650 |
+
"639": "maillot, tank suit",
|
| 651 |
+
"640": "manhole cover",
|
| 652 |
+
"641": "maraca",
|
| 653 |
+
"642": "marimba, xylophone",
|
| 654 |
+
"643": "mask",
|
| 655 |
+
"644": "matchstick",
|
| 656 |
+
"645": "maypole",
|
| 657 |
+
"646": "maze, labyrinth",
|
| 658 |
+
"647": "measuring cup",
|
| 659 |
+
"648": "medicine chest, medicine cabinet",
|
| 660 |
+
"649": "megalith, megalithic structure",
|
| 661 |
+
"650": "microphone, mike",
|
| 662 |
+
"651": "microwave, microwave oven",
|
| 663 |
+
"652": "military uniform",
|
| 664 |
+
"653": "milk can",
|
| 665 |
+
"654": "minibus",
|
| 666 |
+
"655": "miniskirt, mini",
|
| 667 |
+
"656": "minivan",
|
| 668 |
+
"657": "missile",
|
| 669 |
+
"658": "mitten",
|
| 670 |
+
"659": "mixing bowl",
|
| 671 |
+
"660": "mobile home, manufactured home",
|
| 672 |
+
"661": "Model T",
|
| 673 |
+
"662": "modem",
|
| 674 |
+
"663": "monastery",
|
| 675 |
+
"664": "monitor",
|
| 676 |
+
"665": "moped",
|
| 677 |
+
"666": "mortar",
|
| 678 |
+
"667": "mortarboard",
|
| 679 |
+
"668": "mosque",
|
| 680 |
+
"669": "mosquito net",
|
| 681 |
+
"670": "motor scooter, scooter",
|
| 682 |
+
"671": "mountain bike, all-terrain bike, off-roader",
|
| 683 |
+
"672": "mountain tent",
|
| 684 |
+
"673": "mouse, computer mouse",
|
| 685 |
+
"674": "mousetrap",
|
| 686 |
+
"675": "moving van",
|
| 687 |
+
"676": "muzzle",
|
| 688 |
+
"677": "nail",
|
| 689 |
+
"678": "neck brace",
|
| 690 |
+
"679": "necklace",
|
| 691 |
+
"680": "nipple",
|
| 692 |
+
"681": "notebook, notebook computer",
|
| 693 |
+
"682": "obelisk",
|
| 694 |
+
"683": "oboe, hautboy, hautbois",
|
| 695 |
+
"684": "ocarina, sweet potato",
|
| 696 |
+
"685": "odometer, hodometer, mileometer, milometer",
|
| 697 |
+
"686": "oil filter",
|
| 698 |
+
"687": "organ, pipe organ",
|
| 699 |
+
"688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
|
| 700 |
+
"689": "overskirt",
|
| 701 |
+
"690": "oxcart",
|
| 702 |
+
"691": "oxygen mask",
|
| 703 |
+
"692": "packet",
|
| 704 |
+
"693": "paddle, boat paddle",
|
| 705 |
+
"694": "paddlewheel, paddle wheel",
|
| 706 |
+
"695": "padlock",
|
| 707 |
+
"696": "paintbrush",
|
| 708 |
+
"697": "pajama, pyjama, pj's, jammies",
|
| 709 |
+
"698": "palace",
|
| 710 |
+
"699": "panpipe, pandean pipe, syrinx",
|
| 711 |
+
"700": "paper towel",
|
| 712 |
+
"701": "parachute, chute",
|
| 713 |
+
"702": "parallel bars, bars",
|
| 714 |
+
"703": "park bench",
|
| 715 |
+
"704": "parking meter",
|
| 716 |
+
"705": "passenger car, coach, carriage",
|
| 717 |
+
"706": "patio, terrace",
|
| 718 |
+
"707": "pay-phone, pay-station",
|
| 719 |
+
"708": "pedestal, plinth, footstall",
|
| 720 |
+
"709": "pencil box, pencil case",
|
| 721 |
+
"710": "pencil sharpener",
|
| 722 |
+
"711": "perfume, essence",
|
| 723 |
+
"712": "Petri dish",
|
| 724 |
+
"713": "photocopier",
|
| 725 |
+
"714": "pick, plectrum, plectron",
|
| 726 |
+
"715": "pickelhaube",
|
| 727 |
+
"716": "picket fence, paling",
|
| 728 |
+
"717": "pickup, pickup truck",
|
| 729 |
+
"718": "pier",
|
| 730 |
+
"719": "piggy bank, penny bank",
|
| 731 |
+
"720": "pill bottle",
|
| 732 |
+
"721": "pillow",
|
| 733 |
+
"722": "ping-pong ball",
|
| 734 |
+
"723": "pinwheel",
|
| 735 |
+
"724": "pirate, pirate ship",
|
| 736 |
+
"725": "pitcher, ewer",
|
| 737 |
+
"726": "plane, carpenter's plane, woodworking plane",
|
| 738 |
+
"727": "planetarium",
|
| 739 |
+
"728": "plastic bag",
|
| 740 |
+
"729": "plate rack",
|
| 741 |
+
"730": "plow, plough",
|
| 742 |
+
"731": "plunger, plumber's helper",
|
| 743 |
+
"732": "Polaroid camera, Polaroid Land camera",
|
| 744 |
+
"733": "pole",
|
| 745 |
+
"734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
|
| 746 |
+
"735": "poncho",
|
| 747 |
+
"736": "pool table, billiard table, snooker table",
|
| 748 |
+
"737": "pop bottle, soda bottle",
|
| 749 |
+
"738": "pot, flowerpot",
|
| 750 |
+
"739": "potter's wheel",
|
| 751 |
+
"740": "power drill",
|
| 752 |
+
"741": "prayer rug, prayer mat",
|
| 753 |
+
"742": "printer",
|
| 754 |
+
"743": "prison, prison house",
|
| 755 |
+
"744": "projectile, missile",
|
| 756 |
+
"745": "projector",
|
| 757 |
+
"746": "puck, hockey puck",
|
| 758 |
+
"747": "punching bag, punch bag, punching ball, punchball",
|
| 759 |
+
"748": "purse",
|
| 760 |
+
"749": "quill, quill pen",
|
| 761 |
+
"750": "quilt, comforter, comfort, puff",
|
| 762 |
+
"751": "racer, race car, racing car",
|
| 763 |
+
"752": "racket, racquet",
|
| 764 |
+
"753": "radiator",
|
| 765 |
+
"754": "radio, wireless",
|
| 766 |
+
"755": "radio telescope, radio reflector",
|
| 767 |
+
"756": "rain barrel",
|
| 768 |
+
"757": "recreational vehicle, RV, R.V.",
|
| 769 |
+
"758": "reel",
|
| 770 |
+
"759": "reflex camera",
|
| 771 |
+
"760": "refrigerator, icebox",
|
| 772 |
+
"761": "remote control, remote",
|
| 773 |
+
"762": "restaurant, eating house, eating place, eatery",
|
| 774 |
+
"763": "revolver, six-gun, six-shooter",
|
| 775 |
+
"764": "rifle",
|
| 776 |
+
"765": "rocking chair, rocker",
|
| 777 |
+
"766": "rotisserie",
|
| 778 |
+
"767": "rubber eraser, rubber, pencil eraser",
|
| 779 |
+
"768": "rugby ball",
|
| 780 |
+
"769": "rule, ruler",
|
| 781 |
+
"770": "running shoe",
|
| 782 |
+
"771": "safe",
|
| 783 |
+
"772": "safety pin",
|
| 784 |
+
"773": "saltshaker, salt shaker",
|
| 785 |
+
"774": "sandal",
|
| 786 |
+
"775": "sarong",
|
| 787 |
+
"776": "sax, saxophone",
|
| 788 |
+
"777": "scabbard",
|
| 789 |
+
"778": "scale, weighing machine",
|
| 790 |
+
"779": "school bus",
|
| 791 |
+
"780": "schooner",
|
| 792 |
+
"781": "scoreboard",
|
| 793 |
+
"782": "screen, CRT screen",
|
| 794 |
+
"783": "screw",
|
| 795 |
+
"784": "screwdriver",
|
| 796 |
+
"785": "seat belt, seatbelt",
|
| 797 |
+
"786": "sewing machine",
|
| 798 |
+
"787": "shield, buckler",
|
| 799 |
+
"788": "shoe shop, shoe-shop, shoe store",
|
| 800 |
+
"789": "shoji",
|
| 801 |
+
"790": "shopping basket",
|
| 802 |
+
"791": "shopping cart",
|
| 803 |
+
"792": "shovel",
|
| 804 |
+
"793": "shower cap",
|
| 805 |
+
"794": "shower curtain",
|
| 806 |
+
"795": "ski",
|
| 807 |
+
"796": "ski mask",
|
| 808 |
+
"797": "sleeping bag",
|
| 809 |
+
"798": "slide rule, slipstick",
|
| 810 |
+
"799": "sliding door",
|
| 811 |
+
"800": "slot, one-armed bandit",
|
| 812 |
+
"801": "snorkel",
|
| 813 |
+
"802": "snowmobile",
|
| 814 |
+
"803": "snowplow, snowplough",
|
| 815 |
+
"804": "soap dispenser",
|
| 816 |
+
"805": "soccer ball",
|
| 817 |
+
"806": "sock",
|
| 818 |
+
"807": "solar dish, solar collector, solar furnace",
|
| 819 |
+
"808": "sombrero",
|
| 820 |
+
"809": "soup bowl",
|
| 821 |
+
"810": "space bar",
|
| 822 |
+
"811": "space heater",
|
| 823 |
+
"812": "space shuttle",
|
| 824 |
+
"813": "spatula",
|
| 825 |
+
"814": "speedboat",
|
| 826 |
+
"815": "spider web, spider's web",
|
| 827 |
+
"816": "spindle",
|
| 828 |
+
"817": "sports car, sport car",
|
| 829 |
+
"818": "spotlight, spot",
|
| 830 |
+
"819": "stage",
|
| 831 |
+
"820": "steam locomotive",
|
| 832 |
+
"821": "steel arch bridge",
|
| 833 |
+
"822": "steel drum",
|
| 834 |
+
"823": "stethoscope",
|
| 835 |
+
"824": "stole",
|
| 836 |
+
"825": "stone wall",
|
| 837 |
+
"826": "stopwatch, stop watch",
|
| 838 |
+
"827": "stove",
|
| 839 |
+
"828": "strainer",
|
| 840 |
+
"829": "streetcar, tram, tramcar, trolley, trolley car",
|
| 841 |
+
"830": "stretcher",
|
| 842 |
+
"831": "studio couch, day bed",
|
| 843 |
+
"832": "stupa, tope",
|
| 844 |
+
"833": "submarine, pigboat, sub, U-boat",
|
| 845 |
+
"834": "suit, suit of clothes",
|
| 846 |
+
"835": "sundial",
|
| 847 |
+
"836": "sunglass",
|
| 848 |
+
"837": "sunglasses, dark glasses, shades",
|
| 849 |
+
"838": "sunscreen, sunblock, sun blocker",
|
| 850 |
+
"839": "suspension bridge",
|
| 851 |
+
"840": "swab, swob, mop",
|
| 852 |
+
"841": "sweatshirt",
|
| 853 |
+
"842": "swimming trunks, bathing trunks",
|
| 854 |
+
"843": "swing",
|
| 855 |
+
"844": "switch, electric switch, electrical switch",
|
| 856 |
+
"845": "syringe",
|
| 857 |
+
"846": "table lamp",
|
| 858 |
+
"847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
|
| 859 |
+
"848": "tape player",
|
| 860 |
+
"849": "teapot",
|
| 861 |
+
"850": "teddy, teddy bear",
|
| 862 |
+
"851": "television, television system",
|
| 863 |
+
"852": "tennis ball",
|
| 864 |
+
"853": "thatch, thatched roof",
|
| 865 |
+
"854": "theater curtain, theatre curtain",
|
| 866 |
+
"855": "thimble",
|
| 867 |
+
"856": "thresher, thrasher, threshing machine",
|
| 868 |
+
"857": "throne",
|
| 869 |
+
"858": "tile roof",
|
| 870 |
+
"859": "toaster",
|
| 871 |
+
"860": "tobacco shop, tobacconist shop, tobacconist",
|
| 872 |
+
"861": "toilet seat",
|
| 873 |
+
"862": "torch",
|
| 874 |
+
"863": "totem pole",
|
| 875 |
+
"864": "tow truck, tow car, wrecker",
|
| 876 |
+
"865": "toyshop",
|
| 877 |
+
"866": "tractor",
|
| 878 |
+
"867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
|
| 879 |
+
"868": "tray",
|
| 880 |
+
"869": "trench coat",
|
| 881 |
+
"870": "tricycle, trike, velocipede",
|
| 882 |
+
"871": "trimaran",
|
| 883 |
+
"872": "tripod",
|
| 884 |
+
"873": "triumphal arch",
|
| 885 |
+
"874": "trolleybus, trolley coach, trackless trolley",
|
| 886 |
+
"875": "trombone",
|
| 887 |
+
"876": "tub, vat",
|
| 888 |
+
"877": "turnstile",
|
| 889 |
+
"878": "typewriter keyboard",
|
| 890 |
+
"879": "umbrella",
|
| 891 |
+
"880": "unicycle, monocycle",
|
| 892 |
+
"881": "upright, upright piano",
|
| 893 |
+
"882": "vacuum, vacuum cleaner",
|
| 894 |
+
"883": "vase",
|
| 895 |
+
"884": "vault",
|
| 896 |
+
"885": "velvet",
|
| 897 |
+
"886": "vending machine",
|
| 898 |
+
"887": "vestment",
|
| 899 |
+
"888": "viaduct",
|
| 900 |
+
"889": "violin, fiddle",
|
| 901 |
+
"890": "volleyball",
|
| 902 |
+
"891": "waffle iron",
|
| 903 |
+
"892": "wall clock",
|
| 904 |
+
"893": "wallet, billfold, notecase, pocketbook",
|
| 905 |
+
"894": "wardrobe, closet, press",
|
| 906 |
+
"895": "warplane, military plane",
|
| 907 |
+
"896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
|
| 908 |
+
"897": "washer, automatic washer, washing machine",
|
| 909 |
+
"898": "water bottle",
|
| 910 |
+
"899": "water jug",
|
| 911 |
+
"900": "water tower",
|
| 912 |
+
"901": "whiskey jug",
|
| 913 |
+
"902": "whistle",
|
| 914 |
+
"903": "wig",
|
| 915 |
+
"904": "window screen",
|
| 916 |
+
"905": "window shade",
|
| 917 |
+
"906": "Windsor tie",
|
| 918 |
+
"907": "wine bottle",
|
| 919 |
+
"908": "wing",
|
| 920 |
+
"909": "wok",
|
| 921 |
+
"910": "wooden spoon",
|
| 922 |
+
"911": "wool, woolen, woollen",
|
| 923 |
+
"912": "worm fence, snake fence, snake-rail fence, Virginia fence",
|
| 924 |
+
"913": "wreck",
|
| 925 |
+
"914": "yawl",
|
| 926 |
+
"915": "yurt",
|
| 927 |
+
"916": "web site, website, internet site, site",
|
| 928 |
+
"917": "comic book",
|
| 929 |
+
"918": "crossword puzzle, crossword",
|
| 930 |
+
"919": "street sign",
|
| 931 |
+
"920": "traffic light, traffic signal, stoplight",
|
| 932 |
+
"921": "book jacket, dust cover, dust jacket, dust wrapper",
|
| 933 |
+
"922": "menu",
|
| 934 |
+
"923": "plate",
|
| 935 |
+
"924": "guacamole",
|
| 936 |
+
"925": "consomme",
|
| 937 |
+
"926": "hot pot, hotpot",
|
| 938 |
+
"927": "trifle",
|
| 939 |
+
"928": "ice cream, icecream",
|
| 940 |
+
"929": "ice lolly, lolly, lollipop, popsicle",
|
| 941 |
+
"930": "French loaf",
|
| 942 |
+
"931": "bagel, beigel",
|
| 943 |
+
"932": "pretzel",
|
| 944 |
+
"933": "cheeseburger",
|
| 945 |
+
"934": "hotdog, hot dog, red hot",
|
| 946 |
+
"935": "mashed potato",
|
| 947 |
+
"936": "head cabbage",
|
| 948 |
+
"937": "broccoli",
|
| 949 |
+
"938": "cauliflower",
|
| 950 |
+
"939": "zucchini, courgette",
|
| 951 |
+
"940": "spaghetti squash",
|
| 952 |
+
"941": "acorn squash",
|
| 953 |
+
"942": "butternut squash",
|
| 954 |
+
"943": "cucumber, cuke",
|
| 955 |
+
"944": "artichoke, globe artichoke",
|
| 956 |
+
"945": "bell pepper",
|
| 957 |
+
"946": "cardoon",
|
| 958 |
+
"947": "mushroom",
|
| 959 |
+
"948": "Granny Smith",
|
| 960 |
+
"949": "strawberry",
|
| 961 |
+
"950": "orange",
|
| 962 |
+
"951": "lemon",
|
| 963 |
+
"952": "fig",
|
| 964 |
+
"953": "pineapple, ananas",
|
| 965 |
+
"954": "banana",
|
| 966 |
+
"955": "jackfruit, jak, jack",
|
| 967 |
+
"956": "custard apple",
|
| 968 |
+
"957": "pomegranate",
|
| 969 |
+
"958": "hay",
|
| 970 |
+
"959": "carbonara",
|
| 971 |
+
"960": "chocolate sauce, chocolate syrup",
|
| 972 |
+
"961": "dough",
|
| 973 |
+
"962": "meat loaf, meatloaf",
|
| 974 |
+
"963": "pizza, pizza pie",
|
| 975 |
+
"964": "potpie",
|
| 976 |
+
"965": "burrito",
|
| 977 |
+
"966": "red wine",
|
| 978 |
+
"967": "espresso",
|
| 979 |
+
"968": "cup",
|
| 980 |
+
"969": "eggnog",
|
| 981 |
+
"970": "alp",
|
| 982 |
+
"971": "bubble",
|
| 983 |
+
"972": "cliff, drop, drop-off",
|
| 984 |
+
"973": "coral reef",
|
| 985 |
+
"974": "geyser",
|
| 986 |
+
"975": "lakeside, lakeshore",
|
| 987 |
+
"976": "promontory, headland, head, foreland",
|
| 988 |
+
"977": "sandbar, sand bar",
|
| 989 |
+
"978": "seashore, coast, seacoast, sea-coast",
|
| 990 |
+
"979": "valley, vale",
|
| 991 |
+
"980": "volcano",
|
| 992 |
+
"981": "ballplayer, baseball player",
|
| 993 |
+
"982": "groom, bridegroom",
|
| 994 |
+
"983": "scuba diver",
|
| 995 |
+
"984": "rapeseed",
|
| 996 |
+
"985": "daisy",
|
| 997 |
+
"986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
|
| 998 |
+
"987": "corn",
|
| 999 |
+
"988": "acorn",
|
| 1000 |
+
"989": "hip, rose hip, rosehip",
|
| 1001 |
+
"990": "buckeye, horse chestnut, conker",
|
| 1002 |
+
"991": "coral fungus",
|
| 1003 |
+
"992": "agaric",
|
| 1004 |
+
"993": "gyromitra",
|
| 1005 |
+
"994": "stinkhorn, carrion fungus",
|
| 1006 |
+
"995": "earthstar",
|
| 1007 |
+
"996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
|
| 1008 |
+
"997": "bolete",
|
| 1009 |
+
"998": "ear, spike, capitulum",
|
| 1010 |
+
"999": "toilet tissue, toilet paper, bathroom tissue"
|
| 1011 |
+
},
|
| 1012 |
+
"image_size": 224,
|
| 1013 |
+
"initializer_range": 0.02,
|
| 1014 |
+
"intermediate_size": 3072,
|
| 1015 |
+
"label2id": {
|
| 1016 |
+
"Afghan hound, Afghan": 160,
|
| 1017 |
+
"African chameleon, Chamaeleo chamaeleon": 47,
|
| 1018 |
+
"African crocodile, Nile crocodile, Crocodylus niloticus": 49,
|
| 1019 |
+
"African elephant, Loxodonta africana": 386,
|
| 1020 |
+
"African grey, African gray, Psittacus erithacus": 87,
|
| 1021 |
+
"African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus": 275,
|
| 1022 |
+
"Airedale, Airedale terrier": 191,
|
| 1023 |
+
"American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier": 180,
|
| 1024 |
+
"American alligator, Alligator mississipiensis": 50,
|
| 1025 |
+
"American black bear, black bear, Ursus americanus, Euarctos americanus": 295,
|
| 1026 |
+
"American chameleon, anole, Anolis carolinensis": 40,
|
| 1027 |
+
"American coot, marsh hen, mud hen, water hen, Fulica americana": 137,
|
| 1028 |
+
"American egret, great white heron, Egretta albus": 132,
|
| 1029 |
+
"American lobster, Northern lobster, Maine lobster, Homarus americanus": 122,
|
| 1030 |
+
"Angora, Angora rabbit": 332,
|
| 1031 |
+
"Appenzeller": 240,
|
| 1032 |
+
"Arabian camel, dromedary, Camelus dromedarius": 354,
|
| 1033 |
+
"Arctic fox, white fox, Alopex lagopus": 279,
|
| 1034 |
+
"Australian terrier": 193,
|
| 1035 |
+
"Band Aid": 419,
|
| 1036 |
+
"Bedlington terrier": 181,
|
| 1037 |
+
"Bernese mountain dog": 239,
|
| 1038 |
+
"Blenheim spaniel": 156,
|
| 1039 |
+
"Border collie": 232,
|
| 1040 |
+
"Border terrier": 182,
|
| 1041 |
+
"Boston bull, Boston terrier": 195,
|
| 1042 |
+
"Bouvier des Flandres, Bouviers des Flandres": 233,
|
| 1043 |
+
"Brabancon griffon": 262,
|
| 1044 |
+
"Brittany spaniel": 215,
|
| 1045 |
+
"CD player": 485,
|
| 1046 |
+
"Cardigan, Cardigan Welsh corgi": 264,
|
| 1047 |
+
"Chesapeake Bay retriever": 209,
|
| 1048 |
+
"Chihuahua": 151,
|
| 1049 |
+
"Christmas stocking": 496,
|
| 1050 |
+
"Crock Pot": 521,
|
| 1051 |
+
"Dandie Dinmont, Dandie Dinmont terrier": 194,
|
| 1052 |
+
"Doberman, Doberman pinscher": 236,
|
| 1053 |
+
"Dungeness crab, Cancer magister": 118,
|
| 1054 |
+
"Dutch oven": 544,
|
| 1055 |
+
"Egyptian cat": 285,
|
| 1056 |
+
"English foxhound": 167,
|
| 1057 |
+
"English setter": 212,
|
| 1058 |
+
"English springer, English springer spaniel": 217,
|
| 1059 |
+
"EntleBucher": 241,
|
| 1060 |
+
"Eskimo dog, husky": 248,
|
| 1061 |
+
"European fire salamander, Salamandra salamandra": 25,
|
| 1062 |
+
"European gallinule, Porphyrio porphyrio": 136,
|
| 1063 |
+
"French bulldog": 245,
|
| 1064 |
+
"French horn, horn": 566,
|
| 1065 |
+
"French loaf": 930,
|
| 1066 |
+
"German shepherd, German shepherd dog, German police dog, alsatian": 235,
|
| 1067 |
+
"German short-haired pointer": 210,
|
| 1068 |
+
"Gila monster, Heloderma suspectum": 45,
|
| 1069 |
+
"Gordon setter": 214,
|
| 1070 |
+
"Granny Smith": 948,
|
| 1071 |
+
"Great Dane": 246,
|
| 1072 |
+
"Great Pyrenees": 257,
|
| 1073 |
+
"Greater Swiss Mountain dog": 238,
|
| 1074 |
+
"Ibizan hound, Ibizan Podenco": 173,
|
| 1075 |
+
"Indian cobra, Naja naja": 63,
|
| 1076 |
+
"Indian elephant, Elephas maximus": 385,
|
| 1077 |
+
"Irish setter, red setter": 213,
|
| 1078 |
+
"Irish terrier": 184,
|
| 1079 |
+
"Irish water spaniel": 221,
|
| 1080 |
+
"Irish wolfhound": 170,
|
| 1081 |
+
"Italian greyhound": 171,
|
| 1082 |
+
"Japanese spaniel": 152,
|
| 1083 |
+
"Kerry blue terrier": 183,
|
| 1084 |
+
"Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis": 48,
|
| 1085 |
+
"Labrador retriever": 208,
|
| 1086 |
+
"Lakeland terrier": 189,
|
| 1087 |
+
"Leonberg": 255,
|
| 1088 |
+
"Lhasa, Lhasa apso": 204,
|
| 1089 |
+
"Loafer": 630,
|
| 1090 |
+
"Madagascar cat, ring-tailed lemur, Lemur catta": 383,
|
| 1091 |
+
"Maltese dog, Maltese terrier, Maltese": 153,
|
| 1092 |
+
"Mexican hairless": 268,
|
| 1093 |
+
"Model T": 661,
|
| 1094 |
+
"Newfoundland, Newfoundland dog": 256,
|
| 1095 |
+
"Norfolk terrier": 185,
|
| 1096 |
+
"Norwegian elkhound, elkhound": 174,
|
| 1097 |
+
"Norwich terrier": 186,
|
| 1098 |
+
"Old English sheepdog, bobtail": 229,
|
| 1099 |
+
"Pekinese, Pekingese, Peke": 154,
|
| 1100 |
+
"Pembroke, Pembroke Welsh corgi": 263,
|
| 1101 |
+
"Persian cat": 283,
|
| 1102 |
+
"Petri dish": 712,
|
| 1103 |
+
"Polaroid camera, Polaroid Land camera": 732,
|
| 1104 |
+
"Pomeranian": 259,
|
| 1105 |
+
"Rhodesian ridgeback": 159,
|
| 1106 |
+
"Rottweiler": 234,
|
| 1107 |
+
"Saint Bernard, St Bernard": 247,
|
| 1108 |
+
"Saluki, gazelle hound": 176,
|
| 1109 |
+
"Samoyed, Samoyede": 258,
|
| 1110 |
+
"Scotch terrier, Scottish terrier, Scottie": 199,
|
| 1111 |
+
"Scottish deerhound, deerhound": 177,
|
| 1112 |
+
"Sealyham terrier, Sealyham": 190,
|
| 1113 |
+
"Shetland sheepdog, Shetland sheep dog, Shetland": 230,
|
| 1114 |
+
"Shih-Tzu": 155,
|
| 1115 |
+
"Siamese cat, Siamese": 284,
|
| 1116 |
+
"Siberian husky": 250,
|
| 1117 |
+
"Staffordshire bullterrier, Staffordshire bull terrier": 179,
|
| 1118 |
+
"Sussex spaniel": 220,
|
| 1119 |
+
"Tibetan mastiff": 244,
|
| 1120 |
+
"Tibetan terrier, chrysanthemum dog": 200,
|
| 1121 |
+
"Walker hound, Walker foxhound": 166,
|
| 1122 |
+
"Weimaraner": 178,
|
| 1123 |
+
"Welsh springer spaniel": 218,
|
| 1124 |
+
"West Highland white terrier": 203,
|
| 1125 |
+
"Windsor tie": 906,
|
| 1126 |
+
"Yorkshire terrier": 187,
|
| 1127 |
+
"abacus": 398,
|
| 1128 |
+
"abaya": 399,
|
| 1129 |
+
"academic gown, academic robe, judge's robe": 400,
|
| 1130 |
+
"accordion, piano accordion, squeeze box": 401,
|
| 1131 |
+
"acorn": 988,
|
| 1132 |
+
"acorn squash": 941,
|
| 1133 |
+
"acoustic guitar": 402,
|
| 1134 |
+
"admiral": 321,
|
| 1135 |
+
"affenpinscher, monkey pinscher, monkey dog": 252,
|
| 1136 |
+
"agama": 42,
|
| 1137 |
+
"agaric": 992,
|
| 1138 |
+
"aircraft carrier, carrier, flattop, attack aircraft carrier": 403,
|
| 1139 |
+
"airliner": 404,
|
| 1140 |
+
"airship, dirigible": 405,
|
| 1141 |
+
"albatross, mollymawk": 146,
|
| 1142 |
+
"alligator lizard": 44,
|
| 1143 |
+
"alp": 970,
|
| 1144 |
+
"altar": 406,
|
| 1145 |
+
"ambulance": 407,
|
| 1146 |
+
"amphibian, amphibious vehicle": 408,
|
| 1147 |
+
"analog clock": 409,
|
| 1148 |
+
"anemone fish": 393,
|
| 1149 |
+
"ant, emmet, pismire": 310,
|
| 1150 |
+
"apiary, bee house": 410,
|
| 1151 |
+
"apron": 411,
|
| 1152 |
+
"armadillo": 363,
|
| 1153 |
+
"artichoke, globe artichoke": 944,
|
| 1154 |
+
"ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin": 412,
|
| 1155 |
+
"assault rifle, assault gun": 413,
|
| 1156 |
+
"axolotl, mud puppy, Ambystoma mexicanum": 29,
|
| 1157 |
+
"baboon": 372,
|
| 1158 |
+
"backpack, back pack, knapsack, packsack, rucksack, haversack": 414,
|
| 1159 |
+
"badger": 362,
|
| 1160 |
+
"bagel, beigel": 931,
|
| 1161 |
+
"bakery, bakeshop, bakehouse": 415,
|
| 1162 |
+
"balance beam, beam": 416,
|
| 1163 |
+
"bald eagle, American eagle, Haliaeetus leucocephalus": 22,
|
| 1164 |
+
"balloon": 417,
|
| 1165 |
+
"ballplayer, baseball player": 981,
|
| 1166 |
+
"ballpoint, ballpoint pen, ballpen, Biro": 418,
|
| 1167 |
+
"banana": 954,
|
| 1168 |
+
"banded gecko": 38,
|
| 1169 |
+
"banjo": 420,
|
| 1170 |
+
"bannister, banister, balustrade, balusters, handrail": 421,
|
| 1171 |
+
"barbell": 422,
|
| 1172 |
+
"barber chair": 423,
|
| 1173 |
+
"barbershop": 424,
|
| 1174 |
+
"barn": 425,
|
| 1175 |
+
"barn spider, Araneus cavaticus": 73,
|
| 1176 |
+
"barometer": 426,
|
| 1177 |
+
"barracouta, snoek": 389,
|
| 1178 |
+
"barrel, cask": 427,
|
| 1179 |
+
"barrow, garden cart, lawn cart, wheelbarrow": 428,
|
| 1180 |
+
"baseball": 429,
|
| 1181 |
+
"basenji": 253,
|
| 1182 |
+
"basketball": 430,
|
| 1183 |
+
"basset, basset hound": 161,
|
| 1184 |
+
"bassinet": 431,
|
| 1185 |
+
"bassoon": 432,
|
| 1186 |
+
"bath towel": 434,
|
| 1187 |
+
"bathing cap, swimming cap": 433,
|
| 1188 |
+
"bathtub, bathing tub, bath, tub": 435,
|
| 1189 |
+
"beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon": 436,
|
| 1190 |
+
"beacon, lighthouse, beacon light, pharos": 437,
|
| 1191 |
+
"beagle": 162,
|
| 1192 |
+
"beaker": 438,
|
| 1193 |
+
"bearskin, busby, shako": 439,
|
| 1194 |
+
"beaver": 337,
|
| 1195 |
+
"bee": 309,
|
| 1196 |
+
"bee eater": 92,
|
| 1197 |
+
"beer bottle": 440,
|
| 1198 |
+
"beer glass": 441,
|
| 1199 |
+
"bell cote, bell cot": 442,
|
| 1200 |
+
"bell pepper": 945,
|
| 1201 |
+
"bib": 443,
|
| 1202 |
+
"bicycle-built-for-two, tandem bicycle, tandem": 444,
|
| 1203 |
+
"bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis": 349,
|
| 1204 |
+
"bikini, two-piece": 445,
|
| 1205 |
+
"binder, ring-binder": 446,
|
| 1206 |
+
"binoculars, field glasses, opera glasses": 447,
|
| 1207 |
+
"birdhouse": 448,
|
| 1208 |
+
"bison": 347,
|
| 1209 |
+
"bittern": 133,
|
| 1210 |
+
"black and gold garden spider, Argiope aurantia": 72,
|
| 1211 |
+
"black grouse": 80,
|
| 1212 |
+
"black stork, Ciconia nigra": 128,
|
| 1213 |
+
"black swan, Cygnus atratus": 100,
|
| 1214 |
+
"black widow, Latrodectus mactans": 75,
|
| 1215 |
+
"black-and-tan coonhound": 165,
|
| 1216 |
+
"black-footed ferret, ferret, Mustela nigripes": 359,
|
| 1217 |
+
"bloodhound, sleuthhound": 163,
|
| 1218 |
+
"bluetick": 164,
|
| 1219 |
+
"boa constrictor, Constrictor constrictor": 61,
|
| 1220 |
+
"boathouse": 449,
|
| 1221 |
+
"bobsled, bobsleigh, bob": 450,
|
| 1222 |
+
"bolete": 997,
|
| 1223 |
+
"bolo tie, bolo, bola tie, bola": 451,
|
| 1224 |
+
"bonnet, poke bonnet": 452,
|
| 1225 |
+
"book jacket, dust cover, dust jacket, dust wrapper": 921,
|
| 1226 |
+
"bookcase": 453,
|
| 1227 |
+
"bookshop, bookstore, bookstall": 454,
|
| 1228 |
+
"borzoi, Russian wolfhound": 169,
|
| 1229 |
+
"bottlecap": 455,
|
| 1230 |
+
"bow": 456,
|
| 1231 |
+
"bow tie, bow-tie, bowtie": 457,
|
| 1232 |
+
"box turtle, box tortoise": 37,
|
| 1233 |
+
"boxer": 242,
|
| 1234 |
+
"brain coral": 109,
|
| 1235 |
+
"brambling, Fringilla montifringilla": 10,
|
| 1236 |
+
"brass, memorial tablet, plaque": 458,
|
| 1237 |
+
"brassiere, bra, bandeau": 459,
|
| 1238 |
+
"breakwater, groin, groyne, mole, bulwark, seawall, jetty": 460,
|
| 1239 |
+
"breastplate, aegis, egis": 461,
|
| 1240 |
+
"briard": 226,
|
| 1241 |
+
"broccoli": 937,
|
| 1242 |
+
"broom": 462,
|
| 1243 |
+
"brown bear, bruin, Ursus arctos": 294,
|
| 1244 |
+
"bubble": 971,
|
| 1245 |
+
"bucket, pail": 463,
|
| 1246 |
+
"buckeye, horse chestnut, conker": 990,
|
| 1247 |
+
"buckle": 464,
|
| 1248 |
+
"bulbul": 16,
|
| 1249 |
+
"bull mastiff": 243,
|
| 1250 |
+
"bullet train, bullet": 466,
|
| 1251 |
+
"bulletproof vest": 465,
|
| 1252 |
+
"bullfrog, Rana catesbeiana": 30,
|
| 1253 |
+
"burrito": 965,
|
| 1254 |
+
"bustard": 138,
|
| 1255 |
+
"butcher shop, meat market": 467,
|
| 1256 |
+
"butternut squash": 942,
|
| 1257 |
+
"cab, hack, taxi, taxicab": 468,
|
| 1258 |
+
"cabbage butterfly": 324,
|
| 1259 |
+
"cairn, cairn terrier": 192,
|
| 1260 |
+
"caldron, cauldron": 469,
|
| 1261 |
+
"can opener, tin opener": 473,
|
| 1262 |
+
"candle, taper, wax light": 470,
|
| 1263 |
+
"cannon": 471,
|
| 1264 |
+
"canoe": 472,
|
| 1265 |
+
"capuchin, ringtail, Cebus capucinus": 378,
|
| 1266 |
+
"car mirror": 475,
|
| 1267 |
+
"car wheel": 479,
|
| 1268 |
+
"carbonara": 959,
|
| 1269 |
+
"cardigan": 474,
|
| 1270 |
+
"cardoon": 946,
|
| 1271 |
+
"carousel, carrousel, merry-go-round, roundabout, whirligig": 476,
|
| 1272 |
+
"carpenter's kit, tool kit": 477,
|
| 1273 |
+
"carton": 478,
|
| 1274 |
+
"cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM": 480,
|
| 1275 |
+
"cassette": 481,
|
| 1276 |
+
"cassette player": 482,
|
| 1277 |
+
"castle": 483,
|
| 1278 |
+
"catamaran": 484,
|
| 1279 |
+
"cauliflower": 938,
|
| 1280 |
+
"cello, violoncello": 486,
|
| 1281 |
+
"cellular telephone, cellular phone, cellphone, cell, mobile phone": 487,
|
| 1282 |
+
"centipede": 79,
|
| 1283 |
+
"chain": 488,
|
| 1284 |
+
"chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour": 490,
|
| 1285 |
+
"chain saw, chainsaw": 491,
|
| 1286 |
+
"chainlink fence": 489,
|
| 1287 |
+
"chambered nautilus, pearly nautilus, nautilus": 117,
|
| 1288 |
+
"cheeseburger": 933,
|
| 1289 |
+
"cheetah, chetah, Acinonyx jubatus": 293,
|
| 1290 |
+
"chest": 492,
|
| 1291 |
+
"chickadee": 19,
|
| 1292 |
+
"chiffonier, commode": 493,
|
| 1293 |
+
"chime, bell, gong": 494,
|
| 1294 |
+
"chimpanzee, chimp, Pan troglodytes": 367,
|
| 1295 |
+
"china cabinet, china closet": 495,
|
| 1296 |
+
"chiton, coat-of-mail shell, sea cradle, polyplacophore": 116,
|
| 1297 |
+
"chocolate sauce, chocolate syrup": 960,
|
| 1298 |
+
"chow, chow chow": 260,
|
| 1299 |
+
"church, church building": 497,
|
| 1300 |
+
"cicada, cicala": 316,
|
| 1301 |
+
"cinema, movie theater, movie theatre, movie house, picture palace": 498,
|
| 1302 |
+
"cleaver, meat cleaver, chopper": 499,
|
| 1303 |
+
"cliff dwelling": 500,
|
| 1304 |
+
"cliff, drop, drop-off": 972,
|
| 1305 |
+
"cloak": 501,
|
| 1306 |
+
"clog, geta, patten, sabot": 502,
|
| 1307 |
+
"clumber, clumber spaniel": 216,
|
| 1308 |
+
"cock": 7,
|
| 1309 |
+
"cocker spaniel, English cocker spaniel, cocker": 219,
|
| 1310 |
+
"cockroach, roach": 314,
|
| 1311 |
+
"cocktail shaker": 503,
|
| 1312 |
+
"coffee mug": 504,
|
| 1313 |
+
"coffeepot": 505,
|
| 1314 |
+
"coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch": 391,
|
| 1315 |
+
"coil, spiral, volute, whorl, helix": 506,
|
| 1316 |
+
"collie": 231,
|
| 1317 |
+
"colobus, colobus monkey": 375,
|
| 1318 |
+
"combination lock": 507,
|
| 1319 |
+
"comic book": 917,
|
| 1320 |
+
"common iguana, iguana, Iguana iguana": 39,
|
| 1321 |
+
"common newt, Triturus vulgaris": 26,
|
| 1322 |
+
"computer keyboard, keypad": 508,
|
| 1323 |
+
"conch": 112,
|
| 1324 |
+
"confectionery, confectionary, candy store": 509,
|
| 1325 |
+
"consomme": 925,
|
| 1326 |
+
"container ship, containership, container vessel": 510,
|
| 1327 |
+
"convertible": 511,
|
| 1328 |
+
"coral fungus": 991,
|
| 1329 |
+
"coral reef": 973,
|
| 1330 |
+
"corkscrew, bottle screw": 512,
|
| 1331 |
+
"corn": 987,
|
| 1332 |
+
"cornet, horn, trumpet, trump": 513,
|
| 1333 |
+
"coucal": 91,
|
| 1334 |
+
"cougar, puma, catamount, mountain lion, painter, panther, Felis concolor": 286,
|
| 1335 |
+
"cowboy boot": 514,
|
| 1336 |
+
"cowboy hat, ten-gallon hat": 515,
|
| 1337 |
+
"coyote, prairie wolf, brush wolf, Canis latrans": 272,
|
| 1338 |
+
"cradle": 516,
|
| 1339 |
+
"crane": 517,
|
| 1340 |
+
"crash helmet": 518,
|
| 1341 |
+
"crate": 519,
|
| 1342 |
+
"crayfish, crawfish, crawdad, crawdaddy": 124,
|
| 1343 |
+
"crib, cot": 520,
|
| 1344 |
+
"cricket": 312,
|
| 1345 |
+
"croquet ball": 522,
|
| 1346 |
+
"crossword puzzle, crossword": 918,
|
| 1347 |
+
"crutch": 523,
|
| 1348 |
+
"cucumber, cuke": 943,
|
| 1349 |
+
"cuirass": 524,
|
| 1350 |
+
"cup": 968,
|
| 1351 |
+
"curly-coated retriever": 206,
|
| 1352 |
+
"custard apple": 956,
|
| 1353 |
+
"daisy": 985,
|
| 1354 |
+
"dalmatian, coach dog, carriage dog": 251,
|
| 1355 |
+
"dam, dike, dyke": 525,
|
| 1356 |
+
"damselfly": 320,
|
| 1357 |
+
"desk": 526,
|
| 1358 |
+
"desktop computer": 527,
|
| 1359 |
+
"dhole, Cuon alpinus": 274,
|
| 1360 |
+
"dial telephone, dial phone": 528,
|
| 1361 |
+
"diamondback, diamondback rattlesnake, Crotalus adamanteus": 67,
|
| 1362 |
+
"diaper, nappy, napkin": 529,
|
| 1363 |
+
"digital clock": 530,
|
| 1364 |
+
"digital watch": 531,
|
| 1365 |
+
"dingo, warrigal, warragal, Canis dingo": 273,
|
| 1366 |
+
"dining table, board": 532,
|
| 1367 |
+
"dishrag, dishcloth": 533,
|
| 1368 |
+
"dishwasher, dish washer, dishwashing machine": 534,
|
| 1369 |
+
"disk brake, disc brake": 535,
|
| 1370 |
+
"dock, dockage, docking facility": 536,
|
| 1371 |
+
"dogsled, dog sled, dog sleigh": 537,
|
| 1372 |
+
"dome": 538,
|
| 1373 |
+
"doormat, welcome mat": 539,
|
| 1374 |
+
"dough": 961,
|
| 1375 |
+
"dowitcher": 142,
|
| 1376 |
+
"dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk": 319,
|
| 1377 |
+
"drake": 97,
|
| 1378 |
+
"drilling platform, offshore rig": 540,
|
| 1379 |
+
"drum, membranophone, tympan": 541,
|
| 1380 |
+
"drumstick": 542,
|
| 1381 |
+
"dugong, Dugong dugon": 149,
|
| 1382 |
+
"dumbbell": 543,
|
| 1383 |
+
"dung beetle": 305,
|
| 1384 |
+
"ear, spike, capitulum": 998,
|
| 1385 |
+
"earthstar": 995,
|
| 1386 |
+
"echidna, spiny anteater, anteater": 102,
|
| 1387 |
+
"eel": 390,
|
| 1388 |
+
"eft": 27,
|
| 1389 |
+
"eggnog": 969,
|
| 1390 |
+
"electric fan, blower": 545,
|
| 1391 |
+
"electric guitar": 546,
|
| 1392 |
+
"electric locomotive": 547,
|
| 1393 |
+
"electric ray, crampfish, numbfish, torpedo": 5,
|
| 1394 |
+
"entertainment center": 548,
|
| 1395 |
+
"envelope": 549,
|
| 1396 |
+
"espresso": 967,
|
| 1397 |
+
"espresso maker": 550,
|
| 1398 |
+
"face powder": 551,
|
| 1399 |
+
"feather boa, boa": 552,
|
| 1400 |
+
"fiddler crab": 120,
|
| 1401 |
+
"fig": 952,
|
| 1402 |
+
"file, file cabinet, filing cabinet": 553,
|
| 1403 |
+
"fire engine, fire truck": 555,
|
| 1404 |
+
"fire screen, fireguard": 556,
|
| 1405 |
+
"fireboat": 554,
|
| 1406 |
+
"flagpole, flagstaff": 557,
|
| 1407 |
+
"flamingo": 130,
|
| 1408 |
+
"flat-coated retriever": 205,
|
| 1409 |
+
"flatworm, platyhelminth": 110,
|
| 1410 |
+
"flute, transverse flute": 558,
|
| 1411 |
+
"fly": 308,
|
| 1412 |
+
"folding chair": 559,
|
| 1413 |
+
"football helmet": 560,
|
| 1414 |
+
"forklift": 561,
|
| 1415 |
+
"fountain": 562,
|
| 1416 |
+
"fountain pen": 563,
|
| 1417 |
+
"four-poster": 564,
|
| 1418 |
+
"fox squirrel, eastern fox squirrel, Sciurus niger": 335,
|
| 1419 |
+
"freight car": 565,
|
| 1420 |
+
"frilled lizard, Chlamydosaurus kingi": 43,
|
| 1421 |
+
"frying pan, frypan, skillet": 567,
|
| 1422 |
+
"fur coat": 568,
|
| 1423 |
+
"gar, garfish, garpike, billfish, Lepisosteus osseus": 395,
|
| 1424 |
+
"garbage truck, dustcart": 569,
|
| 1425 |
+
"garden spider, Aranea diademata": 74,
|
| 1426 |
+
"garter snake, grass snake": 57,
|
| 1427 |
+
"gas pump, gasoline pump, petrol pump, island dispenser": 571,
|
| 1428 |
+
"gasmask, respirator, gas helmet": 570,
|
| 1429 |
+
"gazelle": 353,
|
| 1430 |
+
"geyser": 974,
|
| 1431 |
+
"giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca": 388,
|
| 1432 |
+
"giant schnauzer": 197,
|
| 1433 |
+
"gibbon, Hylobates lar": 368,
|
| 1434 |
+
"go-kart": 573,
|
| 1435 |
+
"goblet": 572,
|
| 1436 |
+
"golden retriever": 207,
|
| 1437 |
+
"goldfinch, Carduelis carduelis": 11,
|
| 1438 |
+
"goldfish, Carassius auratus": 1,
|
| 1439 |
+
"golf ball": 574,
|
| 1440 |
+
"golfcart, golf cart": 575,
|
| 1441 |
+
"gondola": 576,
|
| 1442 |
+
"gong, tam-tam": 577,
|
| 1443 |
+
"goose": 99,
|
| 1444 |
+
"gorilla, Gorilla gorilla": 366,
|
| 1445 |
+
"gown": 578,
|
| 1446 |
+
"grand piano, grand": 579,
|
| 1447 |
+
"grasshopper, hopper": 311,
|
| 1448 |
+
"great grey owl, great gray owl, Strix nebulosa": 24,
|
| 1449 |
+
"great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias": 2,
|
| 1450 |
+
"green lizard, Lacerta viridis": 46,
|
| 1451 |
+
"green mamba": 64,
|
| 1452 |
+
"green snake, grass snake": 55,
|
| 1453 |
+
"greenhouse, nursery, glasshouse": 580,
|
| 1454 |
+
"grey fox, gray fox, Urocyon cinereoargenteus": 280,
|
| 1455 |
+
"grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus": 147,
|
| 1456 |
+
"grille, radiator grille": 581,
|
| 1457 |
+
"grocery store, grocery, food market, market": 582,
|
| 1458 |
+
"groenendael": 224,
|
| 1459 |
+
"groom, bridegroom": 982,
|
| 1460 |
+
"ground beetle, carabid beetle": 302,
|
| 1461 |
+
"guacamole": 924,
|
| 1462 |
+
"guenon, guenon monkey": 370,
|
| 1463 |
+
"guillotine": 583,
|
| 1464 |
+
"guinea pig, Cavia cobaya": 338,
|
| 1465 |
+
"gyromitra": 993,
|
| 1466 |
+
"hair slide": 584,
|
| 1467 |
+
"hair spray": 585,
|
| 1468 |
+
"half track": 586,
|
| 1469 |
+
"hammer": 587,
|
| 1470 |
+
"hammerhead, hammerhead shark": 4,
|
| 1471 |
+
"hamper": 588,
|
| 1472 |
+
"hamster": 333,
|
| 1473 |
+
"hand blower, blow dryer, blow drier, hair dryer, hair drier": 589,
|
| 1474 |
+
"hand-held computer, hand-held microcomputer": 590,
|
| 1475 |
+
"handkerchief, hankie, hanky, hankey": 591,
|
| 1476 |
+
"hard disc, hard disk, fixed disk": 592,
|
| 1477 |
+
"hare": 331,
|
| 1478 |
+
"harmonica, mouth organ, harp, mouth harp": 593,
|
| 1479 |
+
"harp": 594,
|
| 1480 |
+
"hartebeest": 351,
|
| 1481 |
+
"harvester, reaper": 595,
|
| 1482 |
+
"harvestman, daddy longlegs, Phalangium opilio": 70,
|
| 1483 |
+
"hatchet": 596,
|
| 1484 |
+
"hay": 958,
|
| 1485 |
+
"head cabbage": 936,
|
| 1486 |
+
"hen": 8,
|
| 1487 |
+
"hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa": 996,
|
| 1488 |
+
"hermit crab": 125,
|
| 1489 |
+
"hip, rose hip, rosehip": 989,
|
| 1490 |
+
"hippopotamus, hippo, river horse, Hippopotamus amphibius": 344,
|
| 1491 |
+
"hog, pig, grunter, squealer, Sus scrofa": 341,
|
| 1492 |
+
"hognose snake, puff adder, sand viper": 54,
|
| 1493 |
+
"holster": 597,
|
| 1494 |
+
"home theater, home theatre": 598,
|
| 1495 |
+
"honeycomb": 599,
|
| 1496 |
+
"hook, claw": 600,
|
| 1497 |
+
"hoopskirt, crinoline": 601,
|
| 1498 |
+
"horizontal bar, high bar": 602,
|
| 1499 |
+
"hornbill": 93,
|
| 1500 |
+
"horned viper, cerastes, sand viper, horned asp, Cerastes cornutus": 66,
|
| 1501 |
+
"horse cart, horse-cart": 603,
|
| 1502 |
+
"hot pot, hotpot": 926,
|
| 1503 |
+
"hotdog, hot dog, red hot": 934,
|
| 1504 |
+
"hourglass": 604,
|
| 1505 |
+
"house finch, linnet, Carpodacus mexicanus": 12,
|
| 1506 |
+
"howler monkey, howler": 379,
|
| 1507 |
+
"hummingbird": 94,
|
| 1508 |
+
"hyena, hyaena": 276,
|
| 1509 |
+
"iPod": 605,
|
| 1510 |
+
"ibex, Capra ibex": 350,
|
| 1511 |
+
"ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus": 296,
|
| 1512 |
+
"ice cream, icecream": 928,
|
| 1513 |
+
"ice lolly, lolly, lollipop, popsicle": 929,
|
| 1514 |
+
"impala, Aepyceros melampus": 352,
|
| 1515 |
+
"indigo bunting, indigo finch, indigo bird, Passerina cyanea": 14,
|
| 1516 |
+
"indri, indris, Indri indri, Indri brevicaudatus": 384,
|
| 1517 |
+
"iron, smoothing iron": 606,
|
| 1518 |
+
"isopod": 126,
|
| 1519 |
+
"jacamar": 95,
|
| 1520 |
+
"jack-o'-lantern": 607,
|
| 1521 |
+
"jackfruit, jak, jack": 955,
|
| 1522 |
+
"jaguar, panther, Panthera onca, Felis onca": 290,
|
| 1523 |
+
"jay": 17,
|
| 1524 |
+
"jean, blue jean, denim": 608,
|
| 1525 |
+
"jeep, landrover": 609,
|
| 1526 |
+
"jellyfish": 107,
|
| 1527 |
+
"jersey, T-shirt, tee shirt": 610,
|
| 1528 |
+
"jigsaw puzzle": 611,
|
| 1529 |
+
"jinrikisha, ricksha, rickshaw": 612,
|
| 1530 |
+
"joystick": 613,
|
| 1531 |
+
"junco, snowbird": 13,
|
| 1532 |
+
"keeshond": 261,
|
| 1533 |
+
"kelpie": 227,
|
| 1534 |
+
"killer whale, killer, orca, grampus, sea wolf, Orcinus orca": 148,
|
| 1535 |
+
"kimono": 614,
|
| 1536 |
+
"king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica": 121,
|
| 1537 |
+
"king penguin, Aptenodytes patagonica": 145,
|
| 1538 |
+
"king snake, kingsnake": 56,
|
| 1539 |
+
"kit fox, Vulpes macrotis": 278,
|
| 1540 |
+
"kite": 21,
|
| 1541 |
+
"knee pad": 615,
|
| 1542 |
+
"knot": 616,
|
| 1543 |
+
"koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus": 105,
|
| 1544 |
+
"komondor": 228,
|
| 1545 |
+
"kuvasz": 222,
|
| 1546 |
+
"lab coat, laboratory coat": 617,
|
| 1547 |
+
"lacewing, lacewing fly": 318,
|
| 1548 |
+
"ladle": 618,
|
| 1549 |
+
"ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle": 301,
|
| 1550 |
+
"lakeside, lakeshore": 975,
|
| 1551 |
+
"lampshade, lamp shade": 619,
|
| 1552 |
+
"langur": 374,
|
| 1553 |
+
"laptop, laptop computer": 620,
|
| 1554 |
+
"lawn mower, mower": 621,
|
| 1555 |
+
"leaf beetle, chrysomelid": 304,
|
| 1556 |
+
"leafhopper": 317,
|
| 1557 |
+
"leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea": 34,
|
| 1558 |
+
"lemon": 951,
|
| 1559 |
+
"lens cap, lens cover": 622,
|
| 1560 |
+
"leopard, Panthera pardus": 288,
|
| 1561 |
+
"lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens": 387,
|
| 1562 |
+
"letter opener, paper knife, paperknife": 623,
|
| 1563 |
+
"library": 624,
|
| 1564 |
+
"lifeboat": 625,
|
| 1565 |
+
"lighter, light, igniter, ignitor": 626,
|
| 1566 |
+
"limousine, limo": 627,
|
| 1567 |
+
"limpkin, Aramus pictus": 135,
|
| 1568 |
+
"liner, ocean liner": 628,
|
| 1569 |
+
"lion, king of beasts, Panthera leo": 291,
|
| 1570 |
+
"lionfish": 396,
|
| 1571 |
+
"lipstick, lip rouge": 629,
|
| 1572 |
+
"little blue heron, Egretta caerulea": 131,
|
| 1573 |
+
"llama": 355,
|
| 1574 |
+
"loggerhead, loggerhead turtle, Caretta caretta": 33,
|
| 1575 |
+
"long-horned beetle, longicorn, longicorn beetle": 303,
|
| 1576 |
+
"lorikeet": 90,
|
| 1577 |
+
"lotion": 631,
|
| 1578 |
+
"loudspeaker, speaker, speaker unit, loudspeaker system, speaker system": 632,
|
| 1579 |
+
"loupe, jeweler's loupe": 633,
|
| 1580 |
+
"lumbermill, sawmill": 634,
|
| 1581 |
+
"lycaenid, lycaenid butterfly": 326,
|
| 1582 |
+
"lynx, catamount": 287,
|
| 1583 |
+
"macaque": 373,
|
| 1584 |
+
"macaw": 88,
|
| 1585 |
+
"magnetic compass": 635,
|
| 1586 |
+
"magpie": 18,
|
| 1587 |
+
"mailbag, postbag": 636,
|
| 1588 |
+
"mailbox, letter box": 637,
|
| 1589 |
+
"maillot": 638,
|
| 1590 |
+
"maillot, tank suit": 639,
|
| 1591 |
+
"malamute, malemute, Alaskan malamute": 249,
|
| 1592 |
+
"malinois": 225,
|
| 1593 |
+
"manhole cover": 640,
|
| 1594 |
+
"mantis, mantid": 315,
|
| 1595 |
+
"maraca": 641,
|
| 1596 |
+
"marimba, xylophone": 642,
|
| 1597 |
+
"marmoset": 377,
|
| 1598 |
+
"marmot": 336,
|
| 1599 |
+
"mashed potato": 935,
|
| 1600 |
+
"mask": 643,
|
| 1601 |
+
"matchstick": 644,
|
| 1602 |
+
"maypole": 645,
|
| 1603 |
+
"maze, labyrinth": 646,
|
| 1604 |
+
"measuring cup": 647,
|
| 1605 |
+
"meat loaf, meatloaf": 962,
|
| 1606 |
+
"medicine chest, medicine cabinet": 648,
|
| 1607 |
+
"meerkat, mierkat": 299,
|
| 1608 |
+
"megalith, megalithic structure": 649,
|
| 1609 |
+
"menu": 922,
|
| 1610 |
+
"microphone, mike": 650,
|
| 1611 |
+
"microwave, microwave oven": 651,
|
| 1612 |
+
"military uniform": 652,
|
| 1613 |
+
"milk can": 653,
|
| 1614 |
+
"miniature pinscher": 237,
|
| 1615 |
+
"miniature poodle": 266,
|
| 1616 |
+
"miniature schnauzer": 196,
|
| 1617 |
+
"minibus": 654,
|
| 1618 |
+
"miniskirt, mini": 655,
|
| 1619 |
+
"minivan": 656,
|
| 1620 |
+
"mink": 357,
|
| 1621 |
+
"missile": 657,
|
| 1622 |
+
"mitten": 658,
|
| 1623 |
+
"mixing bowl": 659,
|
| 1624 |
+
"mobile home, manufactured home": 660,
|
| 1625 |
+
"modem": 662,
|
| 1626 |
+
"monarch, monarch butterfly, milkweed butterfly, Danaus plexippus": 323,
|
| 1627 |
+
"monastery": 663,
|
| 1628 |
+
"mongoose": 298,
|
| 1629 |
+
"monitor": 664,
|
| 1630 |
+
"moped": 665,
|
| 1631 |
+
"mortar": 666,
|
| 1632 |
+
"mortarboard": 667,
|
| 1633 |
+
"mosque": 668,
|
| 1634 |
+
"mosquito net": 669,
|
| 1635 |
+
"motor scooter, scooter": 670,
|
| 1636 |
+
"mountain bike, all-terrain bike, off-roader": 671,
|
| 1637 |
+
"mountain tent": 672,
|
| 1638 |
+
"mouse, computer mouse": 673,
|
| 1639 |
+
"mousetrap": 674,
|
| 1640 |
+
"moving van": 675,
|
| 1641 |
+
"mud turtle": 35,
|
| 1642 |
+
"mushroom": 947,
|
| 1643 |
+
"muzzle": 676,
|
| 1644 |
+
"nail": 677,
|
| 1645 |
+
"neck brace": 678,
|
| 1646 |
+
"necklace": 679,
|
| 1647 |
+
"nematode, nematode worm, roundworm": 111,
|
| 1648 |
+
"night snake, Hypsiglena torquata": 60,
|
| 1649 |
+
"nipple": 680,
|
| 1650 |
+
"notebook, notebook computer": 681,
|
| 1651 |
+
"obelisk": 682,
|
| 1652 |
+
"oboe, hautboy, hautbois": 683,
|
| 1653 |
+
"ocarina, sweet potato": 684,
|
| 1654 |
+
"odometer, hodometer, mileometer, milometer": 685,
|
| 1655 |
+
"oil filter": 686,
|
| 1656 |
+
"orange": 950,
|
| 1657 |
+
"orangutan, orang, orangutang, Pongo pygmaeus": 365,
|
| 1658 |
+
"organ, pipe organ": 687,
|
| 1659 |
+
"oscilloscope, scope, cathode-ray oscilloscope, CRO": 688,
|
| 1660 |
+
"ostrich, Struthio camelus": 9,
|
| 1661 |
+
"otter": 360,
|
| 1662 |
+
"otterhound, otter hound": 175,
|
| 1663 |
+
"overskirt": 689,
|
| 1664 |
+
"ox": 345,
|
| 1665 |
+
"oxcart": 690,
|
| 1666 |
+
"oxygen mask": 691,
|
| 1667 |
+
"oystercatcher, oyster catcher": 143,
|
| 1668 |
+
"packet": 692,
|
| 1669 |
+
"paddle, boat paddle": 693,
|
| 1670 |
+
"paddlewheel, paddle wheel": 694,
|
| 1671 |
+
"padlock": 695,
|
| 1672 |
+
"paintbrush": 696,
|
| 1673 |
+
"pajama, pyjama, pj's, jammies": 697,
|
| 1674 |
+
"palace": 698,
|
| 1675 |
+
"panpipe, pandean pipe, syrinx": 699,
|
| 1676 |
+
"paper towel": 700,
|
| 1677 |
+
"papillon": 157,
|
| 1678 |
+
"parachute, chute": 701,
|
| 1679 |
+
"parallel bars, bars": 702,
|
| 1680 |
+
"park bench": 703,
|
| 1681 |
+
"parking meter": 704,
|
| 1682 |
+
"partridge": 86,
|
| 1683 |
+
"passenger car, coach, carriage": 705,
|
| 1684 |
+
"patas, hussar monkey, Erythrocebus patas": 371,
|
| 1685 |
+
"patio, terrace": 706,
|
| 1686 |
+
"pay-phone, pay-station": 707,
|
| 1687 |
+
"peacock": 84,
|
| 1688 |
+
"pedestal, plinth, footstall": 708,
|
| 1689 |
+
"pelican": 144,
|
| 1690 |
+
"pencil box, pencil case": 709,
|
| 1691 |
+
"pencil sharpener": 710,
|
| 1692 |
+
"perfume, essence": 711,
|
| 1693 |
+
"photocopier": 713,
|
| 1694 |
+
"pick, plectrum, plectron": 714,
|
| 1695 |
+
"pickelhaube": 715,
|
| 1696 |
+
"picket fence, paling": 716,
|
| 1697 |
+
"pickup, pickup truck": 717,
|
| 1698 |
+
"pier": 718,
|
| 1699 |
+
"piggy bank, penny bank": 719,
|
| 1700 |
+
"pill bottle": 720,
|
| 1701 |
+
"pillow": 721,
|
| 1702 |
+
"pineapple, ananas": 953,
|
| 1703 |
+
"ping-pong ball": 722,
|
| 1704 |
+
"pinwheel": 723,
|
| 1705 |
+
"pirate, pirate ship": 724,
|
| 1706 |
+
"pitcher, ewer": 725,
|
| 1707 |
+
"pizza, pizza pie": 963,
|
| 1708 |
+
"plane, carpenter's plane, woodworking plane": 726,
|
| 1709 |
+
"planetarium": 727,
|
| 1710 |
+
"plastic bag": 728,
|
| 1711 |
+
"plate": 923,
|
| 1712 |
+
"plate rack": 729,
|
| 1713 |
+
"platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus": 103,
|
| 1714 |
+
"plow, plough": 730,
|
| 1715 |
+
"plunger, plumber's helper": 731,
|
| 1716 |
+
"pole": 733,
|
| 1717 |
+
"polecat, fitch, foulmart, foumart, Mustela putorius": 358,
|
| 1718 |
+
"police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria": 734,
|
| 1719 |
+
"pomegranate": 957,
|
| 1720 |
+
"poncho": 735,
|
| 1721 |
+
"pool table, billiard table, snooker table": 736,
|
| 1722 |
+
"pop bottle, soda bottle": 737,
|
| 1723 |
+
"porcupine, hedgehog": 334,
|
| 1724 |
+
"pot, flowerpot": 738,
|
| 1725 |
+
"potpie": 964,
|
| 1726 |
+
"potter's wheel": 739,
|
| 1727 |
+
"power drill": 740,
|
| 1728 |
+
"prairie chicken, prairie grouse, prairie fowl": 83,
|
| 1729 |
+
"prayer rug, prayer mat": 741,
|
| 1730 |
+
"pretzel": 932,
|
| 1731 |
+
"printer": 742,
|
| 1732 |
+
"prison, prison house": 743,
|
| 1733 |
+
"proboscis monkey, Nasalis larvatus": 376,
|
| 1734 |
+
"projectile, missile": 744,
|
| 1735 |
+
"projector": 745,
|
| 1736 |
+
"promontory, headland, head, foreland": 976,
|
| 1737 |
+
"ptarmigan": 81,
|
| 1738 |
+
"puck, hockey puck": 746,
|
| 1739 |
+
"puffer, pufferfish, blowfish, globefish": 397,
|
| 1740 |
+
"pug, pug-dog": 254,
|
| 1741 |
+
"punching bag, punch bag, punching ball, punchball": 747,
|
| 1742 |
+
"purse": 748,
|
| 1743 |
+
"quail": 85,
|
| 1744 |
+
"quill, quill pen": 749,
|
| 1745 |
+
"quilt, comforter, comfort, puff": 750,
|
| 1746 |
+
"racer, race car, racing car": 751,
|
| 1747 |
+
"racket, racquet": 752,
|
| 1748 |
+
"radiator": 753,
|
| 1749 |
+
"radio telescope, radio reflector": 755,
|
| 1750 |
+
"radio, wireless": 754,
|
| 1751 |
+
"rain barrel": 756,
|
| 1752 |
+
"ram, tup": 348,
|
| 1753 |
+
"rapeseed": 984,
|
| 1754 |
+
"recreational vehicle, RV, R.V.": 757,
|
| 1755 |
+
"red fox, Vulpes vulpes": 277,
|
| 1756 |
+
"red wine": 966,
|
| 1757 |
+
"red wolf, maned wolf, Canis rufus, Canis niger": 271,
|
| 1758 |
+
"red-backed sandpiper, dunlin, Erolia alpina": 140,
|
| 1759 |
+
"red-breasted merganser, Mergus serrator": 98,
|
| 1760 |
+
"redbone": 168,
|
| 1761 |
+
"redshank, Tringa totanus": 141,
|
| 1762 |
+
"reel": 758,
|
| 1763 |
+
"reflex camera": 759,
|
| 1764 |
+
"refrigerator, icebox": 760,
|
| 1765 |
+
"remote control, remote": 761,
|
| 1766 |
+
"restaurant, eating house, eating place, eatery": 762,
|
| 1767 |
+
"revolver, six-gun, six-shooter": 763,
|
| 1768 |
+
"rhinoceros beetle": 306,
|
| 1769 |
+
"rifle": 764,
|
| 1770 |
+
"ringlet, ringlet butterfly": 322,
|
| 1771 |
+
"ringneck snake, ring-necked snake, ring snake": 53,
|
| 1772 |
+
"robin, American robin, Turdus migratorius": 15,
|
| 1773 |
+
"rock beauty, Holocanthus tricolor": 392,
|
| 1774 |
+
"rock crab, Cancer irroratus": 119,
|
| 1775 |
+
"rock python, rock snake, Python sebae": 62,
|
| 1776 |
+
"rocking chair, rocker": 765,
|
| 1777 |
+
"rotisserie": 766,
|
| 1778 |
+
"rubber eraser, rubber, pencil eraser": 767,
|
| 1779 |
+
"ruddy turnstone, Arenaria interpres": 139,
|
| 1780 |
+
"ruffed grouse, partridge, Bonasa umbellus": 82,
|
| 1781 |
+
"rugby ball": 768,
|
| 1782 |
+
"rule, ruler": 769,
|
| 1783 |
+
"running shoe": 770,
|
| 1784 |
+
"safe": 771,
|
| 1785 |
+
"safety pin": 772,
|
| 1786 |
+
"saltshaker, salt shaker": 773,
|
| 1787 |
+
"sandal": 774,
|
| 1788 |
+
"sandbar, sand bar": 977,
|
| 1789 |
+
"sarong": 775,
|
| 1790 |
+
"sax, saxophone": 776,
|
| 1791 |
+
"scabbard": 777,
|
| 1792 |
+
"scale, weighing machine": 778,
|
| 1793 |
+
"schipperke": 223,
|
| 1794 |
+
"school bus": 779,
|
| 1795 |
+
"schooner": 780,
|
| 1796 |
+
"scoreboard": 781,
|
| 1797 |
+
"scorpion": 71,
|
| 1798 |
+
"screen, CRT screen": 782,
|
| 1799 |
+
"screw": 783,
|
| 1800 |
+
"screwdriver": 784,
|
| 1801 |
+
"scuba diver": 983,
|
| 1802 |
+
"sea anemone, anemone": 108,
|
| 1803 |
+
"sea cucumber, holothurian": 329,
|
| 1804 |
+
"sea lion": 150,
|
| 1805 |
+
"sea slug, nudibranch": 115,
|
| 1806 |
+
"sea snake": 65,
|
| 1807 |
+
"sea urchin": 328,
|
| 1808 |
+
"seashore, coast, seacoast, sea-coast": 978,
|
| 1809 |
+
"seat belt, seatbelt": 785,
|
| 1810 |
+
"sewing machine": 786,
|
| 1811 |
+
"shield, buckler": 787,
|
| 1812 |
+
"shoe shop, shoe-shop, shoe store": 788,
|
| 1813 |
+
"shoji": 789,
|
| 1814 |
+
"shopping basket": 790,
|
| 1815 |
+
"shopping cart": 791,
|
| 1816 |
+
"shovel": 792,
|
| 1817 |
+
"shower cap": 793,
|
| 1818 |
+
"shower curtain": 794,
|
| 1819 |
+
"siamang, Hylobates syndactylus, Symphalangus syndactylus": 369,
|
| 1820 |
+
"sidewinder, horned rattlesnake, Crotalus cerastes": 68,
|
| 1821 |
+
"silky terrier, Sydney silky": 201,
|
| 1822 |
+
"ski": 795,
|
| 1823 |
+
"ski mask": 796,
|
| 1824 |
+
"skunk, polecat, wood pussy": 361,
|
| 1825 |
+
"sleeping bag": 797,
|
| 1826 |
+
"slide rule, slipstick": 798,
|
| 1827 |
+
"sliding door": 799,
|
| 1828 |
+
"slot, one-armed bandit": 800,
|
| 1829 |
+
"sloth bear, Melursus ursinus, Ursus ursinus": 297,
|
| 1830 |
+
"slug": 114,
|
| 1831 |
+
"snail": 113,
|
| 1832 |
+
"snorkel": 801,
|
| 1833 |
+
"snow leopard, ounce, Panthera uncia": 289,
|
| 1834 |
+
"snowmobile": 802,
|
| 1835 |
+
"snowplow, snowplough": 803,
|
| 1836 |
+
"soap dispenser": 804,
|
| 1837 |
+
"soccer ball": 805,
|
| 1838 |
+
"sock": 806,
|
| 1839 |
+
"soft-coated wheaten terrier": 202,
|
| 1840 |
+
"solar dish, solar collector, solar furnace": 807,
|
| 1841 |
+
"sombrero": 808,
|
| 1842 |
+
"sorrel": 339,
|
| 1843 |
+
"soup bowl": 809,
|
| 1844 |
+
"space bar": 810,
|
| 1845 |
+
"space heater": 811,
|
| 1846 |
+
"space shuttle": 812,
|
| 1847 |
+
"spaghetti squash": 940,
|
| 1848 |
+
"spatula": 813,
|
| 1849 |
+
"speedboat": 814,
|
| 1850 |
+
"spider monkey, Ateles geoffroyi": 381,
|
| 1851 |
+
"spider web, spider's web": 815,
|
| 1852 |
+
"spindle": 816,
|
| 1853 |
+
"spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish": 123,
|
| 1854 |
+
"spoonbill": 129,
|
| 1855 |
+
"sports car, sport car": 817,
|
| 1856 |
+
"spotlight, spot": 818,
|
| 1857 |
+
"spotted salamander, Ambystoma maculatum": 28,
|
| 1858 |
+
"squirrel monkey, Saimiri sciureus": 382,
|
| 1859 |
+
"stage": 819,
|
| 1860 |
+
"standard poodle": 267,
|
| 1861 |
+
"standard schnauzer": 198,
|
| 1862 |
+
"starfish, sea star": 327,
|
| 1863 |
+
"steam locomotive": 820,
|
| 1864 |
+
"steel arch bridge": 821,
|
| 1865 |
+
"steel drum": 822,
|
| 1866 |
+
"stethoscope": 823,
|
| 1867 |
+
"stingray": 6,
|
| 1868 |
+
"stinkhorn, carrion fungus": 994,
|
| 1869 |
+
"stole": 824,
|
| 1870 |
+
"stone wall": 825,
|
| 1871 |
+
"stopwatch, stop watch": 826,
|
| 1872 |
+
"stove": 827,
|
| 1873 |
+
"strainer": 828,
|
| 1874 |
+
"strawberry": 949,
|
| 1875 |
+
"street sign": 919,
|
| 1876 |
+
"streetcar, tram, tramcar, trolley, trolley car": 829,
|
| 1877 |
+
"stretcher": 830,
|
| 1878 |
+
"studio couch, day bed": 831,
|
| 1879 |
+
"stupa, tope": 832,
|
| 1880 |
+
"sturgeon": 394,
|
| 1881 |
+
"submarine, pigboat, sub, U-boat": 833,
|
| 1882 |
+
"suit, suit of clothes": 834,
|
| 1883 |
+
"sulphur butterfly, sulfur butterfly": 325,
|
| 1884 |
+
"sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita": 89,
|
| 1885 |
+
"sundial": 835,
|
| 1886 |
+
"sunglass": 836,
|
| 1887 |
+
"sunglasses, dark glasses, shades": 837,
|
| 1888 |
+
"sunscreen, sunblock, sun blocker": 838,
|
| 1889 |
+
"suspension bridge": 839,
|
| 1890 |
+
"swab, swob, mop": 840,
|
| 1891 |
+
"sweatshirt": 841,
|
| 1892 |
+
"swimming trunks, bathing trunks": 842,
|
| 1893 |
+
"swing": 843,
|
| 1894 |
+
"switch, electric switch, electrical switch": 844,
|
| 1895 |
+
"syringe": 845,
|
| 1896 |
+
"tabby, tabby cat": 281,
|
| 1897 |
+
"table lamp": 846,
|
| 1898 |
+
"tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui": 32,
|
| 1899 |
+
"tank, army tank, armored combat vehicle, armoured combat vehicle": 847,
|
| 1900 |
+
"tape player": 848,
|
| 1901 |
+
"tarantula": 76,
|
| 1902 |
+
"teapot": 849,
|
| 1903 |
+
"teddy, teddy bear": 850,
|
| 1904 |
+
"television, television system": 851,
|
| 1905 |
+
"tench, Tinca tinca": 0,
|
| 1906 |
+
"tennis ball": 852,
|
| 1907 |
+
"terrapin": 36,
|
| 1908 |
+
"thatch, thatched roof": 853,
|
| 1909 |
+
"theater curtain, theatre curtain": 854,
|
| 1910 |
+
"thimble": 855,
|
| 1911 |
+
"three-toed sloth, ai, Bradypus tridactylus": 364,
|
| 1912 |
+
"thresher, thrasher, threshing machine": 856,
|
| 1913 |
+
"throne": 857,
|
| 1914 |
+
"thunder snake, worm snake, Carphophis amoenus": 52,
|
| 1915 |
+
"tick": 78,
|
| 1916 |
+
"tiger beetle": 300,
|
| 1917 |
+
"tiger cat": 282,
|
| 1918 |
+
"tiger shark, Galeocerdo cuvieri": 3,
|
| 1919 |
+
"tiger, Panthera tigris": 292,
|
| 1920 |
+
"tile roof": 858,
|
| 1921 |
+
"timber wolf, grey wolf, gray wolf, Canis lupus": 269,
|
| 1922 |
+
"titi, titi monkey": 380,
|
| 1923 |
+
"toaster": 859,
|
| 1924 |
+
"tobacco shop, tobacconist shop, tobacconist": 860,
|
| 1925 |
+
"toilet seat": 861,
|
| 1926 |
+
"toilet tissue, toilet paper, bathroom tissue": 999,
|
| 1927 |
+
"torch": 862,
|
| 1928 |
+
"totem pole": 863,
|
| 1929 |
+
"toucan": 96,
|
| 1930 |
+
"tow truck, tow car, wrecker": 864,
|
| 1931 |
+
"toy poodle": 265,
|
| 1932 |
+
"toy terrier": 158,
|
| 1933 |
+
"toyshop": 865,
|
| 1934 |
+
"tractor": 866,
|
| 1935 |
+
"traffic light, traffic signal, stoplight": 920,
|
| 1936 |
+
"trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi": 867,
|
| 1937 |
+
"tray": 868,
|
| 1938 |
+
"tree frog, tree-frog": 31,
|
| 1939 |
+
"trench coat": 869,
|
| 1940 |
+
"triceratops": 51,
|
| 1941 |
+
"tricycle, trike, velocipede": 870,
|
| 1942 |
+
"trifle": 927,
|
| 1943 |
+
"trilobite": 69,
|
| 1944 |
+
"trimaran": 871,
|
| 1945 |
+
"tripod": 872,
|
| 1946 |
+
"triumphal arch": 873,
|
| 1947 |
+
"trolleybus, trolley coach, trackless trolley": 874,
|
| 1948 |
+
"trombone": 875,
|
| 1949 |
+
"tub, vat": 876,
|
| 1950 |
+
"turnstile": 877,
|
| 1951 |
+
"tusker": 101,
|
| 1952 |
+
"typewriter keyboard": 878,
|
| 1953 |
+
"umbrella": 879,
|
| 1954 |
+
"unicycle, monocycle": 880,
|
| 1955 |
+
"upright, upright piano": 881,
|
| 1956 |
+
"vacuum, vacuum cleaner": 882,
|
| 1957 |
+
"valley, vale": 979,
|
| 1958 |
+
"vase": 883,
|
| 1959 |
+
"vault": 884,
|
| 1960 |
+
"velvet": 885,
|
| 1961 |
+
"vending machine": 886,
|
| 1962 |
+
"vestment": 887,
|
| 1963 |
+
"viaduct": 888,
|
| 1964 |
+
"vine snake": 59,
|
| 1965 |
+
"violin, fiddle": 889,
|
| 1966 |
+
"vizsla, Hungarian pointer": 211,
|
| 1967 |
+
"volcano": 980,
|
| 1968 |
+
"volleyball": 890,
|
| 1969 |
+
"vulture": 23,
|
| 1970 |
+
"waffle iron": 891,
|
| 1971 |
+
"walking stick, walkingstick, stick insect": 313,
|
| 1972 |
+
"wall clock": 892,
|
| 1973 |
+
"wallaby, brush kangaroo": 104,
|
| 1974 |
+
"wallet, billfold, notecase, pocketbook": 893,
|
| 1975 |
+
"wardrobe, closet, press": 894,
|
| 1976 |
+
"warplane, military plane": 895,
|
| 1977 |
+
"warthog": 343,
|
| 1978 |
+
"washbasin, handbasin, washbowl, lavabo, wash-hand basin": 896,
|
| 1979 |
+
"washer, automatic washer, washing machine": 897,
|
| 1980 |
+
"water bottle": 898,
|
| 1981 |
+
"water buffalo, water ox, Asiatic buffalo, Bubalus bubalis": 346,
|
| 1982 |
+
"water jug": 899,
|
| 1983 |
+
"water ouzel, dipper": 20,
|
| 1984 |
+
"water snake": 58,
|
| 1985 |
+
"water tower": 900,
|
| 1986 |
+
"weasel": 356,
|
| 1987 |
+
"web site, website, internet site, site": 916,
|
| 1988 |
+
"weevil": 307,
|
| 1989 |
+
"whippet": 172,
|
| 1990 |
+
"whiptail, whiptail lizard": 41,
|
| 1991 |
+
"whiskey jug": 901,
|
| 1992 |
+
"whistle": 902,
|
| 1993 |
+
"white stork, Ciconia ciconia": 127,
|
| 1994 |
+
"white wolf, Arctic wolf, Canis lupus tundrarum": 270,
|
| 1995 |
+
"wig": 903,
|
| 1996 |
+
"wild boar, boar, Sus scrofa": 342,
|
| 1997 |
+
"window screen": 904,
|
| 1998 |
+
"window shade": 905,
|
| 1999 |
+
"wine bottle": 907,
|
| 2000 |
+
"wing": 908,
|
| 2001 |
+
"wire-haired fox terrier": 188,
|
| 2002 |
+
"wok": 909,
|
| 2003 |
+
"wolf spider, hunting spider": 77,
|
| 2004 |
+
"wombat": 106,
|
| 2005 |
+
"wood rabbit, cottontail, cottontail rabbit": 330,
|
| 2006 |
+
"wooden spoon": 910,
|
| 2007 |
+
"wool, woolen, woollen": 911,
|
| 2008 |
+
"worm fence, snake fence, snake-rail fence, Virginia fence": 912,
|
| 2009 |
+
"wreck": 913,
|
| 2010 |
+
"yawl": 914,
|
| 2011 |
+
"yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum": 986,
|
| 2012 |
+
"yurt": 915,
|
| 2013 |
+
"zebra": 340,
|
| 2014 |
+
"zucchini, courgette": 939
|
| 2015 |
+
},
|
| 2016 |
+
"layer_norm_eps": 1e-12,
|
| 2017 |
+
"model_type": "vit",
|
| 2018 |
+
"num_attention_heads": 12,
|
| 2019 |
+
"num_channels": 3,
|
| 2020 |
+
"num_hidden_layers": 12,
|
| 2021 |
+
"patch_size": 16,
|
| 2022 |
+
"pooler_act": "tanh",
|
| 2023 |
+
"pooler_output_size": 768,
|
| 2024 |
+
"qkv_bias": true,
|
| 2025 |
+
"transformers_version": "4.57.1"
|
| 2026 |
+
}
|
core/.ipynb_checkpoints/distill-checkpoint.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Knowledge-distillation utilities (model-family agnostic).
|
| 2 |
+
|
| 3 |
+
This module provides:
|
| 4 |
+
- Losses: KL distillation, soft cross-entropy, cosine feature loss
|
| 5 |
+
- Helper to obtain logits from models with/without built-in heads
|
| 6 |
+
- Lightweight classification head for backbone models (e.g., ViTModel)
|
| 7 |
+
- Simple evaluators (agreement %, KL) and diagnostics
|
| 8 |
+
|
| 9 |
+
Adapters may override `adapter_get_logits(model, x)` if a family needs a
|
| 10 |
+
custom extraction (e.g., language models with past_key_values).
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Callable, Optional, Protocol, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# -----------------------------------------------------------------------------
|
| 23 |
+
# Config
|
| 24 |
+
# -----------------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class KDConfig:
|
| 28 |
+
temperature: float = 2.0
|
| 29 |
+
alpha: float = 1.0 # multiplier for KL term; task loss handled outside
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
# Losses
|
| 34 |
+
# -----------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
def kl_divergence(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float = 2.0) -> torch.Tensor:
|
| 37 |
+
"""Batchmean KL(student/ T || teacher/ T) scaled by T^2 (Hinton-style)."""
|
| 38 |
+
p_s = F.log_softmax(student_logits / T, dim=-1)
|
| 39 |
+
p_t = F.softmax(teacher_logits / T, dim=-1)
|
| 40 |
+
return F.kl_div(p_s, p_t, reduction="batchmean") * (T * T)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def kd_loss(student_logits: torch.Tensor, teacher_logits: torch.Tensor, cfg: KDConfig) -> torch.Tensor:
|
| 44 |
+
return cfg.alpha * kl_divergence(student_logits, teacher_logits, T=cfg.temperature)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def soft_ce(student_logits: torch.Tensor, soft_targets: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""Soft cross-entropy: expects `soft_targets` already normalized."""
|
| 49 |
+
logp = F.log_softmax(student_logits, dim=-1)
|
| 50 |
+
return -(soft_targets * logp).sum(dim=-1).mean()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def cosine_feature_loss(student_feats: torch.Tensor, teacher_feats: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
"""1 - cosine similarity averaged over batch and time/patch dims."""
|
| 55 |
+
s = F.normalize(student_feats, dim=-1)
|
| 56 |
+
t = F.normalize(teacher_feats, dim=-1)
|
| 57 |
+
return (1.0 - (s * t).sum(dim=-1)).mean()
|
| 58 |
+
|
| 59 |
+
def mse_reg(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float = 2.0) -> torch.Tensor:
|
| 60 |
+
mse = F.mse_loss(student_logits,teacher_logits, reduction="mean")
|
| 61 |
+
return mse * (T * T)
|
| 62 |
+
|
| 63 |
+
# -----------------------------------------------------------------------------
|
| 64 |
+
# Logit extraction
|
| 65 |
+
# -----------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
class LogitsProvider(Protocol):
|
| 68 |
+
def __call__(self, model: nn.Module, x: torch.Tensor) -> torch.Tensor: ...
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ClsHead(nn.Module):
|
| 72 |
+
"""Minimal classification head: LN + Linear.
|
| 73 |
+
|
| 74 |
+
Useful when the backbone outputs hidden states (e.g., ViTModel) and you
|
| 75 |
+
want logits comparable to a teacher with a classification head.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, hidden_size: int, num_classes: int = 1000, base_head: Optional[nn.Module] = None):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.norm = nn.LayerNorm(hidden_size)
|
| 81 |
+
self.fc = nn.Linear(hidden_size, num_classes)
|
| 82 |
+
if base_head is not None:
|
| 83 |
+
# Try to load weights if shapes match (e.g., from HF classifier)
|
| 84 |
+
try:
|
| 85 |
+
self.load_state_dict(base_head.state_dict(), strict=False)
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
def forward(self, cls_token: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
return self.fc(self.norm(cls_token))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@torch.no_grad()
|
| 94 |
+
def infer_hidden_size(model: nn.Module, sample: torch.Tensor) -> int:
|
| 95 |
+
# Run a tiny forward to inspect hidden size when unknown
|
| 96 |
+
model.eval()
|
| 97 |
+
out = model(pixel_values=sample)
|
| 98 |
+
if hasattr(out, "last_hidden_state"):
|
| 99 |
+
return int(out.last_hidden_state.shape[-1])
|
| 100 |
+
if hasattr(out, "logits"):
|
| 101 |
+
return int(out.logits.shape[-1])
|
| 102 |
+
raise RuntimeError("Cannot infer hidden size; provide explicitly.")
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def default_get_logits(model: nn.Module, x: torch.Tensor, *, head: Optional[nn.Module] = None) -> torch.Tensor:
|
| 106 |
+
"""Family-agnostic logits extractor.
|
| 107 |
+
|
| 108 |
+
- If model output has `.logits`, return it.
|
| 109 |
+
- Else expects `.last_hidden_state` and uses [CLS] via provided `head`.
|
| 110 |
+
"""
|
| 111 |
+
out = model(pixel_values=x)
|
| 112 |
+
if hasattr(out, "logits"):
|
| 113 |
+
return out.logits
|
| 114 |
+
if hasattr(out, "last_hidden_state"):
|
| 115 |
+
if head is None:
|
| 116 |
+
raise ValueError("Backbone returned hidden states; supply a classification head.")
|
| 117 |
+
cls_tok = out.last_hidden_state[:, 0, :]
|
| 118 |
+
return head(cls_tok)
|
| 119 |
+
raise ValueError("Model output lacks logits and last_hidden_state.")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# -----------------------------------------------------------------------------
|
| 123 |
+
# Evaluators & diagnostics
|
| 124 |
+
# -----------------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
@torch.inference_mode()
|
| 127 |
+
def logits_std(model: nn.Module, loader, *, get_logits: LogitsProvider, batches: int = 10, device: str = "cuda") -> Tuple[float, int]:
|
| 128 |
+
s = 0.0
|
| 129 |
+
k = 0
|
| 130 |
+
for x in loader:
|
| 131 |
+
if k >= batches:
|
| 132 |
+
break
|
| 133 |
+
x = x.to(device)
|
| 134 |
+
y = get_logits(model, x)
|
| 135 |
+
s += y.std().item()
|
| 136 |
+
k += 1
|
| 137 |
+
return (s / max(1, k), k)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@torch.inference_mode()
|
| 141 |
+
def agreement_metrics(
|
| 142 |
+
student: nn.Module,
|
| 143 |
+
teacher: nn.Module,
|
| 144 |
+
loader,
|
| 145 |
+
*,
|
| 146 |
+
get_student_logits: LogitsProvider,
|
| 147 |
+
get_teacher_logits: LogitsProvider,
|
| 148 |
+
batches: int = 20,
|
| 149 |
+
T: float = 1.0,
|
| 150 |
+
device: str = "cuda",
|
| 151 |
+
) -> dict:
|
| 152 |
+
kl_sum = 0.0
|
| 153 |
+
n = 0
|
| 154 |
+
top1 = 0
|
| 155 |
+
tot = 0
|
| 156 |
+
for i, x in enumerate(loader):
|
| 157 |
+
if i >= batches:
|
| 158 |
+
break
|
| 159 |
+
x = x.to(device)
|
| 160 |
+
t = get_teacher_logits(teacher, x)
|
| 161 |
+
s = get_student_logits(student, x)
|
| 162 |
+
p_s = F.log_softmax(s / T, dim=-1)
|
| 163 |
+
p_t = F.softmax(t / T, dim=-1)
|
| 164 |
+
kl_sum += (F.kl_div(p_s, p_t, reduction="batchmean") * (T * T)).item()
|
| 165 |
+
top1 += (s.argmax(-1) == t.argmax(-1)).sum().item()
|
| 166 |
+
tot += x.size(0)
|
| 167 |
+
n += 1
|
| 168 |
+
return {"kl_TT": kl_sum / max(1, n), "top1_agreement": top1 / max(1, tot)}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# -----------------------------------------------------------------------------
|
| 172 |
+
# Small trainer helpers
|
| 173 |
+
# -----------------------------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
class DualEMA:
|
| 176 |
+
"""Simple exponential moving average for a scalar (e.g., lambda or latency)."""
|
| 177 |
+
|
| 178 |
+
def __init__(self, beta: float = 0.9, value: float = 0.0):
|
| 179 |
+
self.beta = float(beta)
|
| 180 |
+
self.value = float(value)
|
| 181 |
+
|
| 182 |
+
def update(self, x: float) -> float:
|
| 183 |
+
self.value = self.beta * self.value + (1 - self.beta) * float(x)
|
| 184 |
+
return self.value
|
core/.ipynb_checkpoints/finetune-checkpoint.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/finetune.py
|
| 2 |
+
"""Post-pruning fine-tuning utilities (distillation)."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Callable, Optional, Tuple, Iterable
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from core.distill import KDConfig, kd_loss, mse_reg
|
| 12 |
+
from core.utils import ensure_trainable_parameters
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class FinetuneConfig:
|
| 19 |
+
epochs: int = 5
|
| 20 |
+
lr: float = 3e-4
|
| 21 |
+
wd: float = 0.0
|
| 22 |
+
kd: KDConfig = KDConfig(temperature=2.0, alpha=1.0)
|
| 23 |
+
amp: bool = True
|
| 24 |
+
# "auto" -> bf16 if supported else fp16; "bf16" | "fp16" | "off" also allowed
|
| 25 |
+
amp_dtype: str = "auto"
|
| 26 |
+
device: str = "cuda"
|
| 27 |
+
log_every: int = 200
|
| 28 |
+
# diagnostics
|
| 29 |
+
grad_check_every: int = 50
|
| 30 |
+
grad_warn_if_zero_steps: int = 2 # consecutive checks with zero grad -> warn
|
| 31 |
+
mse_weight: float = 0.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _autocast_and_scaler(amp: bool, amp_dtype: str) -> Tuple[torch.autocast, Optional[torch.amp.GradScaler], bool, str]:
|
| 35 |
+
"""
|
| 36 |
+
Returns (autocast_ctx, scaler_or_None, use_scaler_bool, amp_mode_str)
|
| 37 |
+
- BF16 -> autocast(bfloat16), NO GradScaler
|
| 38 |
+
- FP16 -> autocast(float16), GradScaler ENABLED
|
| 39 |
+
- OFF -> disabled autocast, NO GradScaler
|
| 40 |
+
"""
|
| 41 |
+
if not amp or amp_dtype == "off":
|
| 42 |
+
ctx = torch.amp.autocast(device_type="cuda", enabled=False)
|
| 43 |
+
return ctx, None, False, "OFF"
|
| 44 |
+
|
| 45 |
+
if amp_dtype == "auto":
|
| 46 |
+
use_bf16 = torch.cuda.is_bf16_supported()
|
| 47 |
+
elif amp_dtype == "bf16":
|
| 48 |
+
use_bf16 = True
|
| 49 |
+
elif amp_dtype == "fp16":
|
| 50 |
+
use_bf16 = False
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError(f"Unknown amp_dtype={amp_dtype!r}")
|
| 53 |
+
|
| 54 |
+
if use_bf16:
|
| 55 |
+
ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True)
|
| 56 |
+
return ctx, None, False, "BF16"
|
| 57 |
+
else:
|
| 58 |
+
ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True)
|
| 59 |
+
try:
|
| 60 |
+
scaler = torch.amp.GradScaler("cuda", enabled=True)
|
| 61 |
+
except TypeError:
|
| 62 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
| 63 |
+
return ctx, scaler, True, "FP16"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _images_from_batch(batch):
|
| 67 |
+
if isinstance(batch, dict):
|
| 68 |
+
return batch.get("pixel_values") or batch.get("input")
|
| 69 |
+
if isinstance(batch, (tuple, list)):
|
| 70 |
+
return batch[0]
|
| 71 |
+
return batch
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _param_iter_trainable(model: nn.Module) -> Iterable[torch.nn.Parameter]:
|
| 75 |
+
for p in model.parameters():
|
| 76 |
+
if p.requires_grad:
|
| 77 |
+
yield p
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _grad_norm_and_nonzero(params: Iterable[torch.nn.Parameter]) -> Tuple[float, int]:
|
| 81 |
+
total_sq, nonzero = 0.0, 0
|
| 82 |
+
for p in params:
|
| 83 |
+
g = p.grad
|
| 84 |
+
if g is None:
|
| 85 |
+
continue
|
| 86 |
+
if g.is_sparse:
|
| 87 |
+
g = g.coalesce().values()
|
| 88 |
+
gn = float(g.detach().norm().cpu())
|
| 89 |
+
if gn > 0.0:
|
| 90 |
+
nonzero += 1
|
| 91 |
+
total_sq += gn * gn
|
| 92 |
+
return (total_sq ** 0.5), nonzero
|
| 93 |
+
|
| 94 |
+
@torch.no_grad()
|
| 95 |
+
def recalibrate_bn_stats(model, loader, max_batches=200, device="cuda"):
|
| 96 |
+
model.train() # use training mode to update running stats
|
| 97 |
+
seen = 0
|
| 98 |
+
for i, batch in enumerate(loader):
|
| 99 |
+
if i >= max_batches: break
|
| 100 |
+
x = batch[0] if isinstance(batch, (tuple, list)) else batch
|
| 101 |
+
if not torch.is_tensor(x): continue
|
| 102 |
+
x = x.to(device, non_blocking=True)
|
| 103 |
+
model(x)
|
| 104 |
+
seen += x.size(0)
|
| 105 |
+
return seen
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def finetune_student(
|
| 109 |
+
student: nn.Module,
|
| 110 |
+
teacher: nn.Module,
|
| 111 |
+
train_loader,
|
| 112 |
+
*,
|
| 113 |
+
get_student_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
|
| 114 |
+
get_teacher_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
|
| 115 |
+
cfg: FinetuneConfig = FinetuneConfig(),
|
| 116 |
+
val_loader=None,
|
| 117 |
+
on_step: Optional[Callable[[int, float], None]] = None,
|
| 118 |
+
save_best=False
|
| 119 |
+
) -> nn.Module:
|
| 120 |
+
"""Fine-tune a pruned student against a frozen teacher using KD."""
|
| 121 |
+
dev = cfg.device
|
| 122 |
+
student = student.to(dev)
|
| 123 |
+
teacher = teacher.to(dev).eval()
|
| 124 |
+
for p in teacher.parameters():
|
| 125 |
+
p.requires_grad_(False)
|
| 126 |
+
for p in student.parameters():
|
| 127 |
+
p.requires_grad_(True)
|
| 128 |
+
|
| 129 |
+
# Make sure we can actually train
|
| 130 |
+
ensure_trainable_parameters(student, requires_grad=True)
|
| 131 |
+
trainable = sum(p.numel() for p in student.parameters() if p.requires_grad)
|
| 132 |
+
if trainable == 0:
|
| 133 |
+
raise RuntimeError("No trainable parameters in student — cannot finetune.")
|
| 134 |
+
|
| 135 |
+
opt = torch.optim.AdamW(
|
| 136 |
+
_param_iter_trainable(student),
|
| 137 |
+
lr=cfg.lr,
|
| 138 |
+
weight_decay=cfg.wd,
|
| 139 |
+
)
|
| 140 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs*len(train_loader), eta_min=3e-5)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
autocast_ctx, scaler, use_scaler, amp_mode = _autocast_and_scaler(cfg.amp, cfg.amp_dtype)
|
| 144 |
+
print(f"[AMP] Mode={amp_mode} | GradScaler={'ON' if use_scaler else 'OFF'} | "
|
| 145 |
+
f"KD: T={cfg.kd.temperature} alpha={cfg.kd.alpha} | LR={cfg.lr} WD={cfg.wd} | Trainable params={trainable:,}")
|
| 146 |
+
|
| 147 |
+
zero_grad_streak = 0
|
| 148 |
+
global_step = 0
|
| 149 |
+
|
| 150 |
+
T_max = cfg.kd.temperature
|
| 151 |
+
T_min = 2.0
|
| 152 |
+
kd_conf = cfg.kd
|
| 153 |
+
|
| 154 |
+
best_state = None
|
| 155 |
+
best_val = float("inf")
|
| 156 |
+
|
| 157 |
+
for ep in range(cfg.epochs):
|
| 158 |
+
student.train()
|
| 159 |
+
running, seen = 0.0, 0
|
| 160 |
+
|
| 161 |
+
for i, batch in enumerate(train_loader):
|
| 162 |
+
|
| 163 |
+
step = ep*len(train_loader) + i # global step for T scheduling
|
| 164 |
+
max_steps = cfg.epochs*len(train_loader)
|
| 165 |
+
kd_conf.temperature = T_max - (step/max_steps)*(T_max - T_min)
|
| 166 |
+
|
| 167 |
+
# print(f"Step {step}/{max_steps}, T_min={T_min}, T={kd_conf.temperature}, T_max={T_max}")
|
| 168 |
+
|
| 169 |
+
x = _images_from_batch(batch)
|
| 170 |
+
if not torch.is_tensor(x):
|
| 171 |
+
raise ValueError("Train loader must yield tensors or (tensor, target) tuples.")
|
| 172 |
+
x = x.to(dev, non_blocking=True)
|
| 173 |
+
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
t = get_teacher_logits(teacher, x)
|
| 176 |
+
# Force numerically stable dtype for the loss
|
| 177 |
+
t = t.float()
|
| 178 |
+
|
| 179 |
+
# ---- forward student under autocast
|
| 180 |
+
with autocast_ctx:
|
| 181 |
+
s = get_student_logits(student, x)
|
| 182 |
+
|
| 183 |
+
# ---- compute KD loss in FP32 (outside autocast) for stability
|
| 184 |
+
s32 = s.float()
|
| 185 |
+
mse = cfg.mse_weight*mse_reg(s32, t, kd_conf.temperature)
|
| 186 |
+
loss = kd_loss(s32, t, kd_conf) + mse
|
| 187 |
+
|
| 188 |
+
opt.zero_grad(set_to_none=True)
|
| 189 |
+
if use_scaler:
|
| 190 |
+
scaler.scale(loss).backward()
|
| 191 |
+
scaler.step(opt)
|
| 192 |
+
scaler.update()
|
| 193 |
+
else:
|
| 194 |
+
loss.backward()
|
| 195 |
+
opt.step()
|
| 196 |
+
|
| 197 |
+
# ---- diagnostics
|
| 198 |
+
bs = x.size(0)
|
| 199 |
+
running += float(loss.detach()) * bs
|
| 200 |
+
seen += bs
|
| 201 |
+
global_step += 1
|
| 202 |
+
|
| 203 |
+
if cfg.grad_check_every and (global_step % cfg.grad_check_every == 0):
|
| 204 |
+
gnorm, n_nonzero = _grad_norm_and_nonzero(_param_iter_trainable(student))
|
| 205 |
+
if n_nonzero == 0 or gnorm == 0.0:
|
| 206 |
+
zero_grad_streak += 1
|
| 207 |
+
if zero_grad_streak >= cfg.grad_warn_if_zero_steps:
|
| 208 |
+
print(f"[WARN] Step {global_step}: zero gradients detected "
|
| 209 |
+
f"(nonzero={n_nonzero}, grad_norm={gnorm:.3e}). "
|
| 210 |
+
f"Check get_student_logits, requires_grad, AMP settings, and data pipeline.")
|
| 211 |
+
else:
|
| 212 |
+
zero_grad_streak = 0
|
| 213 |
+
|
| 214 |
+
if cfg.log_every and (i + 1) % cfg.log_every == 0:
|
| 215 |
+
print(f"Step {i+1}/{len(train_loader)} (ep {ep+1}/{cfg.epochs}): "
|
| 216 |
+
f"running loss = {running / max(1, seen):.4f}")
|
| 217 |
+
|
| 218 |
+
if on_step is not None:
|
| 219 |
+
on_step(global_step, float(loss.detach()))
|
| 220 |
+
|
| 221 |
+
# free ASAP
|
| 222 |
+
del s, s32, t, loss
|
| 223 |
+
|
| 224 |
+
# ---- validation
|
| 225 |
+
if val_loader is not None:
|
| 226 |
+
_ = recalibrate_bn_stats(student, train_loader, max_batches=1000, device=cfg.device)
|
| 227 |
+
student.eval()
|
| 228 |
+
val_loss, vseen = 0.0, 0
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
for vbatch in val_loader:
|
| 231 |
+
vx = _images_from_batch(vbatch)
|
| 232 |
+
if not torch.is_tensor(vx):
|
| 233 |
+
raise ValueError("Val loader must yield tensors or (tensor, target) tuples.")
|
| 234 |
+
vx = vx.to(dev, non_blocking=True)
|
| 235 |
+
|
| 236 |
+
vt = get_teacher_logits(teacher, vx).float()
|
| 237 |
+
with autocast_ctx:
|
| 238 |
+
vs = get_student_logits(student, vx)
|
| 239 |
+
|
| 240 |
+
vs32 = vs.float()
|
| 241 |
+
vmse = cfg.mse_weight*mse_reg(vs32, vt, kd_conf.temperature)
|
| 242 |
+
vloss = kd_loss(vs32, vt, kd_conf) + vmse
|
| 243 |
+
val_loss += float(vloss.detach()) * vx.size(0)
|
| 244 |
+
vseen += vx.size(0)
|
| 245 |
+
|
| 246 |
+
mean_val = val_loss / max(1, vseen)
|
| 247 |
+
print("\n------------------------------------------------")
|
| 248 |
+
print(f"Epoch {ep+1}/{cfg.epochs}: T={kd_conf.temperature:.2f}, train={running / max(1, seen):.6f}, "
|
| 249 |
+
f"val={mean_val:.6f}")
|
| 250 |
+
|
| 251 |
+
if save_best and (mean_val < best_val):
|
| 252 |
+
best_val = mean_val
|
| 253 |
+
best_state = copy.deepcopy(student.state_dict())
|
| 254 |
+
|
| 255 |
+
print("------------------------------------------------\n")
|
| 256 |
+
|
| 257 |
+
else:
|
| 258 |
+
print(f"Epoch {ep+1}/{cfg.epochs}: train={running / max(1, seen):.6f}")
|
| 259 |
+
|
| 260 |
+
scheduler.step()
|
| 261 |
+
|
| 262 |
+
if save_best and val_loader is not None and best_state is not None:
|
| 263 |
+
student.load_state_dict(best_state)
|
| 264 |
+
|
| 265 |
+
student.eval()
|
| 266 |
+
return student
|
| 267 |
+
|
core/.ipynb_checkpoints/profiler-checkpoint.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simple, robust latency measurement utilities.
|
| 2 |
+
|
| 3 |
+
This module provides GPU-friendly profilers with warmup, multiple repeats,
|
| 4 |
+
median/percentile reporting, and optional outlier rejection via MAD.
|
| 5 |
+
|
| 6 |
+
Design goals:
|
| 7 |
+
- Family-agnostic: take a callable `forward(model, x)` or rely on HF `.forward`
|
| 8 |
+
- Deterministic when desired; avoids autograd by default
|
| 9 |
+
- Works with CUDA or CPU; uses `torch.cuda.Event` for accurate GPU timing
|
| 10 |
+
|
| 11 |
+
Key APIs:
|
| 12 |
+
- measure_latency_ms(model, input_shape | input_tensor, ...)
|
| 13 |
+
- profile(model, sample, settings) -> {mean, p50, p90, p95, p99}
|
| 14 |
+
- LatencyProfiler(settings).measure(...)
|
| 15 |
+
- profile_many_shapes(model, shapes, settings)
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from statistics import median
|
| 21 |
+
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple
|
| 22 |
+
|
| 23 |
+
import contextlib
|
| 24 |
+
import math
|
| 25 |
+
import time
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# -----------------------------------------------------------------------------
|
| 32 |
+
# Settings
|
| 33 |
+
# -----------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ProfileSettings:
|
| 37 |
+
warmup: int = 10
|
| 38 |
+
iters: int = 50
|
| 39 |
+
percentile: Sequence[int] = (50, 90, 95, 99)
|
| 40 |
+
sync_each_iter: bool = True
|
| 41 |
+
use_inference_mode: bool = True
|
| 42 |
+
cuda_graph: bool = False # advanced users can enable with static shapes
|
| 43 |
+
reject_outliers_mad: float = 0.0 # e.g., 3.5 to drop extreme spikes
|
| 44 |
+
cudnn_benchmark: bool = True
|
| 45 |
+
deterministic: bool = False # sets cudnn.deterministic
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# -----------------------------------------------------------------------------
|
| 49 |
+
# Context helpers
|
| 50 |
+
# -----------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
@contextlib.contextmanager
|
| 53 |
+
def _torch_backend_ctx(settings: ProfileSettings):
|
| 54 |
+
prev_bench = torch.backends.cudnn.benchmark
|
| 55 |
+
prev_det = torch.backends.cudnn.deterministic
|
| 56 |
+
try:
|
| 57 |
+
torch.backends.cudnn.benchmark = bool(settings.cudnn_benchmark)
|
| 58 |
+
torch.backends.cudnn.deterministic = bool(settings.deterministic)
|
| 59 |
+
yield
|
| 60 |
+
finally:
|
| 61 |
+
torch.backends.cudnn.benchmark = prev_bench
|
| 62 |
+
torch.backends.cudnn.deterministic = prev_det
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _percentiles(sorted_vals: Sequence[float], qs: Sequence[int]) -> Dict[int, float]:
|
| 66 |
+
n = len(sorted_vals)
|
| 67 |
+
if n == 0:
|
| 68 |
+
return {q: float("nan") for q in qs}
|
| 69 |
+
out = {}
|
| 70 |
+
for q in qs:
|
| 71 |
+
if n == 1:
|
| 72 |
+
out[q] = sorted_vals[0]
|
| 73 |
+
continue
|
| 74 |
+
k = (q / 100.0) * (n - 1)
|
| 75 |
+
f = math.floor(k)
|
| 76 |
+
c = min(n - 1, f + 1)
|
| 77 |
+
if f == c:
|
| 78 |
+
out[q] = sorted_vals[int(k)]
|
| 79 |
+
else:
|
| 80 |
+
d0 = sorted_vals[f] * (c - k)
|
| 81 |
+
d1 = sorted_vals[c] * (k - f)
|
| 82 |
+
out[q] = d0 + d1
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _apply_mad_filter(vals: Sequence[float], thresh: float) -> Sequence[float]:
|
| 87 |
+
if thresh <= 0 or len(vals) < 5:
|
| 88 |
+
return vals
|
| 89 |
+
med = median(vals)
|
| 90 |
+
dev = [abs(v - med) for v in vals]
|
| 91 |
+
mad = median(dev) or 1e-12
|
| 92 |
+
keep = [v for v, d in zip(vals, dev) if (d / mad) <= thresh]
|
| 93 |
+
return keep if keep else vals
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# -----------------------------------------------------------------------------
|
| 97 |
+
# Core measurement
|
| 98 |
+
# -----------------------------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
@torch.inference_mode()
|
| 101 |
+
def measure_latency_ms(
|
| 102 |
+
model: nn.Module,
|
| 103 |
+
sample: torch.Tensor | Tuple[int, ...],
|
| 104 |
+
*,
|
| 105 |
+
settings: Optional[ProfileSettings] = None,
|
| 106 |
+
device: str = "cuda",
|
| 107 |
+
forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
|
| 108 |
+
) -> Tuple[float, float]:
|
| 109 |
+
"""Return (mean_ms, p95_ms) over `iters` measurements.
|
| 110 |
+
|
| 111 |
+
If `sample` is a shape tuple, a random tensor is created on-device.
|
| 112 |
+
The default forward calls `model(pixel_values=x)` if available, else `model(x)`.
|
| 113 |
+
"""
|
| 114 |
+
cfg = settings or ProfileSettings()
|
| 115 |
+
|
| 116 |
+
with _torch_backend_ctx(cfg):
|
| 117 |
+
m = model.to(device).eval()
|
| 118 |
+
if isinstance(sample, torch.Tensor):
|
| 119 |
+
x = sample.to(device)
|
| 120 |
+
else:
|
| 121 |
+
x = torch.randn(*sample, device=device)
|
| 122 |
+
|
| 123 |
+
# Default forward
|
| 124 |
+
def _fwd(mod, inp):
|
| 125 |
+
if hasattr(mod, "forward"):
|
| 126 |
+
try:
|
| 127 |
+
return mod(pixel_values=inp)
|
| 128 |
+
except TypeError:
|
| 129 |
+
return mod(inp)
|
| 130 |
+
return mod(inp)
|
| 131 |
+
|
| 132 |
+
fn = forward_fn or _fwd
|
| 133 |
+
|
| 134 |
+
# Warmup
|
| 135 |
+
if torch.cuda.is_available() and device.startswith("cuda"):
|
| 136 |
+
for _ in range(cfg.warmup):
|
| 137 |
+
_ = fn(m, x)
|
| 138 |
+
torch.cuda.synchronize()
|
| 139 |
+
else:
|
| 140 |
+
for _ in range(cfg.warmup):
|
| 141 |
+
_ = fn(m, x)
|
| 142 |
+
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
| 143 |
+
|
| 144 |
+
times: list[float] = []
|
| 145 |
+
if torch.cuda.is_available() and device.startswith("cuda"):
|
| 146 |
+
for _ in range(cfg.iters):
|
| 147 |
+
t0 = torch.cuda.Event(enable_timing=True)
|
| 148 |
+
t1 = torch.cuda.Event(enable_timing=True)
|
| 149 |
+
t0.record()
|
| 150 |
+
_ = fn(m, x)
|
| 151 |
+
t1.record()
|
| 152 |
+
if cfg.sync_each_iter:
|
| 153 |
+
torch.cuda.synchronize()
|
| 154 |
+
times.append(t0.elapsed_time(t1)) # milliseconds
|
| 155 |
+
else:
|
| 156 |
+
for _ in range(cfg.iters):
|
| 157 |
+
t0 = time.perf_counter()
|
| 158 |
+
_ = fn(m, x)
|
| 159 |
+
if cfg.sync_each_iter and torch.cuda.is_available():
|
| 160 |
+
torch.cuda.synchronize()
|
| 161 |
+
t1 = time.perf_counter()
|
| 162 |
+
times.append((t1 - t0) * 1000.0)
|
| 163 |
+
|
| 164 |
+
times = sorted(_apply_mad_filter(times, cfg.reject_outliers_mad))
|
| 165 |
+
mean_ms = sum(times) / max(1, len(times))
|
| 166 |
+
p = _percentiles(times, cfg.percentile)
|
| 167 |
+
p95 = p.get(95, times[int(0.95 * (len(times) - 1))] if times else float("nan"))
|
| 168 |
+
return mean_ms, p95
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# Higher level wrapper returning multiple percentiles
|
| 172 |
+
@torch.inference_mode()
|
| 173 |
+
def profile(
|
| 174 |
+
model: nn.Module,
|
| 175 |
+
sample: torch.Tensor | Tuple[int, ...],
|
| 176 |
+
*,
|
| 177 |
+
settings: Optional[ProfileSettings] = None,
|
| 178 |
+
device: str = "cuda",
|
| 179 |
+
forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
|
| 180 |
+
) -> Dict[str, float]:
|
| 181 |
+
cfg = settings or ProfileSettings()
|
| 182 |
+
mean_ms, _ = measure_latency_ms(model, sample, settings=cfg, device=device, forward_fn=forward_fn)
|
| 183 |
+
# Re-run percentile calc on same settings for consistency
|
| 184 |
+
m = model.to(device).eval()
|
| 185 |
+
if isinstance(sample, torch.Tensor):
|
| 186 |
+
x = sample.to(device)
|
| 187 |
+
else:
|
| 188 |
+
x = torch.randn(*sample, device=device)
|
| 189 |
+
|
| 190 |
+
if torch.cuda.is_available() and device.startswith("cuda"):
|
| 191 |
+
times = []
|
| 192 |
+
for _ in range(cfg.iters):
|
| 193 |
+
t0 = torch.cuda.Event(True); t1 = torch.cuda.Event(True)
|
| 194 |
+
t0.record(); _ = (forward_fn or (lambda a, b: a(pixel_values=b)))(m, x); t1.record();
|
| 195 |
+
if cfg.sync_each_iter: torch.cuda.synchronize()
|
| 196 |
+
times.append(t0.elapsed_time(t1))
|
| 197 |
+
else:
|
| 198 |
+
times = []
|
| 199 |
+
for _ in range(cfg.iters):
|
| 200 |
+
t0 = time.perf_counter(); _ = (forward_fn or (lambda a, b: a(pixel_values=b)))(m, x); t1 = time.perf_counter()
|
| 201 |
+
times.append((t1 - t0) * 1000.0)
|
| 202 |
+
|
| 203 |
+
times = sorted(_apply_mad_filter(times, cfg.reject_outliers_mad))
|
| 204 |
+
percs = _percentiles(times, cfg.percentile)
|
| 205 |
+
out = {"mean": sum(times) / max(1, len(times))}
|
| 206 |
+
out.update({f"p{q}": v for q, v in percs.items()})
|
| 207 |
+
return out
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class LatencyProfiler:
|
| 211 |
+
"""Reusable profiler with fixed settings."""
|
| 212 |
+
|
| 213 |
+
def __init__(self, settings: Optional[ProfileSettings] = None, device: str = "cuda"):
|
| 214 |
+
self.settings = settings or ProfileSettings()
|
| 215 |
+
self.device = device
|
| 216 |
+
|
| 217 |
+
def measure(self, model: nn.Module, sample: torch.Tensor | Tuple[int, ...], *, forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None) -> Tuple[float, float]:
|
| 218 |
+
return measure_latency_ms(model, sample, settings=self.settings, device=self.device, forward_fn=forward_fn)
|
| 219 |
+
|
| 220 |
+
def profile(self, model: nn.Module, sample: torch.Tensor | Tuple[int, ...], *, forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None) -> Dict[str, float]:
|
| 221 |
+
return profile(model, sample, settings=self.settings, device=self.device, forward_fn=forward_fn)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@torch.inference_mode()
|
| 225 |
+
def profile_many_shapes(
|
| 226 |
+
model: nn.Module,
|
| 227 |
+
shapes: Iterable[Tuple[int, ...]],
|
| 228 |
+
*,
|
| 229 |
+
settings: Optional[ProfileSettings] = None,
|
| 230 |
+
device: str = "cuda",
|
| 231 |
+
forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
|
| 232 |
+
) -> Dict[Tuple[int, ...], Dict[str, float]]:
|
| 233 |
+
out: Dict[Tuple[int, ...], Dict[str, float]] = {}
|
| 234 |
+
for shp in shapes:
|
| 235 |
+
out[tuple(shp)] = profile(model, shp, settings=settings, device=device, forward_fn=forward_fn)
|
| 236 |
+
return out
|
core/.ipynb_checkpoints/proxy_cost-checkpoint.py
ADDED
|
@@ -0,0 +1,771 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/proxy_cost.py
|
| 2 |
+
"""Latency proxy models and a tiny LUT for hardware correction.
|
| 3 |
+
|
| 4 |
+
This file defines a family-agnostic interface plus concrete proxies (ViT, ResNet, LLM)
|
| 5 |
+
that estimate latency from *soft structure* (gates) and input size. All proxies accept
|
| 6 |
+
the trainer's `(model, batch) -> ms` call signature directly (batches may be dict/tuple/tensor).
|
| 7 |
+
A small, in-memory LUT can be populated from real measurements during training to correct
|
| 8 |
+
analytic estimates.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Any, Dict, Optional, Tuple, Union, List
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
from .gates import iter_gates, _as_like # _as_like is used by ViT proxy
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# -----------------------------------------------------------------------------
|
| 22 |
+
# Small batch helpers (shared)
|
| 23 |
+
# -----------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
TensorOrBatch = Union[torch.Tensor, Tuple, List, Dict[str, Any]]
|
| 26 |
+
|
| 27 |
+
def _first_tensor(batch: TensorOrBatch) -> torch.Tensor:
|
| 28 |
+
"""Find the first tensor inside a batch-like structure."""
|
| 29 |
+
if torch.is_tensor(batch):
|
| 30 |
+
return batch
|
| 31 |
+
if isinstance(batch, dict):
|
| 32 |
+
# Common keys across tasks
|
| 33 |
+
for k in ("input_ids", "pixel_values", "images", "x"):
|
| 34 |
+
v = batch.get(k, None)
|
| 35 |
+
if torch.is_tensor(v):
|
| 36 |
+
return v
|
| 37 |
+
# fallback: first tensor value
|
| 38 |
+
for v in batch.values():
|
| 39 |
+
if torch.is_tensor(v):
|
| 40 |
+
return v
|
| 41 |
+
raise ValueError("Batch dict has no tensor field I recognize.")
|
| 42 |
+
if isinstance(batch, (list, tuple)):
|
| 43 |
+
for v in batch:
|
| 44 |
+
if torch.is_tensor(v):
|
| 45 |
+
return v
|
| 46 |
+
# torchvision pattern: ([aug1, aug2], label)
|
| 47 |
+
if len(batch) and isinstance(batch[0], (list, tuple)):
|
| 48 |
+
for v in batch[0]:
|
| 49 |
+
if torch.is_tensor(v):
|
| 50 |
+
return v
|
| 51 |
+
raise ValueError("Cannot find a tensor in the provided batch.")
|
| 52 |
+
|
| 53 |
+
def _ids_from_batch(batch: TensorOrBatch) -> torch.Tensor:
|
| 54 |
+
"""Return a 2D [B,S] tensor representing token ids for LLMs."""
|
| 55 |
+
if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
|
| 56 |
+
return batch["input_ids"]
|
| 57 |
+
t = _first_tensor(batch)
|
| 58 |
+
if t.dim() >= 2:
|
| 59 |
+
return t
|
| 60 |
+
raise ValueError("Cannot infer [B,S] from batch; need 'input_ids' or a 2D tensor.")
|
| 61 |
+
|
| 62 |
+
def _nchw_from_batch(batch: TensorOrBatch) -> Tuple[int, int, int, int]:
|
| 63 |
+
"""Return NCHW shape from a batch or an explicit (N,C,H,W) tuple/list/tensor."""
|
| 64 |
+
if isinstance(batch, (tuple, list)) and len(batch) == 4 and all(isinstance(x, int) for x in batch):
|
| 65 |
+
return tuple(batch) # type: ignore[return-value]
|
| 66 |
+
x = _first_tensor(batch)
|
| 67 |
+
if x.dim() != 4:
|
| 68 |
+
raise ValueError(f"Expected NCHW tensor for CNN proxy; got tensor with shape {tuple(x.shape)}")
|
| 69 |
+
N, C, H, W = map(int, x.shape)
|
| 70 |
+
return (N, C, H, W)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# -----------------------------------------------------------------------------
|
| 74 |
+
# Base proxy + LUT
|
| 75 |
+
# -----------------------------------------------------------------------------
|
| 76 |
+
|
| 77 |
+
class LatencyProxy(nn.Module):
|
| 78 |
+
"""Abstract proxy producing a scalar latency-like value (ms).
|
| 79 |
+
|
| 80 |
+
Subclasses implement `_predict_raw` and may define `_signature` keys used by
|
| 81 |
+
a LUT to refine estimates with real measurements. Proxies accept either a
|
| 82 |
+
batch-like object (dict/tuple/tensor) or an explicit shape tuple.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self):
|
| 86 |
+
super().__init__()
|
| 87 |
+
|
| 88 |
+
def predict(
|
| 89 |
+
self,
|
| 90 |
+
model: nn.Module,
|
| 91 |
+
sample: TensorOrBatch,
|
| 92 |
+
*,
|
| 93 |
+
policy=None,
|
| 94 |
+
step: Optional[int] = None,
|
| 95 |
+
**kwargs,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
"""Batch-friendly entry point. `sample` may be a batch or explicit shape."""
|
| 98 |
+
return self._predict_raw(model, sample, policy=policy, step=step, **kwargs)
|
| 99 |
+
|
| 100 |
+
def _predict_raw(
|
| 101 |
+
self,
|
| 102 |
+
model: nn.Module,
|
| 103 |
+
sample: TensorOrBatch,
|
| 104 |
+
*,
|
| 105 |
+
policy=None,
|
| 106 |
+
step: Optional[int] = None,
|
| 107 |
+
**kwargs,
|
| 108 |
+
) -> torch.Tensor: # pragma: no cover - abstract
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
def signature(
|
| 112 |
+
self,
|
| 113 |
+
model: nn.Module,
|
| 114 |
+
sample: TensorOrBatch,
|
| 115 |
+
*,
|
| 116 |
+
policy=None,
|
| 117 |
+
step: Optional[int] = None
|
| 118 |
+
) -> Tuple:
|
| 119 |
+
"""Return a hashable signature describing the workload shape."""
|
| 120 |
+
if torch.is_tensor(sample):
|
| 121 |
+
shp = tuple(sample.shape)
|
| 122 |
+
elif isinstance(sample, (tuple, list)):
|
| 123 |
+
shp = tuple(sample)
|
| 124 |
+
elif isinstance(sample, dict):
|
| 125 |
+
# summarize the shapes of any tensors in dict
|
| 126 |
+
shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v))
|
| 127 |
+
else:
|
| 128 |
+
shp = (str(type(sample)),)
|
| 129 |
+
return (type(self).__name__, shp)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class LatencyLUT:
|
| 133 |
+
"""Tiny LUT mapping `(signature) -> measured_ms`."""
|
| 134 |
+
|
| 135 |
+
def __init__(self):
|
| 136 |
+
self._table: Dict[Tuple[Any, ...], float] = {}
|
| 137 |
+
|
| 138 |
+
def update(self, signature: Tuple[Any, ...], measured_ms: float) -> None:
|
| 139 |
+
self._table[signature] = float(measured_ms)
|
| 140 |
+
|
| 141 |
+
def get(self, signature: Tuple[Any, ...]) -> Optional[float]:
|
| 142 |
+
return self._table.get(signature)
|
| 143 |
+
|
| 144 |
+
def blend(self, raw_estimate: torch.Tensor, signature: Tuple[Any, ...]) -> torch.Tensor:
|
| 145 |
+
val = self.get(signature)
|
| 146 |
+
if val is None:
|
| 147 |
+
return raw_estimate
|
| 148 |
+
# Put on same device/dtype as raw_estimate
|
| 149 |
+
return _as_like(raw_estimate, val)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# -----------------------------------------------------------------------------
|
| 153 |
+
# ViT proxy (analytic + gates), with scale and per-term weights
|
| 154 |
+
# -----------------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
@dataclass
|
| 157 |
+
class ViTProxyConfig:
|
| 158 |
+
scale_ms: float = 1.0
|
| 159 |
+
alpha_qkv: float = 1.0
|
| 160 |
+
alpha_scores: float = 1.0
|
| 161 |
+
alpha_out: float = 1.0
|
| 162 |
+
alpha_mlp: float = 1.0
|
| 163 |
+
|
| 164 |
+
def _vit_layers(m):
|
| 165 |
+
enc = getattr(m, "encoder", None)
|
| 166 |
+
if enc is not None and hasattr(enc, "layer"):
|
| 167 |
+
return enc.layer
|
| 168 |
+
vit = getattr(m, "vit", None)
|
| 169 |
+
if vit is not None and hasattr(vit, "encoder") and hasattr(vit.encoder, "layer"):
|
| 170 |
+
return vit.encoder.layer
|
| 171 |
+
raise TypeError("Expected a HF ViT with *.encoder.layer (ViTModel or ViTForImageClassification).")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class ViTLatencyProxy(LatencyProxy):
|
| 175 |
+
"""Latency proxy for ViT models. Accepts batches or (N,C,H,W) tuples."""
|
| 176 |
+
|
| 177 |
+
def __init__(self, cfg: Optional[ViTProxyConfig] = None, lut: Optional[LatencyLUT] = None):
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.cfg = cfg or ViTProxyConfig()
|
| 180 |
+
self.lut = lut or LatencyLUT()
|
| 181 |
+
|
| 182 |
+
# ---- helpers -------------------------------------------------------------
|
| 183 |
+
@staticmethod
|
| 184 |
+
def _input_spec(sample: TensorOrBatch) -> Tuple[int, int, int]:
|
| 185 |
+
if isinstance(sample, (tuple, list)) and len(sample) == 4 and all(isinstance(x, int) for x in sample):
|
| 186 |
+
B, C, H, W = sample
|
| 187 |
+
return int(B), int(H), int(W)
|
| 188 |
+
x = _first_tensor(sample)
|
| 189 |
+
if x.dim() != 4:
|
| 190 |
+
raise ValueError("ViTLatencyProxy expects a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
|
| 191 |
+
B, C, H, W = x.shape
|
| 192 |
+
return int(B), int(H), int(W)
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def _patch_hw(cfg) -> Tuple[int, int]:
|
| 196 |
+
patch = getattr(cfg, "patch_size", 16)
|
| 197 |
+
if isinstance(patch, (tuple, list)):
|
| 198 |
+
return int(patch[0]), int(patch[1])
|
| 199 |
+
return int(patch), int(patch)
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def _soft_heads_from_block(blk) -> Optional[torch.Tensor]:
|
| 203 |
+
# Prefer a nested attention with kept_heads_soft()
|
| 204 |
+
attn = getattr(getattr(blk, "attention", None), "attention", None)
|
| 205 |
+
if attn is not None and hasattr(attn, "kept_heads_soft"):
|
| 206 |
+
return attn.kept_heads_soft()
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
@staticmethod
|
| 210 |
+
def _find_ffn_gate(blk):
|
| 211 |
+
inter = getattr(blk, "intermediate", None)
|
| 212 |
+
if inter is None:
|
| 213 |
+
return None
|
| 214 |
+
# Common attribute names
|
| 215 |
+
for nm in ("neuron_gate", "gate", "ffn_gate"):
|
| 216 |
+
g = getattr(inter, nm, None)
|
| 217 |
+
if g is not None and hasattr(g, "logits") and hasattr(g, "tau"):
|
| 218 |
+
return g
|
| 219 |
+
# Last resort: scan children
|
| 220 |
+
for m in blk.modules():
|
| 221 |
+
if hasattr(m, "logits") and hasattr(m, "tau"):
|
| 222 |
+
return m
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
# ---- proxy ---------------------------------------------------------------
|
| 226 |
+
def _predict_raw(
|
| 227 |
+
self,
|
| 228 |
+
model: nn.Module,
|
| 229 |
+
sample: TensorOrBatch,
|
| 230 |
+
*,
|
| 231 |
+
policy=None,
|
| 232 |
+
step: Optional[int] = None
|
| 233 |
+
) -> torch.Tensor:
|
| 234 |
+
anchor = next((p for p in model.parameters()), torch.tensor(0.0))
|
| 235 |
+
|
| 236 |
+
B, H_img, W_img = self._input_spec(sample)
|
| 237 |
+
cfg = getattr(model, "config", None)
|
| 238 |
+
if cfg is None:
|
| 239 |
+
raise ValueError("Model must expose a HuggingFace-like .config for ViT proxy")
|
| 240 |
+
ph, pw = self._patch_hw(cfg)
|
| 241 |
+
|
| 242 |
+
S = _as_like(anchor, 1 + (H_img // ph) * (W_img // pw))
|
| 243 |
+
D = _as_like(anchor, int(getattr(cfg, "hidden_size", 768)))
|
| 244 |
+
Hh = _as_like(anchor, int(getattr(cfg, "num_attention_heads", 12)))
|
| 245 |
+
Dh = D // Hh
|
| 246 |
+
|
| 247 |
+
warm = False
|
| 248 |
+
if policy is not None and step is not None:
|
| 249 |
+
warm = (step < int(getattr(policy, "warmup_steps", 0)))
|
| 250 |
+
|
| 251 |
+
total_qkv = _as_like(anchor, 0.0)
|
| 252 |
+
total_scores = _as_like(anchor, 0.0)
|
| 253 |
+
total_out = _as_like(anchor, 0.0)
|
| 254 |
+
total_mlp = _as_like(anchor, 0.0)
|
| 255 |
+
|
| 256 |
+
default_hidden = _as_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D))))
|
| 257 |
+
|
| 258 |
+
layers = _vit_layers(model)
|
| 259 |
+
for blk in layers:
|
| 260 |
+
heads_soft = Hh if warm else (self._soft_heads_from_block(blk) or Hh)
|
| 261 |
+
|
| 262 |
+
# FFN hidden expectation
|
| 263 |
+
if warm:
|
| 264 |
+
hidden_soft = default_hidden
|
| 265 |
+
else:
|
| 266 |
+
g = self._find_ffn_gate(blk)
|
| 267 |
+
if g is None:
|
| 268 |
+
hidden_soft = default_hidden
|
| 269 |
+
else:
|
| 270 |
+
probs = torch.sigmoid(g.logits / g.tau)
|
| 271 |
+
group = int(getattr(g, "group", getattr(g, "group_size", 16)))
|
| 272 |
+
hidden_soft = probs.sum() * _as_like(anchor, group)
|
| 273 |
+
|
| 274 |
+
D_kept = heads_soft * Dh
|
| 275 |
+
|
| 276 |
+
total_qkv += 3 * S * D * D_kept
|
| 277 |
+
total_scores += (S * S) * heads_soft * Dh
|
| 278 |
+
total_out += S * D_kept * D
|
| 279 |
+
total_mlp += 2 * S * D * hidden_soft
|
| 280 |
+
|
| 281 |
+
raw = (
|
| 282 |
+
self.cfg.alpha_qkv * total_qkv
|
| 283 |
+
+ self.cfg.alpha_scores * total_scores
|
| 284 |
+
+ self.cfg.alpha_out * total_out
|
| 285 |
+
+ self.cfg.alpha_mlp * total_mlp
|
| 286 |
+
)
|
| 287 |
+
raw_ms = raw * _as_like(anchor, float(self.cfg.scale_ms))
|
| 288 |
+
|
| 289 |
+
# optional LUT correction
|
| 290 |
+
sig = self.signature(model, sample, policy=policy, step=step)
|
| 291 |
+
return self.lut.blend(raw_ms, sig)
|
| 292 |
+
|
| 293 |
+
# A reasonable default signature for ViT workloads
|
| 294 |
+
def signature(self, model: nn.Module, sample, *, policy=None, step: Optional[int] = None) -> Tuple:
|
| 295 |
+
if torch.is_tensor(sample):
|
| 296 |
+
shp = tuple(sample.shape)
|
| 297 |
+
elif isinstance(sample, (tuple, list)):
|
| 298 |
+
shp = tuple(sample)
|
| 299 |
+
elif isinstance(sample, dict):
|
| 300 |
+
shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v))
|
| 301 |
+
else:
|
| 302 |
+
shp = (str(type(sample)),)
|
| 303 |
+
cfg = getattr(model, "config", None)
|
| 304 |
+
heads = int(getattr(cfg, "num_attention_heads", 12))
|
| 305 |
+
hidden = int(getattr(cfg, "hidden_size", 768))
|
| 306 |
+
inter = int(getattr(cfg, "intermediate_size", 3072))
|
| 307 |
+
return ("ViT", shp, heads, hidden, inter)
|
| 308 |
+
|
| 309 |
+
@torch.no_grad()
|
| 310 |
+
def calibrate(self, model: nn.Module, shape: tuple, measure_fn, *, device: str = "cuda") -> float:
|
| 311 |
+
"""Set proxy scale so that keep-all student matches measured ms.
|
| 312 |
+
|
| 313 |
+
`measure_fn(model, shape_or_tensor)` should return `(mean_ms, p95_ms)`.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
sample_t = torch.randn(shape, device=device)
|
| 317 |
+
|
| 318 |
+
sample_t = sample_t.to(device)
|
| 319 |
+
model = model.to(device).eval()
|
| 320 |
+
mean_ms, _ = measure_fn(model, shape, device=device)
|
| 321 |
+
soft_ms = self.predict(model, sample_t).item()
|
| 322 |
+
self.cfg.scale_ms = float(mean_ms / max(soft_ms, 1e-9))
|
| 323 |
+
return self.cfg.scale_ms
|
| 324 |
+
|
| 325 |
+
# ------------------------------ ResNet Proxy ------------------------------
|
| 326 |
+
|
| 327 |
+
@dataclass
|
| 328 |
+
class ResNetProxyConfig:
|
| 329 |
+
scale_ms: float = 1.0
|
| 330 |
+
alpha_conv: float = 1.0 # weight for conv FLOPs term
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _as_const_like_resnet(x_like: torch.Tensor, val):
|
| 334 |
+
return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _find_anchor_param(model: nn.Module) -> torch.Tensor:
|
| 338 |
+
# Prefer any gate-like parameter; otherwise any parameter; else cpu scalar
|
| 339 |
+
for m in model.modules():
|
| 340 |
+
for nm in ("logits", "head_gate"):
|
| 341 |
+
t = getattr(m, nm, None)
|
| 342 |
+
if isinstance(t, torch.Tensor):
|
| 343 |
+
return t
|
| 344 |
+
for p in model.parameters():
|
| 345 |
+
return p
|
| 346 |
+
return torch.tensor(0.0)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _kept_from_gate(module, anchor: torch.Tensor) -> Optional[torch.Tensor]:
|
| 350 |
+
"""Return expected kept channels for a BN gate: probs.sum() * group_size.
|
| 351 |
+
If no gate is found, return None.
|
| 352 |
+
"""
|
| 353 |
+
g = None
|
| 354 |
+
for nm in ("gate", "neuron_gate", "channel_gate", "bn_gate"):
|
| 355 |
+
if hasattr(module, nm):
|
| 356 |
+
g = getattr(module, nm)
|
| 357 |
+
break
|
| 358 |
+
if g is None and hasattr(module, "logits") and hasattr(module, "tau"):
|
| 359 |
+
g = module
|
| 360 |
+
|
| 361 |
+
if g is None or not hasattr(g, "logits"):
|
| 362 |
+
return None
|
| 363 |
+
logits = g.logits
|
| 364 |
+
tau = float(getattr(g, "tau", 1.5))
|
| 365 |
+
group = int(getattr(g, "group", getattr(g, "group_size", 1)))
|
| 366 |
+
if group <= 0: group = 1
|
| 367 |
+
probs = torch.sigmoid(logits / tau)
|
| 368 |
+
return probs.sum() * _as_const_like_resnet(anchor, group)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class ResNetLatencyProxy(LatencyProxy):
|
| 372 |
+
"""Latency proxy for ResNet-like backbones with BN gates.
|
| 373 |
+
|
| 374 |
+
Approximates latency with a FLOPs-style sum over convs, using the *expected*
|
| 375 |
+
kept channels after each BN gate (probs.sum()*group_size). Falls back to the
|
| 376 |
+
full channel count when a gate is not found.
|
| 377 |
+
|
| 378 |
+
Accepts a batch or an explicit (N,C,H,W) shape.
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
def __init__(self, cfg: Optional[ResNetProxyConfig] = None):
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.cfg = cfg or ResNetProxyConfig()
|
| 384 |
+
|
| 385 |
+
def _add_cost(self, cost_like: torch.Tensor, oc, ic, k, stride, H, W):
|
| 386 |
+
alpha = _as_const_like_resnet(cost_like, self.cfg.alpha_conv)
|
| 387 |
+
# update spatial dims with conv stride (roughly, ignoring padding effects)
|
| 388 |
+
H = (H + stride - 1) // stride
|
| 389 |
+
W = (W + stride - 1) // stride
|
| 390 |
+
flops = _as_const_like_resnet(cost_like, oc) * _as_const_like_resnet(cost_like, ic) * (k * k) * _as_const_like_resnet(cost_like, H) * _as_const_like_resnet(cost_like, W)
|
| 391 |
+
return cost_like + alpha * flops, H, W
|
| 392 |
+
|
| 393 |
+
def _predict_raw(self, model: nn.Module, sample: TensorOrBatch, **_) -> torch.Tensor:
|
| 394 |
+
N, C_in, H0, W0 = _nchw_from_batch(sample)
|
| 395 |
+
anchor = _find_anchor_param(model)
|
| 396 |
+
cost = _as_const_like_resnet(anchor, 0.0)
|
| 397 |
+
H = _as_const_like_resnet(anchor, int(H0))
|
| 398 |
+
W = _as_const_like_resnet(anchor, int(W0))
|
| 399 |
+
|
| 400 |
+
# Stem
|
| 401 |
+
conv1 = getattr(model, "conv1")
|
| 402 |
+
bn1 = getattr(model, "bn1", None)
|
| 403 |
+
k = conv1.kernel_size[0]
|
| 404 |
+
s = conv1.stride[0]
|
| 405 |
+
kept_out = None
|
| 406 |
+
if bn1 is not None:
|
| 407 |
+
kept = _kept_from_gate(bn1, anchor)
|
| 408 |
+
if kept is not None:
|
| 409 |
+
kept_out = kept
|
| 410 |
+
oc_eff = kept_out if kept_out is not None else _as_const_like_resnet(anchor, conv1.out_channels)
|
| 411 |
+
cost, H, W = self._add_cost(cost, oc_eff, _as_const_like_resnet(anchor, C_in), k, s, H, W)
|
| 412 |
+
in_ch = oc_eff
|
| 413 |
+
|
| 414 |
+
def _block_cost(block, in_ch, H, W, cost):
|
| 415 |
+
# conv1 -> bn1
|
| 416 |
+
c1 = block.conv1
|
| 417 |
+
b1 = block.bn1 if hasattr(block, "bn1") else None
|
| 418 |
+
k1, s1 = c1.kernel_size[0], c1.stride[0]
|
| 419 |
+
oc1_eff = _kept_from_gate(b1, anchor) or _as_const_like_resnet(anchor, c1.out_channels)
|
| 420 |
+
cost, H, W = self._add_cost(cost, oc1_eff, in_ch, k1, s1, H, W)
|
| 421 |
+
|
| 422 |
+
# conv2 -> bn2
|
| 423 |
+
c2 = block.conv2
|
| 424 |
+
b2 = block.bn2 if hasattr(block, "bn2") else None
|
| 425 |
+
k2, s2 = c2.kernel_size[0], c2.stride[0]
|
| 426 |
+
oc2_eff = _kept_from_gate(b2, anchor) or _as_const_like_resnet(anchor, c2.out_channels)
|
| 427 |
+
cost, H, W = self._add_cost(cost, oc2_eff, oc1_eff, k2, s2, H, W)
|
| 428 |
+
|
| 429 |
+
return oc2_eff, H, W, cost
|
| 430 |
+
|
| 431 |
+
# Layers
|
| 432 |
+
for lname in ("layer1", "layer2", "layer3", "layer4"):
|
| 433 |
+
layer = getattr(model, lname, None)
|
| 434 |
+
if layer is None:
|
| 435 |
+
continue
|
| 436 |
+
for blk in layer:
|
| 437 |
+
in_ch, H, W, cost = _block_cost(blk, in_ch, H, W, cost)
|
| 438 |
+
|
| 439 |
+
scale = _as_const_like_resnet(anchor, self.cfg.scale_ms)
|
| 440 |
+
return cost * scale
|
| 441 |
+
|
| 442 |
+
@torch.no_grad()
|
| 443 |
+
def calibrate(self, model: nn.Module, keepall_export_fn, profiler_fn, sample: TensorOrBatch, device: str = "cuda") -> float:
|
| 444 |
+
"""Calibrate `scale_ms` so proxy(model_keepall) ~= real latency in ms."""
|
| 445 |
+
keep = keepall_export_fn(model)
|
| 446 |
+
sample_shape = _nchw_from_batch(sample)
|
| 447 |
+
mean_ms, _ = profiler_fn(keep, sample_shape, device=device)
|
| 448 |
+
soft = float(self.predict(model, sample).detach().cpu())
|
| 449 |
+
self.cfg.scale_ms = mean_ms / max(soft, 1e-9)
|
| 450 |
+
return mean_ms
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
# -----------------------------------------------------------------------------
|
| 454 |
+
# LLM proxy
|
| 455 |
+
# -----------------------------------------------------------------------------
|
| 456 |
+
|
| 457 |
+
"""
|
| 458 |
+
LatencyProxyLLM
|
| 459 |
+
---------------
|
| 460 |
+
A lightweight latency proxy for decoder-only HF LLMs (LLaMA/Mistral style).
|
| 461 |
+
|
| 462 |
+
- Estimates end-to-end latency (ms-like scalar) for a given (B, S, T):
|
| 463 |
+
* Prefill on S tokens (build KV cache)
|
| 464 |
+
* Cached decode for T steps
|
| 465 |
+
- Uses soft gate expectations:
|
| 466 |
+
* Attention heads (HeadGate on GatedSelfAttentionLLM)
|
| 467 |
+
* FFN hidden (SwiGLUWidthGate via .mlp.neuron_gate)
|
| 468 |
+
- Calibrate .scale_ms so proxy ≈ real latency of a keep-all model.
|
| 469 |
+
|
| 470 |
+
Public API
|
| 471 |
+
----------
|
| 472 |
+
- LatencyProxyLLM(...).predict(model, batch_or_shape) # trainer entry
|
| 473 |
+
- LatencyProxyLLM(...).predict(model, B=?, S=?, T=?) # explicit entry
|
| 474 |
+
- LatencyProxyLLM(...).debug_layer_view(...)
|
| 475 |
+
- calibrate_proxy_llm(...), calibrate_proxy_llm_from_batch(...)
|
| 476 |
+
"""
|
| 477 |
+
|
| 478 |
+
# ------------------------------------------------------------
|
| 479 |
+
# Shared tiny utils (device/dtype-safe constants)
|
| 480 |
+
# ------------------------------------------------------------
|
| 481 |
+
def _find_gate_param_or_fallback(model: nn.Module) -> torch.Tensor:
|
| 482 |
+
"""
|
| 483 |
+
Return a tensor to anchor device/dtype for proxy constants.
|
| 484 |
+
Prefer gate logits; else any parameter; else CPU fp32 scalar.
|
| 485 |
+
"""
|
| 486 |
+
for m in model.modules():
|
| 487 |
+
if hasattr(m, "head_gate") and hasattr(getattr(m, "head_gate"), "logits"):
|
| 488 |
+
return m.head_gate.logits
|
| 489 |
+
if hasattr(m, "neuron_gate") and hasattr(m.neuron_gate, "logits"):
|
| 490 |
+
return m.neuron_gate.logits
|
| 491 |
+
if hasattr(m, "logits") and isinstance(getattr(m, "logits"), torch.Tensor):
|
| 492 |
+
return m.logits
|
| 493 |
+
for p in model.parameters():
|
| 494 |
+
return p
|
| 495 |
+
return torch.tensor(0.0)
|
| 496 |
+
|
| 497 |
+
def _as_const_like(x_like: torch.Tensor, val):
|
| 498 |
+
return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# ------------------------------------------------------------
|
| 502 |
+
# Proxy
|
| 503 |
+
# ------------------------------------------------------------
|
| 504 |
+
@dataclass
|
| 505 |
+
class _WarmupOnlyPolicy:
|
| 506 |
+
"""Tiny policy shim so you can pass warmup_steps to .predict()."""
|
| 507 |
+
warmup_steps: int = 0
|
| 508 |
+
|
| 509 |
+
class LatencyProxyLLM(LatencyProxy):
|
| 510 |
+
"""
|
| 511 |
+
LLM latency proxy (ms ~ weighted FLOPs/bandwidth terms) for prefill + cached decode.
|
| 512 |
+
Accepts either a batch or explicit B,S,T.
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
def __init__(
|
| 516 |
+
self,
|
| 517 |
+
*,
|
| 518 |
+
scale_ms: float = 1.0,
|
| 519 |
+
alpha_qkv: float = 1.0,
|
| 520 |
+
alpha_scores: float = 1.0,
|
| 521 |
+
alpha_out: float = 1.0,
|
| 522 |
+
alpha_mlp: float = 1.0,
|
| 523 |
+
gate_kv_in_proxy: bool = False,
|
| 524 |
+
default_T: int = 128,
|
| 525 |
+
):
|
| 526 |
+
super().__init__()
|
| 527 |
+
self.scale_ms = float(scale_ms)
|
| 528 |
+
self.alpha_qkv = float(alpha_qkv)
|
| 529 |
+
self.alpha_scores = float(alpha_scores)
|
| 530 |
+
self.alpha_out = float(alpha_out)
|
| 531 |
+
self.alpha_mlp = float(alpha_mlp)
|
| 532 |
+
self.gate_kv_in_proxy = bool(gate_kv_in_proxy)
|
| 533 |
+
self.default_T = int(default_T)
|
| 534 |
+
|
| 535 |
+
# ---------- gate discovery ----------
|
| 536 |
+
@staticmethod
|
| 537 |
+
def _soft_heads_from_block_llm(blk) -> Optional[torch.Tensor]:
|
| 538 |
+
attn = getattr(blk, "self_attn", None)
|
| 539 |
+
if attn is None:
|
| 540 |
+
return None
|
| 541 |
+
if hasattr(attn, "kept_heads_soft") and callable(attn.kept_heads_soft):
|
| 542 |
+
return attn.kept_heads_soft()
|
| 543 |
+
logits, tau = None, None
|
| 544 |
+
if hasattr(attn, "head_gate") and hasattr(attn.head_gate, "logits"):
|
| 545 |
+
logits = attn.head_gate.logits
|
| 546 |
+
tau = float(getattr(attn.head_gate, "tau", getattr(attn, "tau", 1.5)))
|
| 547 |
+
elif hasattr(attn, "logits"):
|
| 548 |
+
logits = attn.logits
|
| 549 |
+
tau = float(getattr(attn, "tau", 1.5))
|
| 550 |
+
if logits is None:
|
| 551 |
+
return None
|
| 552 |
+
return torch.sigmoid(logits / tau).sum()
|
| 553 |
+
|
| 554 |
+
@staticmethod
|
| 555 |
+
def _find_ffn_gate_llm(blk):
|
| 556 |
+
mlp = getattr(blk, "mlp", None)
|
| 557 |
+
g = getattr(mlp, "neuron_gate", None) if mlp is not None else None
|
| 558 |
+
if g is not None and hasattr(g, "logits") and hasattr(g, "tau"):
|
| 559 |
+
return g
|
| 560 |
+
return None
|
| 561 |
+
|
| 562 |
+
def _soft_hidden_from_block_llm(self, blk, default_hidden, anchor, warm=False):
|
| 563 |
+
if warm:
|
| 564 |
+
return default_hidden
|
| 565 |
+
g = self._find_ffn_gate_llm(blk)
|
| 566 |
+
if g is None:
|
| 567 |
+
return default_hidden
|
| 568 |
+
probs = torch.sigmoid(g.logits / float(g.tau)) # [#groups]
|
| 569 |
+
group = int(getattr(g, "group", getattr(g, "group_size", 128)))
|
| 570 |
+
kept_hidden = probs.sum() * _as_const_like(anchor, group)
|
| 571 |
+
return kept_hidden
|
| 572 |
+
|
| 573 |
+
# ---------- main ----------
|
| 574 |
+
def predict( # trainer entry and explicit-shape entry unified
|
| 575 |
+
self,
|
| 576 |
+
model: nn.Module,
|
| 577 |
+
sample: Optional[TensorOrBatch] = None,
|
| 578 |
+
*,
|
| 579 |
+
B: Optional[int] = None,
|
| 580 |
+
S: Optional[int] = None,
|
| 581 |
+
T: Optional[int] = None,
|
| 582 |
+
policy: Optional[object] = None,
|
| 583 |
+
step: Optional[int] = None,
|
| 584 |
+
return_terms: bool = False,
|
| 585 |
+
):
|
| 586 |
+
# Allow explicit B,S,(T) path
|
| 587 |
+
if B is not None and S is not None:
|
| 588 |
+
ids_B, ids_S = int(B), int(S)
|
| 589 |
+
ids_T = int(T) if T is not None else int(self.default_T)
|
| 590 |
+
else:
|
| 591 |
+
if sample is None:
|
| 592 |
+
raise ValueError("LatencyProxyLLM.predict needs either a batch sample or explicit B,S.")
|
| 593 |
+
if isinstance(sample, (tuple, list)) and len(sample) in (2, 3) and all(isinstance(x, int) for x in sample):
|
| 594 |
+
# explicit (B,S) or (B,S,T)
|
| 595 |
+
ids_B, ids_S = int(sample[0]), int(sample[1])
|
| 596 |
+
ids_T = int(sample[2]) if len(sample) == 3 else int(self.default_T)
|
| 597 |
+
else:
|
| 598 |
+
ids = _ids_from_batch(sample)
|
| 599 |
+
ids_B, ids_S = int(ids.size(0)), int(ids.size(1))
|
| 600 |
+
ids_T = int(self.default_T) if T is None else int(T)
|
| 601 |
+
|
| 602 |
+
anchor = _find_gate_param_or_fallback(model)
|
| 603 |
+
|
| 604 |
+
# scalar tensors (same device/dtype)
|
| 605 |
+
B_t = _as_const_like(anchor, ids_B)
|
| 606 |
+
S_t = _as_const_like(anchor, ids_S)
|
| 607 |
+
T_t = _as_const_like(anchor, ids_T)
|
| 608 |
+
|
| 609 |
+
cfg = model.config
|
| 610 |
+
D = _as_const_like(anchor, int(cfg.hidden_size))
|
| 611 |
+
Hh = _as_const_like(anchor, int(cfg.num_attention_heads))
|
| 612 |
+
Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hh))))
|
| 613 |
+
Dh = D // Hh
|
| 614 |
+
|
| 615 |
+
warmup_steps = int(getattr(policy, "warmup_steps", 0)) if policy is not None else 0
|
| 616 |
+
warm = bool(step is not None and step < warmup_steps)
|
| 617 |
+
|
| 618 |
+
total_qkv = anchor.new_zeros(())
|
| 619 |
+
total_scores = anchor.new_zeros(())
|
| 620 |
+
total_out = anchor.new_zeros(())
|
| 621 |
+
total_mlp = anchor.new_zeros(())
|
| 622 |
+
|
| 623 |
+
default_hidden = _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D))))
|
| 624 |
+
|
| 625 |
+
layers = getattr(getattr(model, "model", model), "layers", [])
|
| 626 |
+
for blk in layers:
|
| 627 |
+
heads_soft = Hh if warm else (self._soft_heads_from_block_llm(blk) or Hh)
|
| 628 |
+
Dq = heads_soft * Dh
|
| 629 |
+
# K/V effective width
|
| 630 |
+
if self.gate_kv_in_proxy:
|
| 631 |
+
Dkv = heads_soft * Dh
|
| 632 |
+
else:
|
| 633 |
+
Dkv = Hkv * Dh
|
| 634 |
+
hidden_soft = self._soft_hidden_from_block_llm(blk, default_hidden, anchor, warm=warm)
|
| 635 |
+
|
| 636 |
+
# Prefill + decode (simplified aggregation)
|
| 637 |
+
Seff = S_t + T_t
|
| 638 |
+
|
| 639 |
+
# q/k/v linear FLOP-like terms
|
| 640 |
+
total_qkv = total_qkv + (
|
| 641 |
+
# q
|
| 642 |
+
B_t * Seff * D * Dq +
|
| 643 |
+
# k + v
|
| 644 |
+
2 * B_t * Seff * D * Dkv
|
| 645 |
+
)
|
| 646 |
+
# attention scores (prefill SxS + decode triangular)
|
| 647 |
+
total_scores = total_scores + (
|
| 648 |
+
B_t * (S_t * S_t) * heads_soft * Dh +
|
| 649 |
+
B_t * heads_soft * Dh * (T_t * S_t + (T_t * (T_t + 1)) // 2)
|
| 650 |
+
)
|
| 651 |
+
# out proj
|
| 652 |
+
total_out = total_out + B_t * Seff * Dq * D
|
| 653 |
+
# mlp
|
| 654 |
+
total_mlp = total_mlp + B_t * Seff * 2 * D * hidden_soft
|
| 655 |
+
|
| 656 |
+
flops_like = (
|
| 657 |
+
self.alpha_qkv * total_qkv
|
| 658 |
+
+ self.alpha_scores * total_scores
|
| 659 |
+
+ self.alpha_out * total_out
|
| 660 |
+
+ self.alpha_mlp * total_mlp
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
ms = flops_like * _as_const_like(anchor, self.scale_ms)
|
| 664 |
+
if return_terms:
|
| 665 |
+
return ms, {
|
| 666 |
+
"qkv": float((self.alpha_qkv * total_qkv).detach().cpu()),
|
| 667 |
+
"scores": float((self.alpha_scores * total_scores).detach().cpu()),
|
| 668 |
+
"out": float((self.alpha_out * total_out).detach().cpu()),
|
| 669 |
+
"mlp": float((self.alpha_mlp * total_mlp).detach().cpu()),
|
| 670 |
+
}
|
| 671 |
+
return ms
|
| 672 |
+
|
| 673 |
+
# ---------- per-layer debug ----------
|
| 674 |
+
@torch.no_grad()
|
| 675 |
+
def debug_layer_view(
|
| 676 |
+
self,
|
| 677 |
+
model: nn.Module,
|
| 678 |
+
*,
|
| 679 |
+
B: int,
|
| 680 |
+
S: int,
|
| 681 |
+
T: int,
|
| 682 |
+
policy: Optional[object] = None,
|
| 683 |
+
step: Optional[int] = None,
|
| 684 |
+
) -> list:
|
| 685 |
+
anchor = _find_gate_param_or_fallback(model)
|
| 686 |
+
cfg = getattr(model, "config", None)
|
| 687 |
+
D = _as_const_like(anchor, int(getattr(cfg, "hidden_size", 0)))
|
| 688 |
+
Hq = _as_const_like(anchor, int(getattr(cfg, "num_attention_heads", 0)))
|
| 689 |
+
Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hq))))
|
| 690 |
+
Dh = D // Hq
|
| 691 |
+
|
| 692 |
+
warm = False
|
| 693 |
+
if policy is not None and step is not None:
|
| 694 |
+
warm = (int(step) < int(getattr(policy, "warmup_steps", 0)))
|
| 695 |
+
|
| 696 |
+
rows = []
|
| 697 |
+
layers = getattr(getattr(model, "model", model), "layers", None) or []
|
| 698 |
+
for i, blk in enumerate(layers):
|
| 699 |
+
heads_soft = Hq if warm else (self._soft_heads_from_block_llm(blk) or Hq)
|
| 700 |
+
Dq = heads_soft * Dh
|
| 701 |
+
Dkv = (heads_soft * Dh) if self.gate_kv_in_proxy else (Hkv * Dh)
|
| 702 |
+
hidden_soft = self._soft_hidden_from_block_llm(
|
| 703 |
+
blk, _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D)))), anchor, warm=warm
|
| 704 |
+
)
|
| 705 |
+
rows.append({
|
| 706 |
+
"layer": i,
|
| 707 |
+
"heads_soft": float(heads_soft.detach().cpu()),
|
| 708 |
+
"Dq≈heads*Dh": float(Dq.detach().cpu()),
|
| 709 |
+
"Dkv_used": float(Dkv.detach().cpu()),
|
| 710 |
+
"ffn_hidden_soft": float(hidden_soft.detach().cpu()),
|
| 711 |
+
})
|
| 712 |
+
return rows
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# ------------------------------------------------------------
|
| 716 |
+
# Calibration helpers for LLM
|
| 717 |
+
# ------------------------------------------------------------
|
| 718 |
+
@torch.inference_mode()
|
| 719 |
+
def calibrate_proxy_llm(
|
| 720 |
+
proxy: LatencyProxyLLM,
|
| 721 |
+
model: nn.Module,
|
| 722 |
+
*,
|
| 723 |
+
B: int,
|
| 724 |
+
S: int,
|
| 725 |
+
T: int,
|
| 726 |
+
export_keepall_fn,
|
| 727 |
+
device: str = "cuda",
|
| 728 |
+
warmup: int = 10,
|
| 729 |
+
iters: int = 30,
|
| 730 |
+
) -> float:
|
| 731 |
+
"""
|
| 732 |
+
Calibrate proxy.scale_ms so proxy.predict(...) matches real keep-all latency for (B,S,T).
|
| 733 |
+
Returns the measured real mean latency in ms.
|
| 734 |
+
"""
|
| 735 |
+
keepall = export_keepall_fn(model).to(device).eval()
|
| 736 |
+
|
| 737 |
+
# Measure real latency (prefill + decode)
|
| 738 |
+
from core.measure import measure_latency_text_ms as _measure # adjust if your path differs
|
| 739 |
+
real_ms, _ = _measure(keepall, B=B, S=S, T=T, warmup=warmup, iters=iters, device=device)
|
| 740 |
+
|
| 741 |
+
# Soft/proxy latency on *gated* model
|
| 742 |
+
ms_like = proxy.predict(model, B=B, S=S, T=T)
|
| 743 |
+
soft_ms = float(ms_like.detach().item()) if torch.is_tensor(ms_like) else float(ms_like)
|
| 744 |
+
|
| 745 |
+
proxy.scale_ms = float(real_ms / max(soft_ms, 1e-9))
|
| 746 |
+
return real_ms
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
@torch.inference_mode()
|
| 750 |
+
def calibrate_proxy_llm_from_batch(
|
| 751 |
+
proxy: LatencyProxyLLM,
|
| 752 |
+
model: nn.Module,
|
| 753 |
+
batch: Dict[str, torch.Tensor],
|
| 754 |
+
*,
|
| 755 |
+
T: int,
|
| 756 |
+
export_keepall_fn,
|
| 757 |
+
device: str = "cuda",
|
| 758 |
+
warmup: int = 10,
|
| 759 |
+
iters: int = 30,
|
| 760 |
+
) -> Tuple[int, int, int, float]:
|
| 761 |
+
"""
|
| 762 |
+
Infers (B,S) from a batch like {'input_ids': [B,S], ...},
|
| 763 |
+
calibrates for (B,S,T), and returns (B,S,T, real_ms).
|
| 764 |
+
"""
|
| 765 |
+
input_ids = batch["input_ids"]
|
| 766 |
+
B, S = int(input_ids.size(0)), int(input_ids.size(1))
|
| 767 |
+
ms = calibrate_proxy_llm(
|
| 768 |
+
proxy, model, B=B, S=S, T=T, export_keepall_fn=export_keepall_fn,
|
| 769 |
+
device=device, warmup=warmup, iters=iters
|
| 770 |
+
)
|
| 771 |
+
return B, S, T, ms
|
core/.ipynb_checkpoints/train-checkpoint.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generic Lagrangian trainer (family-agnostic).
|
| 2 |
+
|
| 3 |
+
This module provides a light framework to optimize *gated* students against
|
| 4 |
+
teachers with a latency target enforced via a proxy + optional real probes.
|
| 5 |
+
|
| 6 |
+
It does not assume ViT/ResNet/LLM specifics; adapters provide tiny callables.
|
| 7 |
+
|
| 8 |
+
Key ingredients:
|
| 9 |
+
- Two-phase update per step: (A) weights w.r.t. KD/task, (B) gates w.r.t. KD +
|
| 10 |
+
sparsity + latency penalty with a dual variable λ.
|
| 11 |
+
- Optional periodic export + real-latency probe to correct λ.
|
| 12 |
+
- Constraint projection for gates after each step.
|
| 13 |
+
|
| 14 |
+
Adapters must provide:
|
| 15 |
+
- get_student_logits(model, x) -> Tensor
|
| 16 |
+
- get_teacher_logits(model, x) -> Tensor
|
| 17 |
+
- export_keepall(model) -> nn.Module (clean copy without gates)
|
| 18 |
+
- export_pruned(model, policy, step) -> nn.Module (transient copy for profiling)
|
| 19 |
+
|
| 20 |
+
Core modules used:
|
| 21 |
+
- `distill.KDConfig`, `distill.kd_loss`
|
| 22 |
+
- `gates.combined_penalty`, `gates.PenaltyWeights`, `gates.project_gates_into_constraints`
|
| 23 |
+
- `proxy_cost.LatencyProxy`
|
| 24 |
+
- `profiler.measure_latency_ms`
|
| 25 |
+
"""
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from typing import Callable, Optional
|
| 30 |
+
import gc
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
|
| 35 |
+
from .distill import KDConfig, kd_loss, mse_reg
|
| 36 |
+
from .gates import PenaltyWeights, Constraints, combined_penalty, project_gates_into_constraints, collect_param_groups
|
| 37 |
+
from .proxy_cost import LatencyProxy
|
| 38 |
+
from .profiler import measure_latency_ms
|
| 39 |
+
|
| 40 |
+
# -----------------------------------------------------------------------------
|
| 41 |
+
# Config
|
| 42 |
+
# -----------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class DualConfig:
|
| 46 |
+
lr: float = 0.05 # step for λ update
|
| 47 |
+
ema_beta: float = 0.5 # blend proxy-driven λ and real probe λ
|
| 48 |
+
clip: float = 10.0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class TrainerConfig:
|
| 53 |
+
kd: KDConfig = KDConfig()
|
| 54 |
+
penalties: PenaltyWeights = PenaltyWeights(l0=0.0, keep_floor_ratio=0.0, bimodality=0.0)
|
| 55 |
+
constraints: Constraints = Constraints(min_keep_ratio=0.0, min_groups=1, max_groups_drop=None)
|
| 56 |
+
|
| 57 |
+
latency_target_ms: float = 30.0
|
| 58 |
+
real_probe_every: int = 0 # steps; 0 disables real probes
|
| 59 |
+
probe_batch_override: Optional[int] = None
|
| 60 |
+
gate_warmup_steps: int = 0 # Freeze gates for early steps
|
| 61 |
+
mse_weight: float = 0.0
|
| 62 |
+
|
| 63 |
+
early_stopping_patience: int = 0
|
| 64 |
+
early_stopping_lambda: float = 1e-4
|
| 65 |
+
|
| 66 |
+
amp: bool = True
|
| 67 |
+
device: str = "cuda"
|
| 68 |
+
|
| 69 |
+
# Optimizers
|
| 70 |
+
lr_gate: float = 1e-2
|
| 71 |
+
lr_linear: float = 1e-4
|
| 72 |
+
lr_affine: float = 3e-4
|
| 73 |
+
wd_linear: float = 1e-4
|
| 74 |
+
|
| 75 |
+
# Mixed precision scaler
|
| 76 |
+
use_grad_scaler: bool = True
|
| 77 |
+
|
| 78 |
+
# Dual update
|
| 79 |
+
dual: DualConfig = DualConfig()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# -----------------------------------------------------------------------------
|
| 83 |
+
# Trainer
|
| 84 |
+
# -----------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
class LagrangeTrainer:
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
student: nn.Module,
|
| 90 |
+
teacher: nn.Module,
|
| 91 |
+
proxy: LatencyProxy,
|
| 92 |
+
*,
|
| 93 |
+
adapter_get_student_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
|
| 94 |
+
adapter_get_teacher_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
|
| 95 |
+
adapter_export_keepall: Callable[[nn.Module], nn.Module],
|
| 96 |
+
adapter_export_pruned: Callable[[nn.Module, object, int], nn.Module],
|
| 97 |
+
export_policy: object,
|
| 98 |
+
cfg: TrainerConfig,
|
| 99 |
+
) -> None:
|
| 100 |
+
self.student = student
|
| 101 |
+
self.teacher = teacher.eval()
|
| 102 |
+
for p in self.teacher.parameters():
|
| 103 |
+
p.requires_grad_(False)
|
| 104 |
+
self.proxy = proxy
|
| 105 |
+
self.get_s = adapter_get_student_logits
|
| 106 |
+
self.get_t = adapter_get_teacher_logits
|
| 107 |
+
self.export_keepall = adapter_export_keepall
|
| 108 |
+
self.export_pruned = adapter_export_pruned
|
| 109 |
+
self.export_policy = export_policy
|
| 110 |
+
self.cfg = cfg
|
| 111 |
+
|
| 112 |
+
# Build optimizers (grouped)
|
| 113 |
+
param_groups = collect_param_groups(
|
| 114 |
+
student,
|
| 115 |
+
lr_gate=cfg.lr_gate,
|
| 116 |
+
lr_linear=cfg.lr_linear,
|
| 117 |
+
lr_affine=cfg.lr_affine,
|
| 118 |
+
wd_linear=cfg.wd_linear,
|
| 119 |
+
)
|
| 120 |
+
# gates-only optimizer uses first group
|
| 121 |
+
self.opt_g = torch.optim.Adam([param_groups[0]], lr=param_groups[0]["lr"]) # type: ignore[arg-type]
|
| 122 |
+
# weights optimizer for the rest
|
| 123 |
+
self.opt_w = torch.optim.Adam(param_groups[1:])
|
| 124 |
+
|
| 125 |
+
self.scaler = torch.amp.GradScaler('cuda', enabled=(cfg.amp and cfg.use_grad_scaler))
|
| 126 |
+
self.lambda_: float = 0.0
|
| 127 |
+
self.mse_weight = cfg.mse_weight
|
| 128 |
+
|
| 129 |
+
# ---- internal helpers -----------------------------------------------------
|
| 130 |
+
def _zero_grads(self, params):
|
| 131 |
+
for p in params:
|
| 132 |
+
if p.grad is not None:
|
| 133 |
+
p.grad = None
|
| 134 |
+
|
| 135 |
+
def _has_grad(self, params) -> bool:
|
| 136 |
+
for p in params:
|
| 137 |
+
if p.grad is not None:
|
| 138 |
+
return True
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
# ---- training -------------------------------------------------------------
|
| 142 |
+
def train_epoch(self, loader, *, real_policy=None, verbose_every: int = 50):
|
| 143 |
+
device = self.cfg.device
|
| 144 |
+
self.student.train().to(device)
|
| 145 |
+
self.teacher.to(device).eval()
|
| 146 |
+
|
| 147 |
+
running = 0.0
|
| 148 |
+
seen = 0
|
| 149 |
+
lam_real = self.lambda_
|
| 150 |
+
|
| 151 |
+
total_steps = len(loader)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
for step, batch in enumerate(loader, 1):
|
| 155 |
+
# Move batch to device in a type-safe way
|
| 156 |
+
batch = _move_batch_to_device(batch, device)
|
| 157 |
+
|
| 158 |
+
# with torch.inference_mode():
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
t_logits = self.get_t(self.teacher, batch) # [B,1,V]
|
| 161 |
+
# match AMP compute dtype to avoid upcasting later
|
| 162 |
+
if self.cfg.amp:
|
| 163 |
+
# infer autocast dtype from student params (bf16 or fp16)
|
| 164 |
+
sparam = next(self.student.parameters())
|
| 165 |
+
t_logits = t_logits.to(dtype=sparam.dtype, non_blocking=True)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# -------- Pass A: WEIGHTS (KD only) --------
|
| 169 |
+
self.opt_w.zero_grad(set_to_none=True)
|
| 170 |
+
|
| 171 |
+
with torch.amp.autocast('cuda', enabled=self.cfg.amp):
|
| 172 |
+
# Adapters receive the batch object (dict/tuple/tensor)
|
| 173 |
+
s_logits = self.get_s(self.student, batch)
|
| 174 |
+
# with torch.no_grad():
|
| 175 |
+
# t_logits = self.get_t(self.teacher, batch)
|
| 176 |
+
mse = self.mse_weight*mse_reg(s_logits, t_logits, self.cfg.kd.temperature)
|
| 177 |
+
loss_w = kd_loss(s_logits, t_logits, self.cfg.kd) + mse
|
| 178 |
+
|
| 179 |
+
self.scaler.scale(loss_w).backward()
|
| 180 |
+
# Prevent gate params from changing in pass A
|
| 181 |
+
gate_params = self.opt_g.param_groups[0]["params"]
|
| 182 |
+
self._zero_grads(gate_params)
|
| 183 |
+
|
| 184 |
+
if any(p.grad is not None for pg in self.opt_w.param_groups for p in pg["params"]):
|
| 185 |
+
self.scaler.step(self.opt_w)
|
| 186 |
+
self.scaler.update()
|
| 187 |
+
else:
|
| 188 |
+
self.opt_w.zero_grad(set_to_none=True)
|
| 189 |
+
|
| 190 |
+
del s_logits
|
| 191 |
+
gc.collect()
|
| 192 |
+
torch.cuda.empty_cache()
|
| 193 |
+
|
| 194 |
+
if step > int(self.cfg.gate_warmup_steps):
|
| 195 |
+
|
| 196 |
+
# -------- Pass B: GATES (KD + sparsity + λ * gap) --------
|
| 197 |
+
self.opt_g.zero_grad(set_to_none=True)
|
| 198 |
+
with torch.amp.autocast('cuda', enabled=self.cfg.amp):
|
| 199 |
+
s_logits = self.get_s(self.student, batch)
|
| 200 |
+
# with torch.no_grad():
|
| 201 |
+
# t_logits = self.get_t(self.teacher, batch)
|
| 202 |
+
kd_g = kd_loss(s_logits, t_logits, self.cfg.kd)
|
| 203 |
+
|
| 204 |
+
# Proxy gets the batch object too; family-specific proxy can read (B,S) etc.
|
| 205 |
+
o1_ms = self.proxy.predict(self.student, batch)
|
| 206 |
+
gap = torch.relu(o1_ms - float(self.cfg.latency_target_ms))
|
| 207 |
+
reg = combined_penalty(self.student, self.cfg.penalties)
|
| 208 |
+
mse = self.mse_weight*mse_reg(s_logits, t_logits, self.cfg.kd.temperature)
|
| 209 |
+
loss_g = kd_g + _to_tensor(self.lambda_, o1_ms) * gap + reg + mse
|
| 210 |
+
|
| 211 |
+
self.scaler.scale(loss_g).backward()
|
| 212 |
+
# Prevent non-gate params from changing in pass B
|
| 213 |
+
for pg in self.opt_w.param_groups:
|
| 214 |
+
self._zero_grads(pg["params"])
|
| 215 |
+
|
| 216 |
+
if self._has_grad(self.opt_g.param_groups[0]["params"]):
|
| 217 |
+
self.scaler.step(self.opt_g)
|
| 218 |
+
self.scaler.update()
|
| 219 |
+
else:
|
| 220 |
+
self.opt_g.zero_grad(set_to_none=True)
|
| 221 |
+
else:
|
| 222 |
+
o1_ms = self.proxy.predict(self.student, batch)
|
| 223 |
+
s_logits = loss_g = kd_g = reg = torch.tensor(0.0, device=device)
|
| 224 |
+
|
| 225 |
+
# -------- Dual (λ) update using proxy --------
|
| 226 |
+
with torch.no_grad():
|
| 227 |
+
lam_proxy = max(0.0, self.lambda_ + self.cfg.dual.lr * (float(o1_ms.detach()) - self.cfg.latency_target_ms))
|
| 228 |
+
self.lambda_ = 0.5 * (lam_real + lam_proxy)
|
| 229 |
+
|
| 230 |
+
# -------- Constraint projection, optional real probe --------
|
| 231 |
+
project_gates_into_constraints(self.student, self.cfg.constraints)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if self.cfg.real_probe_every and (step % int(self.cfg.real_probe_every) == 0):
|
| 235 |
+
# Build a probe shape for latency func if needed
|
| 236 |
+
try:
|
| 237 |
+
from core.measure import measure_latency_text_ms # text-friendly
|
| 238 |
+
if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
|
| 239 |
+
B, S = int(batch["input_ids"].size(0)), int(batch["input_ids"].size(1))
|
| 240 |
+
else:
|
| 241 |
+
# Fallback: try tensor-like batch
|
| 242 |
+
x0 = batch["input_ids"] if isinstance(batch, dict) else (batch[0] if isinstance(batch, (tuple, list)) else batch)
|
| 243 |
+
B = int(x0.size(0)); S = int(x0.size(1))
|
| 244 |
+
slim = self.export_pruned(self.student, real_policy or self.export_policy, step)
|
| 245 |
+
mean_ms, p95_ms = measure_latency_text_ms(slim, B=B, S=S, T=128, device=device)
|
| 246 |
+
except Exception:
|
| 247 |
+
# If the project has a different profiler, retain compatibility:
|
| 248 |
+
from .profiler import measure_latency_ms
|
| 249 |
+
x0 = batch["input_ids"] if isinstance(batch, dict) else (batch[0] if isinstance(batch, (tuple, list)) else batch)
|
| 250 |
+
shape = (int(x0.size(0)), *list(x0.shape[1:]))
|
| 251 |
+
slim = self.export_pruned(self.student, real_policy or self.export_policy, step)
|
| 252 |
+
mean_ms, p95_ms = measure_latency_ms(slim, shape, device=device)
|
| 253 |
+
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
lam_real = max(0.0, self.lambda_ + self.cfg.dual.lr * (mean_ms - self.cfg.latency_target_ms))
|
| 256 |
+
|
| 257 |
+
# scale_correction = mean_ms / max(1e-9, o1_ms.detach())
|
| 258 |
+
# self.proxy.cfg.scale_ms = 0.9 * self.proxy.cfg.scale_ms + 0.1 * scale_correction * self.proxy.cfg.scale_ms
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if (step % verbose_every) == 0:
|
| 262 |
+
print(
|
| 263 |
+
f"Step {step}/{len(loader)} | KL={float(loss_w.item()):.6f} | MSE={float(mse.item()):.6f} | "
|
| 264 |
+
f"Gate={float(loss_g.item()):.6f} | "
|
| 265 |
+
f"proxy={float(o1_ms.detach()):.3f}ms | real_mean={mean_ms:.3f}ms p95={p95_ms:.3f}ms | λ={self.lambda_:.6f}"
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
running += float(loss_g.detach())
|
| 269 |
+
seen += _batch_size(batch)
|
| 270 |
+
|
| 271 |
+
del s_logits, t_logits, o1_ms, kd_g, reg, loss_g, loss_w
|
| 272 |
+
torch.cuda.empty_cache()
|
| 273 |
+
gc.collect()
|
| 274 |
+
|
| 275 |
+
print(f"Epoch loss {running / max(1, seen):.6f}")
|
| 276 |
+
return self.lambda_
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# -----------------------------------------------------------------------------
|
| 280 |
+
# Helpers
|
| 281 |
+
# -----------------------------------------------------------------------------
|
| 282 |
+
|
| 283 |
+
def _to_tensor(val: float, like: torch.Tensor) -> torch.Tensor:
|
| 284 |
+
return torch.as_tensor(val, device=like.device, dtype=like.dtype)
|
| 285 |
+
|
| 286 |
+
def _move_batch_to_device(batch, device: str):
|
| 287 |
+
"""
|
| 288 |
+
Supports:
|
| 289 |
+
- dict with keys 'input_ids' and optional 'attention_mask'
|
| 290 |
+
- (x,) or (x, y) tuples/lists -> move each tensor-like to device
|
| 291 |
+
- single Tensor
|
| 292 |
+
Converts attention_mask to bool (preferred by HF SDPA).
|
| 293 |
+
"""
|
| 294 |
+
if isinstance(batch, dict):
|
| 295 |
+
out = {}
|
| 296 |
+
for k, v in batch.items():
|
| 297 |
+
if torch.is_tensor(v):
|
| 298 |
+
v = v.to(device, non_blocking=True)
|
| 299 |
+
if k == "attention_mask" and v.dtype != torch.bool:
|
| 300 |
+
v = v.to(torch.bool)
|
| 301 |
+
out[k] = v
|
| 302 |
+
return out
|
| 303 |
+
|
| 304 |
+
if isinstance(batch, (tuple, list)):
|
| 305 |
+
moved = []
|
| 306 |
+
for v in batch:
|
| 307 |
+
if torch.is_tensor(v):
|
| 308 |
+
v = v.to(device, non_blocking=True)
|
| 309 |
+
moved.append(v)
|
| 310 |
+
return type(batch)(moved)
|
| 311 |
+
|
| 312 |
+
if torch.is_tensor(batch):
|
| 313 |
+
return batch.to(device, non_blocking=True)
|
| 314 |
+
|
| 315 |
+
# Unknown type: return as-is (adapters/proxy should handle it)
|
| 316 |
+
return batch
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _batch_size(batch) -> int:
|
| 320 |
+
"""Best-effort batch size for logging/averages."""
|
| 321 |
+
if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
|
| 322 |
+
return int(batch["input_ids"].size(0))
|
| 323 |
+
if torch.is_tensor(batch):
|
| 324 |
+
return int(batch.size(0))
|
| 325 |
+
if isinstance(batch, (tuple, list)) and len(batch) and torch.is_tensor(batch[0]):
|
| 326 |
+
return int(batch[0].size(0))
|
| 327 |
+
return 1
|
core/.ipynb_checkpoints/utils-checkpoint.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared utilities used across core and adapters.
|
| 2 |
+
|
| 3 |
+
Consolidates helpers that are generic (device/dtype, seeding, shapes, rounding,
|
| 4 |
+
parameter grouping, model copying, etc.). Keep this file dependency-light.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
|
| 10 |
+
|
| 11 |
+
import copy
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# -----------------------------------------------------------------------------
|
| 20 |
+
# Device / dtype helpers
|
| 21 |
+
# -----------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
def as_like(x: torch.Tensor, val) -> torch.Tensor:
|
| 24 |
+
"""Create a scalar/tensor constant on same device/dtype as `x`."""
|
| 25 |
+
return torch.as_tensor(val, device=x.device, dtype=x.dtype)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def first_param(module: nn.Module) -> torch.Tensor:
|
| 29 |
+
for p in module.parameters(recurse=True):
|
| 30 |
+
return p
|
| 31 |
+
return torch.tensor(0.0)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def to_device_dtype(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
return x.to(device=ref.device, dtype=ref.dtype)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# -----------------------------------------------------------------------------
|
| 39 |
+
# Seeding & determinism
|
| 40 |
+
# -----------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
def set_seed(seed: int = 42, deterministic: bool = False) -> None:
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
np.random.seed(seed)
|
| 45 |
+
torch.manual_seed(seed)
|
| 46 |
+
torch.cuda.manual_seed_all(seed)
|
| 47 |
+
if deterministic:
|
| 48 |
+
torch.backends.cudnn.deterministic = True
|
| 49 |
+
torch.backends.cudnn.benchmark = False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# -----------------------------------------------------------------------------
|
| 53 |
+
# Model parameter helpers
|
| 54 |
+
# -----------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
def freeze(module: nn.Module) -> None:
|
| 57 |
+
for p in module.parameters():
|
| 58 |
+
p.requires_grad_(False)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def unfreeze(module: nn.Module) -> None:
|
| 62 |
+
for p in module.parameters():
|
| 63 |
+
p.requires_grad_(True)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def count_parameters(module: nn.Module, *, trainable_only: bool = False) -> int:
|
| 67 |
+
if trainable_only:
|
| 68 |
+
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 69 |
+
return sum(p.numel() for p in module.parameters())
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# -----------------------------------------------------------------------------
|
| 73 |
+
# Shape/signature helpers
|
| 74 |
+
# -----------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def input_spec_vision(sample) -> Tuple[int, int, int]:
|
| 77 |
+
"""Accept either a 4D tensor [B,3,H,W] or a 4-tuple (B,3,H,W). Returns (B,H,W)."""
|
| 78 |
+
if isinstance(sample, torch.Tensor):
|
| 79 |
+
B, C, H, W = sample.shape
|
| 80 |
+
return int(B), int(H), int(W)
|
| 81 |
+
if isinstance(sample, (tuple, list)) and len(sample) == 4:
|
| 82 |
+
B, C, H, W = sample
|
| 83 |
+
return int(B), int(H), int(W)
|
| 84 |
+
raise ValueError("sample must be a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# -----------------------------------------------------------------------------
|
| 88 |
+
# Rounding / multiples
|
| 89 |
+
# -----------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
def round_down_multiple(n: int, m: int) -> int:
|
| 92 |
+
if m is None or m <= 1:
|
| 93 |
+
return max(1, int(n))
|
| 94 |
+
n = int(n)
|
| 95 |
+
return max(m, (n // m) * m)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def clamp_int(v: int, lo: int, hi: int) -> int:
|
| 99 |
+
return max(lo, min(int(v), hi))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# -----------------------------------------------------------------------------
|
| 103 |
+
# Slicing helpers
|
| 104 |
+
# -----------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
|
| 108 |
+
W = mat.weight.detach()
|
| 109 |
+
b = mat.bias.detach() if mat.bias is not None else None
|
| 110 |
+
if keep_out is not None:
|
| 111 |
+
idx_out = torch.as_tensor(keep_out, device=W.device)
|
| 112 |
+
W = W.index_select(0, idx_out)
|
| 113 |
+
if b is not None:
|
| 114 |
+
b = b.index_select(0, idx_out)
|
| 115 |
+
if keep_in is not None:
|
| 116 |
+
idx_in = torch.as_tensor(keep_in, device=W.device)
|
| 117 |
+
W = W.index_select(1, idx_in)
|
| 118 |
+
out_f, in_f = W.shape
|
| 119 |
+
new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
|
| 120 |
+
new.weight.copy_(W)
|
| 121 |
+
if b is not None:
|
| 122 |
+
new.bias.copy_(b)
|
| 123 |
+
return new
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# -----------------------------------------------------------------------------
|
| 127 |
+
# Copying & detaching models
|
| 128 |
+
# -----------------------------------------------------------------------------
|
| 129 |
+
|
| 130 |
+
def deepcopy_eval_cpu(module: nn.Module) -> nn.Module:
|
| 131 |
+
m = copy.deepcopy(module).cpu().eval()
|
| 132 |
+
return m
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# -----------------------------------------------------------------------------
|
| 136 |
+
# Gradient utilities
|
| 137 |
+
# -----------------------------------------------------------------------------
|
| 138 |
+
|
| 139 |
+
def zero_if_any(params: Iterable[torch.Tensor]) -> None:
|
| 140 |
+
for p in params:
|
| 141 |
+
if p.grad is not None:
|
| 142 |
+
p.grad = None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def any_grad(params: Iterable[torch.Tensor]) -> bool:
|
| 146 |
+
for p in params:
|
| 147 |
+
if p.grad is not None:
|
| 148 |
+
return True
|
| 149 |
+
return False
|
| 150 |
+
|
| 151 |
+
# -----------------------------------------------------------------------------
|
| 152 |
+
# For fine-tuning
|
| 153 |
+
# -----------------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
def ensure_trainable_parameters(module: nn.Module, *, requires_grad: bool = True) -> nn.Module:
|
| 156 |
+
"""
|
| 157 |
+
Rebuild all parameters as fresh nn.Parameter tensors (detach+clone),
|
| 158 |
+
which drops any 'inference tensor' tag and re-enables autograd.
|
| 159 |
+
"""
|
| 160 |
+
for mod in module.modules():
|
| 161 |
+
for name, p in list(mod._parameters.items()):
|
| 162 |
+
if p is None:
|
| 163 |
+
continue
|
| 164 |
+
new_p = nn.Parameter(p.detach().clone(), requires_grad=requires_grad)
|
| 165 |
+
setattr(mod, name, new_p)
|
| 166 |
+
return module
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# -----------------------------------------------------------------------------
|
| 170 |
+
# Misc
|
| 171 |
+
# -----------------------------------------------------------------------------
|
| 172 |
+
|
| 173 |
+
@dataclass
|
| 174 |
+
class ExportRounding:
|
| 175 |
+
head_floor_post: int = 1
|
| 176 |
+
head_multiple_post: int = 1
|
| 177 |
+
ffn_min_keep_ratio_post: float = 0.0
|
| 178 |
+
ffn_snap_groups_post: int = 1
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def shape_signature_vit(cfg, sample_shape: Tuple[int, int, int, int]) -> Tuple:
|
| 182 |
+
B, C, H, W = sample_shape
|
| 183 |
+
return (
|
| 184 |
+
"ViT",
|
| 185 |
+
sample_shape,
|
| 186 |
+
int(getattr(cfg, "num_attention_heads", 12)),
|
| 187 |
+
int(getattr(cfg, "hidden_size", 768)),
|
| 188 |
+
int(getattr(cfg, "intermediate_size", 3072)),
|
| 189 |
+
int(getattr(cfg, "patch_size", 16)) if not isinstance(getattr(cfg, "patch_size", 16), (tuple, list)) else tuple(getattr(cfg, "patch_size", (16, 16))),
|
| 190 |
+
)
|
core/__init__.py
ADDED
|
File without changes
|
core/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (127 Bytes). View file
|
|
|
core/__pycache__/distill.cpython-310.pyc
ADDED
|
Binary file (6.94 kB). View file
|
|
|
core/__pycache__/export.cpython-310.pyc
ADDED
|
Binary file (7.31 kB). View file
|
|
|
core/__pycache__/finetune.cpython-310.pyc
ADDED
|
Binary file (7.35 kB). View file
|
|
|
core/__pycache__/gates.cpython-310.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
core/__pycache__/profiler.cpython-310.pyc
ADDED
|
Binary file (7.68 kB). View file
|
|
|
core/__pycache__/proxy_cost.cpython-310.pyc
ADDED
|
Binary file (22.8 kB). View file
|
|
|
core/__pycache__/search_export.cpython-310.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
core/__pycache__/train.cpython-310.pyc
ADDED
|
Binary file (9.12 kB). View file
|
|
|
core/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (5.98 kB). View file
|
|
|
core/distill.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Knowledge-distillation utilities (model-family agnostic).
|
| 2 |
+
|
| 3 |
+
This module provides:
|
| 4 |
+
- Losses: KL distillation, soft cross-entropy, cosine feature loss
|
| 5 |
+
- Helper to obtain logits from models with/without built-in heads
|
| 6 |
+
- Lightweight classification head for backbone models (e.g., ViTModel)
|
| 7 |
+
- Simple evaluators (agreement %, KL) and diagnostics
|
| 8 |
+
|
| 9 |
+
Adapters may override `adapter_get_logits(model, x)` if a family needs a
|
| 10 |
+
custom extraction (e.g., language models with past_key_values).
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Callable, Optional, Protocol, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# -----------------------------------------------------------------------------
|
| 23 |
+
# Config
|
| 24 |
+
# -----------------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class KDConfig:
|
| 28 |
+
temperature: float = 2.0
|
| 29 |
+
alpha: float = 1.0 # multiplier for KL term; task loss handled outside
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
# Losses
|
| 34 |
+
# -----------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
def kl_divergence(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float = 2.0) -> torch.Tensor:
|
| 37 |
+
"""Batchmean KL(student/ T || teacher/ T) scaled by T^2 (Hinton-style)."""
|
| 38 |
+
p_s = F.log_softmax(student_logits / T, dim=-1)
|
| 39 |
+
p_t = F.softmax(teacher_logits / T, dim=-1)
|
| 40 |
+
return F.kl_div(p_s, p_t, reduction="batchmean") * (T * T)
|
| 41 |
+
|
| 42 |
+
def kd_loss(student_logits: torch.Tensor, teacher_logits: torch.Tensor, cfg: KDConfig) -> torch.Tensor:
|
| 43 |
+
return cfg.alpha * kl_divergence(student_logits, teacher_logits, T=cfg.temperature)
|
| 44 |
+
|
| 45 |
+
def mse_reg(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float = 2.0) -> torch.Tensor:
|
| 46 |
+
mse = F.mse_loss(student_logits,teacher_logits, reduction="mean")
|
| 47 |
+
return mse * (T * T)
|
| 48 |
+
|
| 49 |
+
def soft_ce(student_logits: torch.Tensor, soft_targets: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
"""Soft cross-entropy: expects `soft_targets` already normalized."""
|
| 51 |
+
logp = F.log_softmax(student_logits, dim=-1)
|
| 52 |
+
return -(soft_targets * logp).sum(dim=-1).mean()
|
| 53 |
+
|
| 54 |
+
def cosine_feature_loss(student_feats: torch.Tensor, teacher_feats: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
"""1 - cosine similarity averaged over batch and time/patch dims."""
|
| 56 |
+
s = F.normalize(student_feats, dim=-1)
|
| 57 |
+
t = F.normalize(teacher_feats, dim=-1)
|
| 58 |
+
return (1.0 - (s * t).sum(dim=-1)).mean()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# -----------------------------------------------------------------------------
|
| 63 |
+
# Logit extraction
|
| 64 |
+
# -----------------------------------------------------------------------------
|
| 65 |
+
|
| 66 |
+
class LogitsProvider(Protocol):
|
| 67 |
+
def __call__(self, model: nn.Module, x: torch.Tensor) -> torch.Tensor: ...
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ClsHead(nn.Module):
|
| 71 |
+
"""Minimal classification head: LN + Linear.
|
| 72 |
+
|
| 73 |
+
Useful when the backbone outputs hidden states (e.g., ViTModel) and you
|
| 74 |
+
want logits comparable to a teacher with a classification head.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, hidden_size: int, num_classes: int = 1000, base_head: Optional[nn.Module] = None):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.norm = nn.LayerNorm(hidden_size)
|
| 80 |
+
self.fc = nn.Linear(hidden_size, num_classes)
|
| 81 |
+
if base_head is not None:
|
| 82 |
+
# Try to load weights if shapes match (e.g., from HF classifier)
|
| 83 |
+
try:
|
| 84 |
+
self.load_state_dict(base_head.state_dict(), strict=False)
|
| 85 |
+
except Exception:
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
def forward(self, cls_token: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
return self.fc(self.norm(cls_token))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def infer_hidden_size(model: nn.Module, sample: torch.Tensor) -> int:
|
| 94 |
+
# Run a tiny forward to inspect hidden size when unknown
|
| 95 |
+
model.eval()
|
| 96 |
+
out = model(pixel_values=sample)
|
| 97 |
+
if hasattr(out, "last_hidden_state"):
|
| 98 |
+
return int(out.last_hidden_state.shape[-1])
|
| 99 |
+
if hasattr(out, "logits"):
|
| 100 |
+
return int(out.logits.shape[-1])
|
| 101 |
+
raise RuntimeError("Cannot infer hidden size; provide explicitly.")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def default_get_logits(model: nn.Module, x: torch.Tensor, *, head: Optional[nn.Module] = None) -> torch.Tensor:
|
| 105 |
+
"""Family-agnostic logits extractor.
|
| 106 |
+
|
| 107 |
+
- If model output has `.logits`, return it.
|
| 108 |
+
- Else expects `.last_hidden_state` and uses [CLS] via provided `head`.
|
| 109 |
+
"""
|
| 110 |
+
out = model(pixel_values=x)
|
| 111 |
+
if hasattr(out, "logits"):
|
| 112 |
+
return out.logits
|
| 113 |
+
if hasattr(out, "last_hidden_state"):
|
| 114 |
+
if head is None:
|
| 115 |
+
raise ValueError("Backbone returned hidden states; supply a classification head.")
|
| 116 |
+
cls_tok = out.last_hidden_state[:, 0, :]
|
| 117 |
+
return head(cls_tok)
|
| 118 |
+
raise ValueError("Model output lacks logits and last_hidden_state.")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# -----------------------------------------------------------------------------
|
| 122 |
+
# Evaluators & diagnostics
|
| 123 |
+
# -----------------------------------------------------------------------------
|
| 124 |
+
|
| 125 |
+
@torch.inference_mode()
|
| 126 |
+
def logits_std(model: nn.Module, loader, *, get_logits: LogitsProvider, batches: int = 10, device: str = "cuda") -> Tuple[float, int]:
|
| 127 |
+
s = 0.0
|
| 128 |
+
k = 0
|
| 129 |
+
for x in loader:
|
| 130 |
+
if k >= batches:
|
| 131 |
+
break
|
| 132 |
+
x = x.to(device)
|
| 133 |
+
y = get_logits(model, x)
|
| 134 |
+
s += y.std().item()
|
| 135 |
+
k += 1
|
| 136 |
+
return (s / max(1, k), k)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@torch.inference_mode()
|
| 140 |
+
def agreement_metrics(
|
| 141 |
+
student: nn.Module,
|
| 142 |
+
teacher: nn.Module,
|
| 143 |
+
loader,
|
| 144 |
+
*,
|
| 145 |
+
get_student_logits: LogitsProvider,
|
| 146 |
+
get_teacher_logits: LogitsProvider,
|
| 147 |
+
batches: int = 20,
|
| 148 |
+
T: float = 1.0,
|
| 149 |
+
device: str = "cuda",
|
| 150 |
+
) -> dict:
|
| 151 |
+
kl_sum = 0.0
|
| 152 |
+
n = 0
|
| 153 |
+
top1 = 0
|
| 154 |
+
tot = 0
|
| 155 |
+
for i, x in enumerate(loader):
|
| 156 |
+
if i >= batches:
|
| 157 |
+
break
|
| 158 |
+
x = x.to(device)
|
| 159 |
+
t = get_teacher_logits(teacher, x)
|
| 160 |
+
s = get_student_logits(student, x)
|
| 161 |
+
p_s = F.log_softmax(s / T, dim=-1)
|
| 162 |
+
p_t = F.softmax(t / T, dim=-1)
|
| 163 |
+
kl_sum += (F.kl_div(p_s, p_t, reduction="batchmean") * (T * T)).item()
|
| 164 |
+
top1 += (s.argmax(-1) == t.argmax(-1)).sum().item()
|
| 165 |
+
tot += x.size(0)
|
| 166 |
+
n += 1
|
| 167 |
+
return {"kl_TT": kl_sum / max(1, n), "top1_agreement": top1 / max(1, tot)}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# -----------------------------------------------------------------------------
|
| 171 |
+
# Small trainer helpers
|
| 172 |
+
# -----------------------------------------------------------------------------
|
| 173 |
+
|
| 174 |
+
class DualEMA:
|
| 175 |
+
"""Simple exponential moving average for a scalar (e.g., lambda or latency)."""
|
| 176 |
+
|
| 177 |
+
def __init__(self, beta: float = 0.9, value: float = 0.0):
|
| 178 |
+
self.beta = float(beta)
|
| 179 |
+
self.value = float(value)
|
| 180 |
+
|
| 181 |
+
def update(self, x: float) -> float:
|
| 182 |
+
self.value = self.beta * self.value + (1 - self.beta) * float(x)
|
| 183 |
+
return self.value
|
core/export.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core export utilities for hard-pruning and kernel-aligned rounding.
|
| 2 |
+
|
| 3 |
+
This module is *family-agnostic*. Adapters (e.g., ViT, ResNet, LLM) should:
|
| 4 |
+
1) decide which gates map to which structural dims (heads, hidden groups, channels),
|
| 5 |
+
2) obtain KEEP indices using helpers in this file, and
|
| 6 |
+
3) rebuild family-specific modules with the sliced weights.
|
| 7 |
+
|
| 8 |
+
Provided here:
|
| 9 |
+
- Rounding policies and helpers (floors, multiples, warmup keep-all)
|
| 10 |
+
- KEEP index selection from a `Gate` (or gate-like) object
|
| 11 |
+
- Generic weight slicers for Linear / Conv2d / Embedding
|
| 12 |
+
- Small safe-guards for dtype/device and shape checks
|
| 13 |
+
|
| 14 |
+
The library avoids touching family internals here. Exporters in adapters should
|
| 15 |
+
use these primitives to assemble a clean pruned model.
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Iterable, Optional, Sequence, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
from .gates import Gate, expand_group_indices
|
| 26 |
+
|
| 27 |
+
# -----------------------------------------------------------------------------
|
| 28 |
+
# Policies & rounding
|
| 29 |
+
# -----------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class Rounding:
|
| 33 |
+
"""Rounding policy for a single gated axis.
|
| 34 |
+
|
| 35 |
+
Attributes
|
| 36 |
+
----------
|
| 37 |
+
floor_groups : int
|
| 38 |
+
Minimum number of groups to keep after rounding.
|
| 39 |
+
multiple_groups : int
|
| 40 |
+
Snap the number of groups kept down to a multiple of this (>=1).
|
| 41 |
+
min_keep_ratio : float
|
| 42 |
+
Optional fractional lower bound on expected keep; applied before rounding.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
floor_groups: int = 1
|
| 46 |
+
multiple_groups: int = 1
|
| 47 |
+
min_keep_ratio: float = 0.0
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class ExportPolicy:
|
| 52 |
+
"""Export-time policy shared by families.
|
| 53 |
+
|
| 54 |
+
- `warmup_steps`: if current `step < warmup_steps`, keep-all.
|
| 55 |
+
- `rounding`: default rounding used unless adapter overrides per-axis.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
warmup_steps: int = 0
|
| 59 |
+
rounding: Rounding = Rounding()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _round_down_mult(n: int, m: int) -> int:
|
| 63 |
+
if m is None or m <= 1:
|
| 64 |
+
return max(1, int(n))
|
| 65 |
+
n = int(n)
|
| 66 |
+
return max(m, (n // m) * m)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _compute_keep_k(
|
| 70 |
+
expected_kept: float,
|
| 71 |
+
total_groups: int,
|
| 72 |
+
*,
|
| 73 |
+
rounding: Rounding,
|
| 74 |
+
) -> int:
|
| 75 |
+
# Start from nearest-integer expectation
|
| 76 |
+
k = int(round(expected_kept))
|
| 77 |
+
# Apply ratio floor, then absolute floor, then multiple snapping
|
| 78 |
+
k = max(k, int(rounding.min_keep_ratio * total_groups))
|
| 79 |
+
k = max(k, int(rounding.floor_groups))
|
| 80 |
+
k = min(k, total_groups)
|
| 81 |
+
k = _round_down_mult(k, int(rounding.multiple_groups))
|
| 82 |
+
return max(1, min(k, total_groups))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# -----------------------------------------------------------------------------
|
| 86 |
+
# KEEP index selection from a gate
|
| 87 |
+
# -----------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def keep_group_indices_from_gate(
|
| 91 |
+
gate: Gate,
|
| 92 |
+
*,
|
| 93 |
+
policy: ExportPolicy,
|
| 94 |
+
step: Optional[int] = None,
|
| 95 |
+
custom_rounding: Optional[Rounding] = None,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
"""Return sorted indices of groups to KEEP based on `gate` and policy.
|
| 98 |
+
|
| 99 |
+
If `step < warmup_steps`, returns all indices (keep-all). Otherwise, the
|
| 100 |
+
number of groups to keep is computed from the *expected keep* under the
|
| 101 |
+
current logits and snapped according to the rounding policy.
|
| 102 |
+
"""
|
| 103 |
+
G = int(gate.num_groups)
|
| 104 |
+
if step is not None and step < int(policy.warmup_steps):
|
| 105 |
+
return torch.arange(G, device=gate.logits.device)
|
| 106 |
+
|
| 107 |
+
rounding = custom_rounding or policy.rounding
|
| 108 |
+
p = torch.sigmoid(gate.logits.detach().float() / float(gate.tau))
|
| 109 |
+
k = _compute_keep_k(expected_kept=float(p.sum()), total_groups=G, rounding=rounding)
|
| 110 |
+
idx = torch.topk(p, k, largest=True).indices.sort().values
|
| 111 |
+
return idx.to(torch.long)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@torch.no_grad()
|
| 115 |
+
def keep_element_indices_from_gate(
|
| 116 |
+
gate: Gate,
|
| 117 |
+
*,
|
| 118 |
+
policy: ExportPolicy,
|
| 119 |
+
step: Optional[int] = None,
|
| 120 |
+
custom_rounding: Optional[Rounding] = None,
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""Expand kept *group* indices into element indices using `group_size`."""
|
| 123 |
+
grp_idx = keep_group_indices_from_gate(gate, policy=policy, step=step, custom_rounding=custom_rounding)
|
| 124 |
+
return expand_group_indices(grp_idx, gate.group_size)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# -----------------------------------------------------------------------------
|
| 128 |
+
# Generic slicers
|
| 129 |
+
# -----------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
|
| 133 |
+
"""Create a new Linear with selected input/output features preserved.
|
| 134 |
+
|
| 135 |
+
- `keep_out` selects rows (output features)
|
| 136 |
+
- `keep_in` selects columns (input features)
|
| 137 |
+
"""
|
| 138 |
+
W = mat.weight.detach()
|
| 139 |
+
b = mat.bias.detach() if mat.bias is not None else None
|
| 140 |
+
|
| 141 |
+
if keep_out is not None:
|
| 142 |
+
W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
|
| 143 |
+
if b is not None:
|
| 144 |
+
b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
|
| 145 |
+
if keep_in is not None:
|
| 146 |
+
W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))
|
| 147 |
+
|
| 148 |
+
out_f, in_f = W.shape
|
| 149 |
+
new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
|
| 150 |
+
new.weight.copy_(W)
|
| 151 |
+
if b is not None:
|
| 152 |
+
new.bias.copy_(b)
|
| 153 |
+
return new
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@torch.no_grad()
|
| 157 |
+
def slice_conv2d(conv: nn.Conv2d, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Conv2d:
|
| 158 |
+
"""Create a new Conv2d with selected in/out channels preserved.
|
| 159 |
+
|
| 160 |
+
Only supports standard conv2d (no groups/depthwise changes). For grouped
|
| 161 |
+
convs, the adapter should handle group alignment before calling this.
|
| 162 |
+
"""
|
| 163 |
+
W = conv.weight.detach()
|
| 164 |
+
b = conv.bias.detach() if conv.bias is not None else None
|
| 165 |
+
|
| 166 |
+
if keep_out is not None:
|
| 167 |
+
W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
|
| 168 |
+
if b is not None:
|
| 169 |
+
b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
|
| 170 |
+
if keep_in is not None:
|
| 171 |
+
W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))
|
| 172 |
+
|
| 173 |
+
out_c, in_c = W.shape[:2]
|
| 174 |
+
new = nn.Conv2d(
|
| 175 |
+
in_c,
|
| 176 |
+
out_c,
|
| 177 |
+
kernel_size=conv.kernel_size,
|
| 178 |
+
stride=conv.stride,
|
| 179 |
+
padding=conv.padding,
|
| 180 |
+
dilation=conv.dilation,
|
| 181 |
+
groups=1,
|
| 182 |
+
bias=(b is not None),
|
| 183 |
+
padding_mode=conv.padding_mode,
|
| 184 |
+
).to(W.device)
|
| 185 |
+
new.weight.copy_(W)
|
| 186 |
+
if b is not None:
|
| 187 |
+
new.bias.copy_(b)
|
| 188 |
+
return new
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@torch.no_grad()
|
| 192 |
+
def slice_embedding(emb: nn.Embedding, keep_rows: Optional[Sequence[int]] = None, keep_dim: Optional[Sequence[int]] = None) -> nn.Embedding:
|
| 193 |
+
"""Create a new Embedding with selected rows (vocab) and/or dims kept."""
|
| 194 |
+
W = emb.weight.detach()
|
| 195 |
+
if keep_rows is not None:
|
| 196 |
+
W = W.index_select(0, torch.as_tensor(keep_rows, device=W.device))
|
| 197 |
+
if keep_dim is not None:
|
| 198 |
+
W = W.index_select(1, torch.as_tensor(keep_dim, device=W.device))
|
| 199 |
+
num, dim = W.shape
|
| 200 |
+
new = nn.Embedding(num, dim, padding_idx=emb.padding_idx, max_norm=emb.max_norm, norm_type=emb.norm_type, scale_grad_by_freq=emb.scale_grad_by_freq, sparse=emb.sparse, device=W.device, dtype=W.dtype)
|
| 201 |
+
new.weight.copy_(W)
|
| 202 |
+
return new
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# -----------------------------------------------------------------------------
|
| 206 |
+
# Small helpers for adapters
|
| 207 |
+
# -----------------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
@torch.no_grad()
|
| 210 |
+
def concat_index_ranges(ranges: Sequence[Tuple[int, int]]) -> torch.Tensor:
|
| 211 |
+
"""Given [(start, end_exclusive), ...], return concatenated 1D indices."""
|
| 212 |
+
parts = [torch.arange(a, b, dtype=torch.long) for a, b in ranges if b > a]
|
| 213 |
+
return torch.cat(parts, dim=0) if parts else torch.empty(0, dtype=torch.long)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def block_indices_from_groups(groups: Sequence[int], group_size: int) -> torch.Tensor:
|
| 218 |
+
"""Convert sorted group ids to expanded feature indices."""
|
| 219 |
+
groups = torch.as_tensor(groups, dtype=torch.long)
|
| 220 |
+
return expand_group_indices(groups, int(group_size))
|
core/finetune.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/finetune.py
|
| 2 |
+
"""Post-pruning fine-tuning utilities (distillation)."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Callable, Optional, Tuple, Iterable
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from core.distill import KDConfig, kd_loss, mse_reg
|
| 12 |
+
from core.utils import ensure_trainable_parameters
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class FinetuneConfig:
|
| 19 |
+
epochs: int = 5
|
| 20 |
+
lr: float = 3e-4
|
| 21 |
+
wd: float = 0.0
|
| 22 |
+
kd: KDConfig = KDConfig(temperature=2.0, alpha=1.0)
|
| 23 |
+
amp: bool = True
|
| 24 |
+
# "auto" -> bf16 if supported else fp16; "bf16" | "fp16" | "off" also allowed
|
| 25 |
+
amp_dtype: str = "auto"
|
| 26 |
+
device: str = "cuda"
|
| 27 |
+
log_every: int = 200
|
| 28 |
+
# diagnostics
|
| 29 |
+
grad_check_every: int = 50
|
| 30 |
+
grad_warn_if_zero_steps: int = 2 # consecutive checks with zero grad -> warn
|
| 31 |
+
mse_weight: float = 0.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _autocast_and_scaler(amp: bool, amp_dtype: str) -> Tuple[torch.autocast, Optional[torch.amp.GradScaler], bool, str]:
|
| 35 |
+
"""
|
| 36 |
+
Returns (autocast_ctx, scaler_or_None, use_scaler_bool, amp_mode_str)
|
| 37 |
+
- BF16 -> autocast(bfloat16), NO GradScaler
|
| 38 |
+
- FP16 -> autocast(float16), GradScaler ENABLED
|
| 39 |
+
- OFF -> disabled autocast, NO GradScaler
|
| 40 |
+
"""
|
| 41 |
+
if not amp or amp_dtype == "off":
|
| 42 |
+
ctx = torch.amp.autocast(device_type="cuda", enabled=False)
|
| 43 |
+
return ctx, None, False, "OFF"
|
| 44 |
+
|
| 45 |
+
if amp_dtype == "auto":
|
| 46 |
+
use_bf16 = torch.cuda.is_bf16_supported()
|
| 47 |
+
elif amp_dtype == "bf16":
|
| 48 |
+
use_bf16 = True
|
| 49 |
+
elif amp_dtype == "fp16":
|
| 50 |
+
use_bf16 = False
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError(f"Unknown amp_dtype={amp_dtype!r}")
|
| 53 |
+
|
| 54 |
+
if use_bf16:
|
| 55 |
+
ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True)
|
| 56 |
+
return ctx, None, False, "BF16"
|
| 57 |
+
else:
|
| 58 |
+
ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True)
|
| 59 |
+
try:
|
| 60 |
+
scaler = torch.amp.GradScaler("cuda", enabled=True)
|
| 61 |
+
except TypeError:
|
| 62 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
| 63 |
+
return ctx, scaler, True, "FP16"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _images_from_batch(batch):
|
| 67 |
+
if isinstance(batch, dict):
|
| 68 |
+
return batch.get("pixel_values") or batch.get("input")
|
| 69 |
+
if isinstance(batch, (tuple, list)):
|
| 70 |
+
return batch[0]
|
| 71 |
+
return batch
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _param_iter_trainable(model: nn.Module) -> Iterable[torch.nn.Parameter]:
|
| 75 |
+
for p in model.parameters():
|
| 76 |
+
if p.requires_grad:
|
| 77 |
+
yield p
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _grad_norm_and_nonzero(params: Iterable[torch.nn.Parameter]) -> Tuple[float, int]:
|
| 81 |
+
total_sq, nonzero = 0.0, 0
|
| 82 |
+
for p in params:
|
| 83 |
+
g = p.grad
|
| 84 |
+
if g is None:
|
| 85 |
+
continue
|
| 86 |
+
if g.is_sparse:
|
| 87 |
+
g = g.coalesce().values()
|
| 88 |
+
gn = float(g.detach().norm().cpu())
|
| 89 |
+
if gn > 0.0:
|
| 90 |
+
nonzero += 1
|
| 91 |
+
total_sq += gn * gn
|
| 92 |
+
return (total_sq ** 0.5), nonzero
|
| 93 |
+
|
| 94 |
+
@torch.no_grad()
|
| 95 |
+
def recalibrate_bn_stats(model, loader, max_batches=200, device="cuda"):
|
| 96 |
+
model.train() # use training mode to update running stats
|
| 97 |
+
seen = 0
|
| 98 |
+
for i, batch in enumerate(loader):
|
| 99 |
+
if i >= max_batches: break
|
| 100 |
+
x = batch[0] if isinstance(batch, (tuple, list)) else batch
|
| 101 |
+
if not torch.is_tensor(x): continue
|
| 102 |
+
x = x.to(device, non_blocking=True)
|
| 103 |
+
model(x)
|
| 104 |
+
seen += x.size(0)
|
| 105 |
+
return seen
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def finetune_student(
|
| 109 |
+
student: nn.Module,
|
| 110 |
+
teacher: nn.Module,
|
| 111 |
+
train_loader,
|
| 112 |
+
*,
|
| 113 |
+
get_student_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
|
| 114 |
+
get_teacher_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
|
| 115 |
+
cfg: FinetuneConfig = FinetuneConfig(),
|
| 116 |
+
val_loader=None,
|
| 117 |
+
on_step: Optional[Callable[[int, float], None]] = None,
|
| 118 |
+
save_best=False
|
| 119 |
+
) -> nn.Module:
|
| 120 |
+
"""Fine-tune a pruned student against a frozen teacher using KD."""
|
| 121 |
+
dev = cfg.device
|
| 122 |
+
student = student.to(dev)
|
| 123 |
+
teacher = teacher.to(dev).eval()
|
| 124 |
+
for p in teacher.parameters():
|
| 125 |
+
p.requires_grad_(False)
|
| 126 |
+
for p in student.parameters():
|
| 127 |
+
p.requires_grad_(True)
|
| 128 |
+
|
| 129 |
+
# Make sure we can actually train
|
| 130 |
+
ensure_trainable_parameters(student, requires_grad=True)
|
| 131 |
+
trainable = sum(p.numel() for p in student.parameters() if p.requires_grad)
|
| 132 |
+
if trainable == 0:
|
| 133 |
+
raise RuntimeError("No trainable parameters in student — cannot finetune.")
|
| 134 |
+
|
| 135 |
+
opt = torch.optim.AdamW(
|
| 136 |
+
_param_iter_trainable(student),
|
| 137 |
+
lr=cfg.lr,
|
| 138 |
+
weight_decay=cfg.wd,
|
| 139 |
+
)
|
| 140 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs*len(train_loader), eta_min=3e-5)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
autocast_ctx, scaler, use_scaler, amp_mode = _autocast_and_scaler(cfg.amp, cfg.amp_dtype)
|
| 144 |
+
print(f"[AMP] Mode={amp_mode} | GradScaler={'ON' if use_scaler else 'OFF'} | "
|
| 145 |
+
f"KD: T={cfg.kd.temperature} alpha={cfg.kd.alpha} | LR={cfg.lr} WD={cfg.wd} | Trainable params={trainable:,}")
|
| 146 |
+
|
| 147 |
+
zero_grad_streak = 0
|
| 148 |
+
global_step = 0
|
| 149 |
+
|
| 150 |
+
T_max = cfg.kd.temperature
|
| 151 |
+
T_min = 2.0
|
| 152 |
+
kd_conf = cfg.kd
|
| 153 |
+
|
| 154 |
+
best_state = None
|
| 155 |
+
best_val = float("inf")
|
| 156 |
+
|
| 157 |
+
for ep in range(cfg.epochs):
|
| 158 |
+
student.train()
|
| 159 |
+
running, seen = 0.0, 0
|
| 160 |
+
|
| 161 |
+
for i, batch in enumerate(train_loader):
|
| 162 |
+
|
| 163 |
+
step = ep*len(train_loader) + i # global step for T scheduling
|
| 164 |
+
max_steps = cfg.epochs*len(train_loader)
|
| 165 |
+
kd_conf.temperature = T_max - (step/max_steps)*(T_max - T_min)
|
| 166 |
+
|
| 167 |
+
# print(f"Step {step}/{max_steps}, T_min={T_min}, T={kd_conf.temperature}, T_max={T_max}")
|
| 168 |
+
|
| 169 |
+
x = _images_from_batch(batch)
|
| 170 |
+
if not torch.is_tensor(x):
|
| 171 |
+
raise ValueError("Train loader must yield tensors or (tensor, target) tuples.")
|
| 172 |
+
x = x.to(dev, non_blocking=True)
|
| 173 |
+
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
t = get_teacher_logits(teacher, x)
|
| 176 |
+
# Force numerically stable dtype for the loss
|
| 177 |
+
t = t.float()
|
| 178 |
+
|
| 179 |
+
# ---- forward student under autocast
|
| 180 |
+
with autocast_ctx:
|
| 181 |
+
s = get_student_logits(student, x)
|
| 182 |
+
|
| 183 |
+
# ---- compute KD loss in FP32 (outside autocast) for stability
|
| 184 |
+
s32 = s.float()
|
| 185 |
+
mse = cfg.mse_weight*mse_reg(s32, t, kd_conf.temperature)
|
| 186 |
+
loss = kd_loss(s32, t, kd_conf) + mse
|
| 187 |
+
|
| 188 |
+
opt.zero_grad(set_to_none=True)
|
| 189 |
+
if use_scaler:
|
| 190 |
+
scaler.scale(loss).backward()
|
| 191 |
+
scaler.step(opt)
|
| 192 |
+
scaler.update()
|
| 193 |
+
else:
|
| 194 |
+
loss.backward()
|
| 195 |
+
opt.step()
|
| 196 |
+
|
| 197 |
+
# ---- diagnostics
|
| 198 |
+
bs = x.size(0)
|
| 199 |
+
running += float(loss.detach()) * bs
|
| 200 |
+
seen += bs
|
| 201 |
+
global_step += 1
|
| 202 |
+
|
| 203 |
+
if cfg.grad_check_every and (global_step % cfg.grad_check_every == 0):
|
| 204 |
+
gnorm, n_nonzero = _grad_norm_and_nonzero(_param_iter_trainable(student))
|
| 205 |
+
if n_nonzero == 0 or gnorm == 0.0:
|
| 206 |
+
zero_grad_streak += 1
|
| 207 |
+
if zero_grad_streak >= cfg.grad_warn_if_zero_steps:
|
| 208 |
+
print(f"[WARN] Step {global_step}: zero gradients detected "
|
| 209 |
+
f"(nonzero={n_nonzero}, grad_norm={gnorm:.3e}). "
|
| 210 |
+
f"Check get_student_logits, requires_grad, AMP settings, and data pipeline.")
|
| 211 |
+
else:
|
| 212 |
+
zero_grad_streak = 0
|
| 213 |
+
|
| 214 |
+
if cfg.log_every and (i + 1) % cfg.log_every == 0:
|
| 215 |
+
print(f"Step {i+1}/{len(train_loader)} (ep {ep+1}/{cfg.epochs}): "
|
| 216 |
+
f"running loss = {running / max(1, seen):.4f}")
|
| 217 |
+
|
| 218 |
+
if on_step is not None:
|
| 219 |
+
on_step(global_step, float(loss.detach()))
|
| 220 |
+
|
| 221 |
+
# free ASAP
|
| 222 |
+
del s, s32, t, loss
|
| 223 |
+
|
| 224 |
+
# ---- validation
|
| 225 |
+
if val_loader is not None:
|
| 226 |
+
_ = recalibrate_bn_stats(student, train_loader, max_batches=1000, device=cfg.device)
|
| 227 |
+
student.eval()
|
| 228 |
+
val_loss, vseen = 0.0, 0
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
for vbatch in val_loader:
|
| 231 |
+
vx = _images_from_batch(vbatch)
|
| 232 |
+
if not torch.is_tensor(vx):
|
| 233 |
+
raise ValueError("Val loader must yield tensors or (tensor, target) tuples.")
|
| 234 |
+
vx = vx.to(dev, non_blocking=True)
|
| 235 |
+
|
| 236 |
+
vt = get_teacher_logits(teacher, vx).float()
|
| 237 |
+
with autocast_ctx:
|
| 238 |
+
vs = get_student_logits(student, vx)
|
| 239 |
+
|
| 240 |
+
vs32 = vs.float()
|
| 241 |
+
vmse = cfg.mse_weight*mse_reg(vs32, vt, kd_conf.temperature)
|
| 242 |
+
vloss = kd_loss(vs32, vt, kd_conf) + vmse
|
| 243 |
+
val_loss += float(vloss.detach()) * vx.size(0)
|
| 244 |
+
vseen += vx.size(0)
|
| 245 |
+
|
| 246 |
+
mean_val = val_loss / max(1, vseen)
|
| 247 |
+
print("\n------------------------------------------------")
|
| 248 |
+
print(f"Epoch {ep+1}/{cfg.epochs}: T={kd_conf.temperature:.2f}, train={running / max(1, seen):.6f}, "
|
| 249 |
+
f"val={mean_val:.6f}")
|
| 250 |
+
|
| 251 |
+
if save_best and (mean_val < best_val):
|
| 252 |
+
best_val = mean_val
|
| 253 |
+
best_state = copy.deepcopy(student.state_dict())
|
| 254 |
+
|
| 255 |
+
print("------------------------------------------------\n")
|
| 256 |
+
|
| 257 |
+
else:
|
| 258 |
+
print(f"Epoch {ep+1}/{cfg.epochs}: train={running / max(1, seen):.6f}")
|
| 259 |
+
|
| 260 |
+
scheduler.step()
|
| 261 |
+
|
| 262 |
+
if save_best and val_loader is not None and best_state is not None:
|
| 263 |
+
student.load_state_dict(best_state)
|
| 264 |
+
|
| 265 |
+
student.eval()
|
| 266 |
+
return student
|
| 267 |
+
|
core/gates.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core gating primitives for hardware-aware model optimization.
|
| 2 |
+
|
| 3 |
+
This module defines:
|
| 4 |
+
- Base `Gate` interface (nn.Module) with a small, consistent API
|
| 5 |
+
- Concrete gates: HeadGate, GroupGate, LayerGate
|
| 6 |
+
- Straight-Through (ST) relaxed Bernoulli with Gumbel noise
|
| 7 |
+
- Penalties/regularizers commonly used during training
|
| 8 |
+
- Constraint projection helpers
|
| 9 |
+
|
| 10 |
+
Design goals:
|
| 11 |
+
- TorchScript-friendly where possible
|
| 12 |
+
- Minimal assumptions about model family (ViT, ResNet, LLM)
|
| 13 |
+
- Gates operate on *groups* of units; group_size controls expansion
|
| 14 |
+
- No direct knowledge of attention/FFN/etc. — adapters wire masks
|
| 15 |
+
|
| 16 |
+
Typical usage (adapter side):
|
| 17 |
+
>>> gate = GroupGate(num_groups=H, group_size=Dh, tau=1.5, init_logit=3.0)
|
| 18 |
+
>>> m = gate.mask(training=self.training) # [H * Dh]
|
| 19 |
+
>>> tensor = tensor * m.view(1, H, 1, Dh) # example broadcast
|
| 20 |
+
|
| 21 |
+
Penalties scan the module tree for objects exposing `.logits` and `.tau`.
|
| 22 |
+
"""
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
from dataclasses import dataclass
|
| 26 |
+
from typing import Iterable, List, Optional, Tuple
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
# Utilities
|
| 34 |
+
# -----------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
def _as_like(x: torch.Tensor, val) -> torch.Tensor:
|
| 37 |
+
return torch.as_tensor(val, device=x.device, dtype=x.dtype)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _gumbel_like(x: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
# Uniform(0,1) clamped for numerical stability
|
| 42 |
+
u = torch.rand_like(x).clamp_(1e-6, 1 - 1e-6)
|
| 43 |
+
return u.log().neg_() - (1 - u).log().neg_() # log(u) - log(1-u)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# -----------------------------------------------------------------------------
|
| 47 |
+
# Base Gate
|
| 48 |
+
# -----------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
class Gate(nn.Module):
|
| 51 |
+
"""Abstract gate over *groups*.
|
| 52 |
+
|
| 53 |
+
A gate controls `num_groups` binary decisions, typically expanded by
|
| 54 |
+
`group_size` when applied to tensors. For example, gating ViT MLP hidden
|
| 55 |
+
units in groups of 16: `num_groups = hidden // 16`, `group_size = 16`.
|
| 56 |
+
|
| 57 |
+
Subclasses may override `sample_mask` for custom relaxations.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
num_groups: int,
|
| 63 |
+
*,
|
| 64 |
+
group_size: int = 1,
|
| 65 |
+
tau: float = 1.5,
|
| 66 |
+
init_logit: float = 3.0,
|
| 67 |
+
hard_during_eval: bool = True,
|
| 68 |
+
) -> None:
|
| 69 |
+
super().__init__()
|
| 70 |
+
assert num_groups > 0 and group_size > 0
|
| 71 |
+
self.num_groups = int(num_groups)
|
| 72 |
+
self.group_size = int(group_size)
|
| 73 |
+
self.tau = float(tau)
|
| 74 |
+
self.hard_during_eval = bool(hard_during_eval)
|
| 75 |
+
self.logits = nn.Parameter(torch.full((self.num_groups,), float(init_logit)))
|
| 76 |
+
|
| 77 |
+
# ----- probabilities & stats ------------------------------------------------
|
| 78 |
+
def probs(self) -> torch.Tensor:
|
| 79 |
+
"""Return per-group keep probabilities (sigmoid(logit / tau))."""
|
| 80 |
+
# Using /tau here makes `tau` affect both train and eval statistics
|
| 81 |
+
return torch.sigmoid(self.logits / self.tau)
|
| 82 |
+
|
| 83 |
+
def expected_kept(self) -> torch.Tensor:
|
| 84 |
+
"""Expected *elements* kept (groups × group_size)."""
|
| 85 |
+
return self.probs().sum() * _as_like(self.logits, self.group_size)
|
| 86 |
+
|
| 87 |
+
# ----- masks ----------------------------------------------------------------
|
| 88 |
+
def _hard_mask(self) -> torch.Tensor:
|
| 89 |
+
m = (self.logits > 0).to(self.logits.dtype)
|
| 90 |
+
return m.repeat_interleave(self.group_size)
|
| 91 |
+
|
| 92 |
+
def _soft_st_mask(self) -> torch.Tensor:
|
| 93 |
+
# Straight-through relaxed Bernoulli via Gumbel-sigmoid
|
| 94 |
+
s = _gumbel_like(self.logits)
|
| 95 |
+
y = torch.sigmoid((self.logits + s) / self.tau)
|
| 96 |
+
y_hard = (y > 0.5).to(y.dtype)
|
| 97 |
+
m = (y_hard - y).detach() + y
|
| 98 |
+
return m.repeat_interleave(self.group_size)
|
| 99 |
+
|
| 100 |
+
def mask(self, training: Optional[bool] = None) -> torch.Tensor:
|
| 101 |
+
"""Return a 1D mask of length `num_groups * group_size`.
|
| 102 |
+
|
| 103 |
+
- Training: straight-through relaxed mask
|
| 104 |
+
- Eval: hard (thresholded) mask if `hard_during_eval` else probs expanded
|
| 105 |
+
"""
|
| 106 |
+
if training is None:
|
| 107 |
+
training = self.training
|
| 108 |
+
if training:
|
| 109 |
+
return self._soft_st_mask()
|
| 110 |
+
if self.hard_during_eval:
|
| 111 |
+
return self._hard_mask()
|
| 112 |
+
p = self.probs()
|
| 113 |
+
return p.repeat_interleave(self.group_size)
|
| 114 |
+
|
| 115 |
+
# ----- export helpers -------------------------------------------------------
|
| 116 |
+
@torch.no_grad()
|
| 117 |
+
def topk_indices(self, k: int) -> torch.Tensor:
|
| 118 |
+
k = int(max(1, min(k, self.num_groups)))
|
| 119 |
+
return torch.topk(self.logits, k, largest=True).indices.sort().values
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def threshold_count(self) -> int:
|
| 123 |
+
# Rounds to the nearest integer expectation, then clamps
|
| 124 |
+
p = self.probs()
|
| 125 |
+
k = int(torch.round(p.sum()).item())
|
| 126 |
+
return max(1, min(k, self.num_groups))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# -----------------------------------------------------------------------------
|
| 130 |
+
# Concrete gates
|
| 131 |
+
# -----------------------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
class HeadGate(Gate):
|
| 134 |
+
"""Per-head gate. Often used with attention where group_size=head_dim."""
|
| 135 |
+
|
| 136 |
+
def __init__(self, num_heads: int, *, head_dim: int = 1, **kw):
|
| 137 |
+
super().__init__(num_groups=num_heads, group_size=head_dim, **kw)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class GroupGate(Gate):
|
| 141 |
+
"""Generic group gate (e.g., MLP hidden grouped by `group_size`)."""
|
| 142 |
+
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class LayerGate(Gate):
|
| 147 |
+
"""One bit per layer (group_size=1)."""
|
| 148 |
+
|
| 149 |
+
def __init__(self, num_layers: int, **kw):
|
| 150 |
+
super().__init__(num_groups=num_layers, group_size=1, **kw)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# -----------------------------------------------------------------------------
|
| 154 |
+
# Penalties / Regularizers
|
| 155 |
+
# -----------------------------------------------------------------------------
|
| 156 |
+
|
| 157 |
+
@dataclass
|
| 158 |
+
class PenaltyWeights:
|
| 159 |
+
"""Scalars to blend regularization terms.
|
| 160 |
+
|
| 161 |
+
Attributes
|
| 162 |
+
----------
|
| 163 |
+
l0 : float
|
| 164 |
+
Weight for the L0-like sparsity term (sum of keep probs).
|
| 165 |
+
keep_floor_ratio : float
|
| 166 |
+
Soft constraint: expected kept groups >= floor_ratio * groups.
|
| 167 |
+
bimodality : float
|
| 168 |
+
Encourages probabilities away from 0.5.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
l0: float = 0.0
|
| 172 |
+
keep_floor_ratio: float = 0.0
|
| 173 |
+
bimodality: float = 0.0
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def iter_gates(module: nn.Module) -> Iterable[Gate]:
|
| 177 |
+
for m in module.modules():
|
| 178 |
+
if isinstance(m, Gate):
|
| 179 |
+
yield m
|
| 180 |
+
else:
|
| 181 |
+
# Duck-typing compatibility: any module with `.logits` and `.tau`
|
| 182 |
+
if hasattr(m, "logits") and hasattr(m, "tau"):
|
| 183 |
+
logits = getattr(m, "logits")
|
| 184 |
+
if isinstance(logits, torch.Tensor) and logits.dim() == 1:
|
| 185 |
+
# Wrap view: expose basic API via adapter shim
|
| 186 |
+
g = _TensorBackedGateShim(m)
|
| 187 |
+
yield g
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class _TensorBackedGateShim:
|
| 191 |
+
"""Lightweight adapter exposing .logits, .tau, .group_size, .num_groups.
|
| 192 |
+
|
| 193 |
+
It is intentionally NOT an nn.Module and NOT a Gate subclass to avoid
|
| 194 |
+
ctor/signature constraints and registration side-effects. It's only used
|
| 195 |
+
by projection/regularization utilities that read/update .logits.
|
| 196 |
+
"""
|
| 197 |
+
__slots__ = ("host", "logits", "tau", "group_size", "num_groups")
|
| 198 |
+
|
| 199 |
+
def __init__(self, host):
|
| 200 |
+
self.host = host
|
| 201 |
+
# logits must be a Tensor/Parameter on the host
|
| 202 |
+
self.logits = getattr(host, "logits")
|
| 203 |
+
# default tau=1.5 if not present
|
| 204 |
+
self.tau = float(getattr(host, "tau", 1.5))
|
| 205 |
+
# support either group_size or group attribute names
|
| 206 |
+
self.group_size = int(getattr(host, "group_size", getattr(host, "group", 1)))
|
| 207 |
+
self.num_groups = int(self.logits.numel())
|
| 208 |
+
|
| 209 |
+
def forward(self, *args, **kwargs): # pragma: no cover - shim is not used as a layer
|
| 210 |
+
raise RuntimeError("Gate shim is not a callable layer")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def l0_like_sparsity(module: nn.Module) -> torch.Tensor:
|
| 214 |
+
"""Sum of keep probabilities across all gates (acts like L0/L1)."""
|
| 215 |
+
val = _as_like(next(module.parameters(), torch.tensor(0.0, device="cpu")), 0.0)
|
| 216 |
+
out = torch.as_tensor(0.0, device=val.device, dtype=val.dtype)
|
| 217 |
+
for g in iter_gates(module):
|
| 218 |
+
out = out + g.probs().sum()
|
| 219 |
+
return out
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def keep_floor(module: nn.Module, floor_ratio: float) -> torch.Tensor:
|
| 223 |
+
"""Soft penalty if expected-kept falls below a fraction per gate.
|
| 224 |
+
|
| 225 |
+
For each gate with G groups, penalize relu(floor*G - sum(p)).
|
| 226 |
+
"""
|
| 227 |
+
if floor_ratio <= 0:
|
| 228 |
+
return torch.tensor(0.0, device=next(module.parameters(), torch.tensor(0.0)).device)
|
| 229 |
+
floor_ratio = float(floor_ratio)
|
| 230 |
+
val = _as_like(next(module.parameters(), torch.tensor(0.0, device="cpu")), 0.0)
|
| 231 |
+
out = torch.as_tensor(0.0, device=val.device, dtype=val.dtype)
|
| 232 |
+
for g in iter_gates(module):
|
| 233 |
+
G = _as_like(val, g.num_groups)
|
| 234 |
+
floor_groups = _as_like(val, max(1.0, floor_ratio * float(g.num_groups)))
|
| 235 |
+
out = out + F.relu(floor_groups - g.probs().sum())
|
| 236 |
+
return out
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def bimodality(module: nn.Module) -> torch.Tensor:
|
| 240 |
+
"""Sum over p*(1-p) to push probs away from 0.5 (minimum at 0 or 1)."""
|
| 241 |
+
val = _as_like(next(module.parameters(), torch.tensor(0.0, device="cpu")), 0.0)
|
| 242 |
+
out = torch.as_tensor(0.0, device=val.device, dtype=val.dtype)
|
| 243 |
+
for g in iter_gates(module):
|
| 244 |
+
p = g.probs()
|
| 245 |
+
out = out + (p * (1.0 - p)).sum()
|
| 246 |
+
return out
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def combined_penalty(
|
| 250 |
+
module: nn.Module,
|
| 251 |
+
weights: PenaltyWeights,
|
| 252 |
+
) -> torch.Tensor:
|
| 253 |
+
out = torch.tensor(0.0, device=next(module.parameters(), torch.tensor(0.0)).device)
|
| 254 |
+
if weights.l0:
|
| 255 |
+
out = out + weights.l0 * l0_like_sparsity(module)
|
| 256 |
+
if weights.keep_floor_ratio:
|
| 257 |
+
out = out + keep_floor(module, weights.keep_floor_ratio)
|
| 258 |
+
if weights.bimodality:
|
| 259 |
+
out = out + weights.bimodality * bimodality(module)
|
| 260 |
+
return out
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# -----------------------------------------------------------------------------
|
| 264 |
+
# Constraint projection
|
| 265 |
+
# -----------------------------------------------------------------------------
|
| 266 |
+
|
| 267 |
+
@dataclass
|
| 268 |
+
class Constraints:
|
| 269 |
+
"""High-level feasibility constraints.
|
| 270 |
+
|
| 271 |
+
* min_keep_ratio: per-gate minimum fraction of groups to keep (soft cap via
|
| 272 |
+
projection onto [min_k, G]).
|
| 273 |
+
* min_groups: absolute lower bound per gate (after rounding).
|
| 274 |
+
* max_groups_drop: optional ceiling on groups dropped per gate.
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
min_keep_ratio: float = 0.0
|
| 278 |
+
min_groups: int = 1
|
| 279 |
+
max_groups_drop: Optional[int] = None
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@torch.no_grad()
|
| 283 |
+
def project_gates_into_constraints(module: nn.Module, cons: Constraints) -> None:
|
| 284 |
+
"""Project gate logits so that expected kept groups respect constraints.
|
| 285 |
+
|
| 286 |
+
We rescale logits by an additive bias to achieve a target sum of probs when
|
| 287 |
+
violating the lower/upper bounds. This is a light-touch projection that
|
| 288 |
+
keeps relative ordering intact.
|
| 289 |
+
"""
|
| 290 |
+
for g in iter_gates(module):
|
| 291 |
+
p = torch.sigmoid(g.logits / g.tau)
|
| 292 |
+
G = p.numel()
|
| 293 |
+
# Lower bound
|
| 294 |
+
min_keep = max(cons.min_groups, int(cons.min_keep_ratio * G))
|
| 295 |
+
if p.sum().item() < min_keep:
|
| 296 |
+
# Additive bias to increase sum(p)
|
| 297 |
+
bias = torch.tensor(2.0, device=p.device, dtype=p.dtype)
|
| 298 |
+
# Increase iteratively but cheaply
|
| 299 |
+
for _ in range(6):
|
| 300 |
+
p = torch.sigmoid((g.logits + bias) / g.tau)
|
| 301 |
+
if p.sum().item() >= min_keep:
|
| 302 |
+
break
|
| 303 |
+
bias = bias * 2
|
| 304 |
+
g.logits.add_(bias)
|
| 305 |
+
# Optional upper bound on drops
|
| 306 |
+
if cons.max_groups_drop is not None:
|
| 307 |
+
max_drop = int(cons.max_groups_drop)
|
| 308 |
+
max_keep = max(1, G - max_drop)
|
| 309 |
+
if p.sum().item() > max_keep:
|
| 310 |
+
bias = torch.tensor(-2.0, device=p.device, dtype=p.dtype)
|
| 311 |
+
for _ in range(6):
|
| 312 |
+
p = torch.sigmoid((g.logits + bias) / g.tau)
|
| 313 |
+
if p.sum().item() <= max_keep:
|
| 314 |
+
break
|
| 315 |
+
bias = bias * 2
|
| 316 |
+
g.logits.add_(bias)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# -----------------------------------------------------------------------------
|
| 320 |
+
# Export helpers (indices from gates)
|
| 321 |
+
# -----------------------------------------------------------------------------
|
| 322 |
+
|
| 323 |
+
@torch.no_grad()
|
| 324 |
+
def topk_group_indices(g: Gate, keep_k: Optional[int] = None) -> torch.Tensor:
|
| 325 |
+
"""Return sorted group indices to KEEP based on logits/probs.
|
| 326 |
+
|
| 327 |
+
If `keep_k` is None, use nearest-integer of expected kept.
|
| 328 |
+
"""
|
| 329 |
+
if keep_k is None:
|
| 330 |
+
keep_k = g.threshold_count()
|
| 331 |
+
idx = torch.topk(g.logits, int(keep_k), largest=True).indices
|
| 332 |
+
return idx.sort().values
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
@torch.no_grad()
|
| 336 |
+
def expand_group_indices(idx: torch.Tensor, group_size: int) -> torch.Tensor:
|
| 337 |
+
"""Expand group indices into element indices by `group_size` blocks."""
|
| 338 |
+
if group_size == 1:
|
| 339 |
+
return idx.clone()
|
| 340 |
+
starts = idx * group_size
|
| 341 |
+
parts = [torch.arange(s, s + group_size, device=idx.device) for s in starts]
|
| 342 |
+
return torch.cat(parts, dim=0).long()
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# -----------------------------------------------------------------------------
|
| 346 |
+
# Parameter utilities
|
| 347 |
+
# -----------------------------------------------------------------------------
|
| 348 |
+
|
| 349 |
+
def collect_gate_params(module: nn.Module) -> List[nn.Parameter]:
|
| 350 |
+
return [g.logits for g in iter_gates(module) if isinstance(g.logits, torch.Tensor)]
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def collect_param_groups(
|
| 354 |
+
module: nn.Module,
|
| 355 |
+
*,
|
| 356 |
+
lr_gate: float = 1e-2,
|
| 357 |
+
lr_linear: float = 1e-4,
|
| 358 |
+
lr_affine: float = 3e-4,
|
| 359 |
+
wd_linear: float = 1e-4,
|
| 360 |
+
) -> List[dict]:
|
| 361 |
+
"""Convenience grouping matching common training setups.
|
| 362 |
+
|
| 363 |
+
Group 0: gate logits (no weight decay)
|
| 364 |
+
Group 1: linear weights (with weight decay)
|
| 365 |
+
Group 2: linear biases (no decay)
|
| 366 |
+
Group 3: norm affine params (no decay)
|
| 367 |
+
"""
|
| 368 |
+
gates, ln_affine, linear_w, linear_b = [], [], [], []
|
| 369 |
+
for n, p in module.named_parameters():
|
| 370 |
+
if not p.requires_grad:
|
| 371 |
+
continue
|
| 372 |
+
if n.endswith((".logits", ".head_gate", ".channel_gate")):
|
| 373 |
+
gates.append(p)
|
| 374 |
+
continue
|
| 375 |
+
is_linear_path = (".weight" in n or ".bias" in n) and (
|
| 376 |
+
".dense" in n or ".query" in n or ".key" in n or ".value" in n or ".proj" in n
|
| 377 |
+
)
|
| 378 |
+
if n.endswith(".weight") and is_linear_path:
|
| 379 |
+
linear_w.append(p)
|
| 380 |
+
elif n.endswith(".bias") and is_linear_path:
|
| 381 |
+
linear_b.append(p)
|
| 382 |
+
elif "layernorm" in n.lower() or "layer_norm" in n.lower() or "LayerNorm" in n:
|
| 383 |
+
ln_affine.append(p)
|
| 384 |
+
return [
|
| 385 |
+
{"params": gates, "lr": lr_gate, "weight_decay": 0.0},
|
| 386 |
+
{"params": linear_w, "lr": lr_linear, "weight_decay": wd_linear},
|
| 387 |
+
{"params": linear_b, "lr": lr_linear, "weight_decay": 0.0},
|
| 388 |
+
{"params": ln_affine, "lr": lr_affine, "weight_decay": 0.0},
|
| 389 |
+
]
|
core/profiler.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simple, robust latency measurement utilities.
|
| 2 |
+
|
| 3 |
+
This module provides GPU-friendly profilers with warmup, multiple repeats,
|
| 4 |
+
median/percentile reporting, and optional outlier rejection via MAD.
|
| 5 |
+
|
| 6 |
+
Design goals:
|
| 7 |
+
- Family-agnostic: take a callable `forward(model, x)` or rely on HF `.forward`
|
| 8 |
+
- Deterministic when desired; avoids autograd by default
|
| 9 |
+
- Works with CUDA or CPU; uses `torch.cuda.Event` for accurate GPU timing
|
| 10 |
+
|
| 11 |
+
Key APIs:
|
| 12 |
+
- measure_latency_ms(model, input_shape | input_tensor, ...)
|
| 13 |
+
- profile(model, sample, settings) -> {mean, p50, p90, p95, p99}
|
| 14 |
+
- LatencyProfiler(settings).measure(...)
|
| 15 |
+
- profile_many_shapes(model, shapes, settings)
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from statistics import median
|
| 21 |
+
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple
|
| 22 |
+
|
| 23 |
+
import contextlib
|
| 24 |
+
import math
|
| 25 |
+
import time
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# -----------------------------------------------------------------------------
|
| 32 |
+
# Settings
|
| 33 |
+
# -----------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ProfileSettings:
|
| 37 |
+
warmup: int = 10
|
| 38 |
+
iters: int = 50
|
| 39 |
+
percentile: Sequence[int] = (50, 90, 95, 99)
|
| 40 |
+
sync_each_iter: bool = True
|
| 41 |
+
use_inference_mode: bool = True
|
| 42 |
+
cuda_graph: bool = False # advanced users can enable with static shapes
|
| 43 |
+
reject_outliers_mad: float = 0.0 # e.g., 3.5 to drop extreme spikes
|
| 44 |
+
cudnn_benchmark: bool = True
|
| 45 |
+
deterministic: bool = False # sets cudnn.deterministic
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# -----------------------------------------------------------------------------
|
| 49 |
+
# Context helpers
|
| 50 |
+
# -----------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
@contextlib.contextmanager
|
| 53 |
+
def _torch_backend_ctx(settings: ProfileSettings):
|
| 54 |
+
prev_bench = torch.backends.cudnn.benchmark
|
| 55 |
+
prev_det = torch.backends.cudnn.deterministic
|
| 56 |
+
try:
|
| 57 |
+
torch.backends.cudnn.benchmark = bool(settings.cudnn_benchmark)
|
| 58 |
+
torch.backends.cudnn.deterministic = bool(settings.deterministic)
|
| 59 |
+
yield
|
| 60 |
+
finally:
|
| 61 |
+
torch.backends.cudnn.benchmark = prev_bench
|
| 62 |
+
torch.backends.cudnn.deterministic = prev_det
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _percentiles(sorted_vals: Sequence[float], qs: Sequence[int]) -> Dict[int, float]:
|
| 66 |
+
n = len(sorted_vals)
|
| 67 |
+
if n == 0:
|
| 68 |
+
return {q: float("nan") for q in qs}
|
| 69 |
+
out = {}
|
| 70 |
+
for q in qs:
|
| 71 |
+
if n == 1:
|
| 72 |
+
out[q] = sorted_vals[0]
|
| 73 |
+
continue
|
| 74 |
+
k = (q / 100.0) * (n - 1)
|
| 75 |
+
f = math.floor(k)
|
| 76 |
+
c = min(n - 1, f + 1)
|
| 77 |
+
if f == c:
|
| 78 |
+
out[q] = sorted_vals[int(k)]
|
| 79 |
+
else:
|
| 80 |
+
d0 = sorted_vals[f] * (c - k)
|
| 81 |
+
d1 = sorted_vals[c] * (k - f)
|
| 82 |
+
out[q] = d0 + d1
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _apply_mad_filter(vals: Sequence[float], thresh: float) -> Sequence[float]:
|
| 87 |
+
if thresh <= 0 or len(vals) < 5:
|
| 88 |
+
return vals
|
| 89 |
+
med = median(vals)
|
| 90 |
+
dev = [abs(v - med) for v in vals]
|
| 91 |
+
mad = median(dev) or 1e-12
|
| 92 |
+
keep = [v for v, d in zip(vals, dev) if (d / mad) <= thresh]
|
| 93 |
+
return keep if keep else vals
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# -----------------------------------------------------------------------------
|
| 97 |
+
# Core measurement
|
| 98 |
+
# -----------------------------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
@torch.inference_mode()
|
| 101 |
+
def measure_latency_ms(
|
| 102 |
+
model: nn.Module,
|
| 103 |
+
sample: torch.Tensor | Tuple[int, ...],
|
| 104 |
+
*,
|
| 105 |
+
settings: Optional[ProfileSettings] = None,
|
| 106 |
+
device: str = "cuda",
|
| 107 |
+
forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
|
| 108 |
+
) -> Tuple[float, float]:
|
| 109 |
+
"""Return (mean_ms, p95_ms) over `iters` measurements.
|
| 110 |
+
|
| 111 |
+
If `sample` is a shape tuple, a random tensor is created on-device.
|
| 112 |
+
The default forward calls `model(pixel_values=x)` if available, else `model(x)`.
|
| 113 |
+
"""
|
| 114 |
+
cfg = settings or ProfileSettings()
|
| 115 |
+
|
| 116 |
+
with _torch_backend_ctx(cfg):
|
| 117 |
+
m = model.to(device).eval()
|
| 118 |
+
if isinstance(sample, torch.Tensor):
|
| 119 |
+
x = sample.to(device)
|
| 120 |
+
else:
|
| 121 |
+
x = torch.randn(*sample, device=device)
|
| 122 |
+
|
| 123 |
+
# Default forward
|
| 124 |
+
def _fwd(mod, inp):
|
| 125 |
+
if hasattr(mod, "forward"):
|
| 126 |
+
try:
|
| 127 |
+
return mod(pixel_values=inp)
|
| 128 |
+
except TypeError:
|
| 129 |
+
return mod(inp)
|
| 130 |
+
return mod(inp)
|
| 131 |
+
|
| 132 |
+
fn = forward_fn or _fwd
|
| 133 |
+
|
| 134 |
+
# Warmup
|
| 135 |
+
if torch.cuda.is_available() and device.startswith("cuda"):
|
| 136 |
+
for _ in range(cfg.warmup):
|
| 137 |
+
_ = fn(m, x)
|
| 138 |
+
torch.cuda.synchronize()
|
| 139 |
+
else:
|
| 140 |
+
for _ in range(cfg.warmup):
|
| 141 |
+
_ = fn(m, x)
|
| 142 |
+
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
| 143 |
+
|
| 144 |
+
times: list[float] = []
|
| 145 |
+
if torch.cuda.is_available() and device.startswith("cuda"):
|
| 146 |
+
for _ in range(cfg.iters):
|
| 147 |
+
t0 = torch.cuda.Event(enable_timing=True)
|
| 148 |
+
t1 = torch.cuda.Event(enable_timing=True)
|
| 149 |
+
t0.record()
|
| 150 |
+
_ = fn(m, x)
|
| 151 |
+
t1.record()
|
| 152 |
+
if cfg.sync_each_iter:
|
| 153 |
+
torch.cuda.synchronize()
|
| 154 |
+
times.append(t0.elapsed_time(t1)) # milliseconds
|
| 155 |
+
else:
|
| 156 |
+
for _ in range(cfg.iters):
|
| 157 |
+
t0 = time.perf_counter()
|
| 158 |
+
_ = fn(m, x)
|
| 159 |
+
if cfg.sync_each_iter and torch.cuda.is_available():
|
| 160 |
+
torch.cuda.synchronize()
|
| 161 |
+
t1 = time.perf_counter()
|
| 162 |
+
times.append((t1 - t0) * 1000.0)
|
| 163 |
+
|
| 164 |
+
times = sorted(_apply_mad_filter(times, cfg.reject_outliers_mad))
|
| 165 |
+
mean_ms = sum(times) / max(1, len(times))
|
| 166 |
+
p = _percentiles(times, cfg.percentile)
|
| 167 |
+
p95 = p.get(95, times[int(0.95 * (len(times) - 1))] if times else float("nan"))
|
| 168 |
+
return mean_ms, p95
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# Higher level wrapper returning multiple percentiles
|
| 172 |
+
@torch.inference_mode()
|
| 173 |
+
def profile(
|
| 174 |
+
model: nn.Module,
|
| 175 |
+
sample: torch.Tensor | Tuple[int, ...],
|
| 176 |
+
*,
|
| 177 |
+
settings: Optional[ProfileSettings] = None,
|
| 178 |
+
device: str = "cuda",
|
| 179 |
+
forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
|
| 180 |
+
) -> Dict[str, float]:
|
| 181 |
+
cfg = settings or ProfileSettings()
|
| 182 |
+
mean_ms, _ = measure_latency_ms(model, sample, settings=cfg, device=device, forward_fn=forward_fn)
|
| 183 |
+
# Re-run percentile calc on same settings for consistency
|
| 184 |
+
m = model.to(device).eval()
|
| 185 |
+
if isinstance(sample, torch.Tensor):
|
| 186 |
+
x = sample.to(device)
|
| 187 |
+
else:
|
| 188 |
+
x = torch.randn(*sample, device=device)
|
| 189 |
+
|
| 190 |
+
if torch.cuda.is_available() and device.startswith("cuda"):
|
| 191 |
+
times = []
|
| 192 |
+
for _ in range(cfg.iters):
|
| 193 |
+
t0 = torch.cuda.Event(True); t1 = torch.cuda.Event(True)
|
| 194 |
+
t0.record(); _ = (forward_fn or (lambda a, b: a(pixel_values=b)))(m, x); t1.record();
|
| 195 |
+
if cfg.sync_each_iter: torch.cuda.synchronize()
|
| 196 |
+
times.append(t0.elapsed_time(t1))
|
| 197 |
+
else:
|
| 198 |
+
times = []
|
| 199 |
+
for _ in range(cfg.iters):
|
| 200 |
+
t0 = time.perf_counter(); _ = (forward_fn or (lambda a, b: a(pixel_values=b)))(m, x); t1 = time.perf_counter()
|
| 201 |
+
times.append((t1 - t0) * 1000.0)
|
| 202 |
+
|
| 203 |
+
times = sorted(_apply_mad_filter(times, cfg.reject_outliers_mad))
|
| 204 |
+
percs = _percentiles(times, cfg.percentile)
|
| 205 |
+
out = {"mean": sum(times) / max(1, len(times))}
|
| 206 |
+
out.update({f"p{q}": v for q, v in percs.items()})
|
| 207 |
+
return out
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class LatencyProfiler:
|
| 211 |
+
"""Reusable profiler with fixed settings."""
|
| 212 |
+
|
| 213 |
+
def __init__(self, settings: Optional[ProfileSettings] = None, device: str = "cuda"):
|
| 214 |
+
self.settings = settings or ProfileSettings()
|
| 215 |
+
self.device = device
|
| 216 |
+
|
| 217 |
+
def measure(self, model: nn.Module, sample: torch.Tensor | Tuple[int, ...], *, forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None) -> Tuple[float, float]:
|
| 218 |
+
return measure_latency_ms(model, sample, settings=self.settings, device=self.device, forward_fn=forward_fn)
|
| 219 |
+
|
| 220 |
+
def profile(self, model: nn.Module, sample: torch.Tensor | Tuple[int, ...], *, forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None) -> Dict[str, float]:
|
| 221 |
+
return profile(model, sample, settings=self.settings, device=self.device, forward_fn=forward_fn)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@torch.inference_mode()
|
| 225 |
+
def profile_many_shapes(
|
| 226 |
+
model: nn.Module,
|
| 227 |
+
shapes: Iterable[Tuple[int, ...]],
|
| 228 |
+
*,
|
| 229 |
+
settings: Optional[ProfileSettings] = None,
|
| 230 |
+
device: str = "cuda",
|
| 231 |
+
forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
|
| 232 |
+
) -> Dict[Tuple[int, ...], Dict[str, float]]:
|
| 233 |
+
out: Dict[Tuple[int, ...], Dict[str, float]] = {}
|
| 234 |
+
for shp in shapes:
|
| 235 |
+
out[tuple(shp)] = profile(model, shp, settings=settings, device=device, forward_fn=forward_fn)
|
| 236 |
+
return out
|
core/proxy_cost.py
ADDED
|
@@ -0,0 +1,771 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# core/proxy_cost.py
|
| 2 |
+
"""Latency proxy models and a tiny LUT for hardware correction.
|
| 3 |
+
|
| 4 |
+
This file defines a family-agnostic interface plus concrete proxies (ViT, ResNet, LLM)
|
| 5 |
+
that estimate latency from *soft structure* (gates) and input size. All proxies accept
|
| 6 |
+
the trainer's `(model, batch) -> ms` call signature directly (batches may be dict/tuple/tensor).
|
| 7 |
+
A small, in-memory LUT can be populated from real measurements during training to correct
|
| 8 |
+
analytic estimates.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Any, Dict, Optional, Tuple, Union, List
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
from .gates import iter_gates, _as_like # _as_like is used by ViT proxy
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# -----------------------------------------------------------------------------
|
| 22 |
+
# Small batch helpers (shared)
|
| 23 |
+
# -----------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
TensorOrBatch = Union[torch.Tensor, Tuple, List, Dict[str, Any]]
|
| 26 |
+
|
| 27 |
+
def _first_tensor(batch: TensorOrBatch) -> torch.Tensor:
|
| 28 |
+
"""Find the first tensor inside a batch-like structure."""
|
| 29 |
+
if torch.is_tensor(batch):
|
| 30 |
+
return batch
|
| 31 |
+
if isinstance(batch, dict):
|
| 32 |
+
# Common keys across tasks
|
| 33 |
+
for k in ("input_ids", "pixel_values", "images", "x"):
|
| 34 |
+
v = batch.get(k, None)
|
| 35 |
+
if torch.is_tensor(v):
|
| 36 |
+
return v
|
| 37 |
+
# fallback: first tensor value
|
| 38 |
+
for v in batch.values():
|
| 39 |
+
if torch.is_tensor(v):
|
| 40 |
+
return v
|
| 41 |
+
raise ValueError("Batch dict has no tensor field I recognize.")
|
| 42 |
+
if isinstance(batch, (list, tuple)):
|
| 43 |
+
for v in batch:
|
| 44 |
+
if torch.is_tensor(v):
|
| 45 |
+
return v
|
| 46 |
+
# torchvision pattern: ([aug1, aug2], label)
|
| 47 |
+
if len(batch) and isinstance(batch[0], (list, tuple)):
|
| 48 |
+
for v in batch[0]:
|
| 49 |
+
if torch.is_tensor(v):
|
| 50 |
+
return v
|
| 51 |
+
raise ValueError("Cannot find a tensor in the provided batch.")
|
| 52 |
+
|
| 53 |
+
def _ids_from_batch(batch: TensorOrBatch) -> torch.Tensor:
|
| 54 |
+
"""Return a 2D [B,S] tensor representing token ids for LLMs."""
|
| 55 |
+
if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
|
| 56 |
+
return batch["input_ids"]
|
| 57 |
+
t = _first_tensor(batch)
|
| 58 |
+
if t.dim() >= 2:
|
| 59 |
+
return t
|
| 60 |
+
raise ValueError("Cannot infer [B,S] from batch; need 'input_ids' or a 2D tensor.")
|
| 61 |
+
|
| 62 |
+
def _nchw_from_batch(batch: TensorOrBatch) -> Tuple[int, int, int, int]:
|
| 63 |
+
"""Return NCHW shape from a batch or an explicit (N,C,H,W) tuple/list/tensor."""
|
| 64 |
+
if isinstance(batch, (tuple, list)) and len(batch) == 4 and all(isinstance(x, int) for x in batch):
|
| 65 |
+
return tuple(batch) # type: ignore[return-value]
|
| 66 |
+
x = _first_tensor(batch)
|
| 67 |
+
if x.dim() != 4:
|
| 68 |
+
raise ValueError(f"Expected NCHW tensor for CNN proxy; got tensor with shape {tuple(x.shape)}")
|
| 69 |
+
N, C, H, W = map(int, x.shape)
|
| 70 |
+
return (N, C, H, W)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# -----------------------------------------------------------------------------
|
| 74 |
+
# Base proxy + LUT
|
| 75 |
+
# -----------------------------------------------------------------------------
|
| 76 |
+
|
| 77 |
+
class LatencyProxy(nn.Module):
|
| 78 |
+
"""Abstract proxy producing a scalar latency-like value (ms).
|
| 79 |
+
|
| 80 |
+
Subclasses implement `_predict_raw` and may define `_signature` keys used by
|
| 81 |
+
a LUT to refine estimates with real measurements. Proxies accept either a
|
| 82 |
+
batch-like object (dict/tuple/tensor) or an explicit shape tuple.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self):
|
| 86 |
+
super().__init__()
|
| 87 |
+
|
| 88 |
+
def predict(
|
| 89 |
+
self,
|
| 90 |
+
model: nn.Module,
|
| 91 |
+
sample: TensorOrBatch,
|
| 92 |
+
*,
|
| 93 |
+
policy=None,
|
| 94 |
+
step: Optional[int] = None,
|
| 95 |
+
**kwargs,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
"""Batch-friendly entry point. `sample` may be a batch or explicit shape."""
|
| 98 |
+
return self._predict_raw(model, sample, policy=policy, step=step, **kwargs)
|
| 99 |
+
|
| 100 |
+
def _predict_raw(
|
| 101 |
+
self,
|
| 102 |
+
model: nn.Module,
|
| 103 |
+
sample: TensorOrBatch,
|
| 104 |
+
*,
|
| 105 |
+
policy=None,
|
| 106 |
+
step: Optional[int] = None,
|
| 107 |
+
**kwargs,
|
| 108 |
+
) -> torch.Tensor: # pragma: no cover - abstract
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
def signature(
|
| 112 |
+
self,
|
| 113 |
+
model: nn.Module,
|
| 114 |
+
sample: TensorOrBatch,
|
| 115 |
+
*,
|
| 116 |
+
policy=None,
|
| 117 |
+
step: Optional[int] = None
|
| 118 |
+
) -> Tuple:
|
| 119 |
+
"""Return a hashable signature describing the workload shape."""
|
| 120 |
+
if torch.is_tensor(sample):
|
| 121 |
+
shp = tuple(sample.shape)
|
| 122 |
+
elif isinstance(sample, (tuple, list)):
|
| 123 |
+
shp = tuple(sample)
|
| 124 |
+
elif isinstance(sample, dict):
|
| 125 |
+
# summarize the shapes of any tensors in dict
|
| 126 |
+
shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v))
|
| 127 |
+
else:
|
| 128 |
+
shp = (str(type(sample)),)
|
| 129 |
+
return (type(self).__name__, shp)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class LatencyLUT:
|
| 133 |
+
"""Tiny LUT mapping `(signature) -> measured_ms`."""
|
| 134 |
+
|
| 135 |
+
def __init__(self):
|
| 136 |
+
self._table: Dict[Tuple[Any, ...], float] = {}
|
| 137 |
+
|
| 138 |
+
def update(self, signature: Tuple[Any, ...], measured_ms: float) -> None:
|
| 139 |
+
self._table[signature] = float(measured_ms)
|
| 140 |
+
|
| 141 |
+
def get(self, signature: Tuple[Any, ...]) -> Optional[float]:
|
| 142 |
+
return self._table.get(signature)
|
| 143 |
+
|
| 144 |
+
def blend(self, raw_estimate: torch.Tensor, signature: Tuple[Any, ...]) -> torch.Tensor:
|
| 145 |
+
val = self.get(signature)
|
| 146 |
+
if val is None:
|
| 147 |
+
return raw_estimate
|
| 148 |
+
# Put on same device/dtype as raw_estimate
|
| 149 |
+
return _as_like(raw_estimate, val)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# -----------------------------------------------------------------------------
|
| 153 |
+
# ViT proxy (analytic + gates), with scale and per-term weights
|
| 154 |
+
# -----------------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
@dataclass
|
| 157 |
+
class ViTProxyConfig:
|
| 158 |
+
scale_ms: float = 1.0
|
| 159 |
+
alpha_qkv: float = 1.0
|
| 160 |
+
alpha_scores: float = 1.0
|
| 161 |
+
alpha_out: float = 1.0
|
| 162 |
+
alpha_mlp: float = 1.0
|
| 163 |
+
|
| 164 |
+
def _vit_layers(m):
|
| 165 |
+
enc = getattr(m, "encoder", None)
|
| 166 |
+
if enc is not None and hasattr(enc, "layer"):
|
| 167 |
+
return enc.layer
|
| 168 |
+
vit = getattr(m, "vit", None)
|
| 169 |
+
if vit is not None and hasattr(vit, "encoder") and hasattr(vit.encoder, "layer"):
|
| 170 |
+
return vit.encoder.layer
|
| 171 |
+
raise TypeError("Expected a HF ViT with *.encoder.layer (ViTModel or ViTForImageClassification).")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class ViTLatencyProxy(LatencyProxy):
|
| 175 |
+
"""Latency proxy for ViT models. Accepts batches or (N,C,H,W) tuples."""
|
| 176 |
+
|
| 177 |
+
def __init__(self, cfg: Optional[ViTProxyConfig] = None, lut: Optional[LatencyLUT] = None):
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.cfg = cfg or ViTProxyConfig()
|
| 180 |
+
self.lut = lut or LatencyLUT()
|
| 181 |
+
|
| 182 |
+
# ---- helpers -------------------------------------------------------------
|
| 183 |
+
@staticmethod
|
| 184 |
+
def _input_spec(sample: TensorOrBatch) -> Tuple[int, int, int]:
|
| 185 |
+
if isinstance(sample, (tuple, list)) and len(sample) == 4 and all(isinstance(x, int) for x in sample):
|
| 186 |
+
B, C, H, W = sample
|
| 187 |
+
return int(B), int(H), int(W)
|
| 188 |
+
x = _first_tensor(sample)
|
| 189 |
+
if x.dim() != 4:
|
| 190 |
+
raise ValueError("ViTLatencyProxy expects a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
|
| 191 |
+
B, C, H, W = x.shape
|
| 192 |
+
return int(B), int(H), int(W)
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def _patch_hw(cfg) -> Tuple[int, int]:
|
| 196 |
+
patch = getattr(cfg, "patch_size", 16)
|
| 197 |
+
if isinstance(patch, (tuple, list)):
|
| 198 |
+
return int(patch[0]), int(patch[1])
|
| 199 |
+
return int(patch), int(patch)
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def _soft_heads_from_block(blk) -> Optional[torch.Tensor]:
|
| 203 |
+
# Prefer a nested attention with kept_heads_soft()
|
| 204 |
+
attn = getattr(getattr(blk, "attention", None), "attention", None)
|
| 205 |
+
if attn is not None and hasattr(attn, "kept_heads_soft"):
|
| 206 |
+
return attn.kept_heads_soft()
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
@staticmethod
|
| 210 |
+
def _find_ffn_gate(blk):
|
| 211 |
+
inter = getattr(blk, "intermediate", None)
|
| 212 |
+
if inter is None:
|
| 213 |
+
return None
|
| 214 |
+
# Common attribute names
|
| 215 |
+
for nm in ("neuron_gate", "gate", "ffn_gate"):
|
| 216 |
+
g = getattr(inter, nm, None)
|
| 217 |
+
if g is not None and hasattr(g, "logits") and hasattr(g, "tau"):
|
| 218 |
+
return g
|
| 219 |
+
# Last resort: scan children
|
| 220 |
+
for m in blk.modules():
|
| 221 |
+
if hasattr(m, "logits") and hasattr(m, "tau"):
|
| 222 |
+
return m
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
# ---- proxy ---------------------------------------------------------------
|
| 226 |
+
def _predict_raw(
|
| 227 |
+
self,
|
| 228 |
+
model: nn.Module,
|
| 229 |
+
sample: TensorOrBatch,
|
| 230 |
+
*,
|
| 231 |
+
policy=None,
|
| 232 |
+
step: Optional[int] = None
|
| 233 |
+
) -> torch.Tensor:
|
| 234 |
+
anchor = next((p for p in model.parameters()), torch.tensor(0.0))
|
| 235 |
+
|
| 236 |
+
B, H_img, W_img = self._input_spec(sample)
|
| 237 |
+
cfg = getattr(model, "config", None)
|
| 238 |
+
if cfg is None:
|
| 239 |
+
raise ValueError("Model must expose a HuggingFace-like .config for ViT proxy")
|
| 240 |
+
ph, pw = self._patch_hw(cfg)
|
| 241 |
+
|
| 242 |
+
S = _as_like(anchor, 1 + (H_img // ph) * (W_img // pw))
|
| 243 |
+
D = _as_like(anchor, int(getattr(cfg, "hidden_size", 768)))
|
| 244 |
+
Hh = _as_like(anchor, int(getattr(cfg, "num_attention_heads", 12)))
|
| 245 |
+
Dh = D // Hh
|
| 246 |
+
|
| 247 |
+
warm = False
|
| 248 |
+
if policy is not None and step is not None:
|
| 249 |
+
warm = (step < int(getattr(policy, "warmup_steps", 0)))
|
| 250 |
+
|
| 251 |
+
total_qkv = _as_like(anchor, 0.0)
|
| 252 |
+
total_scores = _as_like(anchor, 0.0)
|
| 253 |
+
total_out = _as_like(anchor, 0.0)
|
| 254 |
+
total_mlp = _as_like(anchor, 0.0)
|
| 255 |
+
|
| 256 |
+
default_hidden = _as_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D))))
|
| 257 |
+
|
| 258 |
+
layers = _vit_layers(model)
|
| 259 |
+
for blk in layers:
|
| 260 |
+
heads_soft = Hh if warm else (self._soft_heads_from_block(blk) or Hh)
|
| 261 |
+
|
| 262 |
+
# FFN hidden expectation
|
| 263 |
+
if warm:
|
| 264 |
+
hidden_soft = default_hidden
|
| 265 |
+
else:
|
| 266 |
+
g = self._find_ffn_gate(blk)
|
| 267 |
+
if g is None:
|
| 268 |
+
hidden_soft = default_hidden
|
| 269 |
+
else:
|
| 270 |
+
probs = torch.sigmoid(g.logits / g.tau)
|
| 271 |
+
group = int(getattr(g, "group", getattr(g, "group_size", 16)))
|
| 272 |
+
hidden_soft = probs.sum() * _as_like(anchor, group)
|
| 273 |
+
|
| 274 |
+
D_kept = heads_soft * Dh
|
| 275 |
+
|
| 276 |
+
total_qkv += 3 * S * D * D_kept
|
| 277 |
+
total_scores += (S * S) * heads_soft * Dh
|
| 278 |
+
total_out += S * D_kept * D
|
| 279 |
+
total_mlp += 2 * S * D * hidden_soft
|
| 280 |
+
|
| 281 |
+
raw = (
|
| 282 |
+
self.cfg.alpha_qkv * total_qkv
|
| 283 |
+
+ self.cfg.alpha_scores * total_scores
|
| 284 |
+
+ self.cfg.alpha_out * total_out
|
| 285 |
+
+ self.cfg.alpha_mlp * total_mlp
|
| 286 |
+
)
|
| 287 |
+
raw_ms = raw * _as_like(anchor, float(self.cfg.scale_ms))
|
| 288 |
+
|
| 289 |
+
# optional LUT correction
|
| 290 |
+
sig = self.signature(model, sample, policy=policy, step=step)
|
| 291 |
+
return self.lut.blend(raw_ms, sig)
|
| 292 |
+
|
| 293 |
+
# A reasonable default signature for ViT workloads
|
| 294 |
+
def signature(self, model: nn.Module, sample, *, policy=None, step: Optional[int] = None) -> Tuple:
|
| 295 |
+
if torch.is_tensor(sample):
|
| 296 |
+
shp = tuple(sample.shape)
|
| 297 |
+
elif isinstance(sample, (tuple, list)):
|
| 298 |
+
shp = tuple(sample)
|
| 299 |
+
elif isinstance(sample, dict):
|
| 300 |
+
shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v))
|
| 301 |
+
else:
|
| 302 |
+
shp = (str(type(sample)),)
|
| 303 |
+
cfg = getattr(model, "config", None)
|
| 304 |
+
heads = int(getattr(cfg, "num_attention_heads", 12))
|
| 305 |
+
hidden = int(getattr(cfg, "hidden_size", 768))
|
| 306 |
+
inter = int(getattr(cfg, "intermediate_size", 3072))
|
| 307 |
+
return ("ViT", shp, heads, hidden, inter)
|
| 308 |
+
|
| 309 |
+
@torch.no_grad()
|
| 310 |
+
def calibrate(self, model: nn.Module, shape: tuple, measure_fn, *, device: str = "cuda") -> float:
|
| 311 |
+
"""Set proxy scale so that keep-all student matches measured ms.
|
| 312 |
+
|
| 313 |
+
`measure_fn(model, shape_or_tensor)` should return `(mean_ms, p95_ms)`.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
sample_t = torch.randn(shape, device=device)
|
| 317 |
+
|
| 318 |
+
sample_t = sample_t.to(device)
|
| 319 |
+
model = model.to(device).eval()
|
| 320 |
+
mean_ms, _ = measure_fn(model, shape, device=device)
|
| 321 |
+
soft_ms = self.predict(model, sample_t).item()
|
| 322 |
+
self.cfg.scale_ms = float(mean_ms / max(soft_ms, 1e-9))
|
| 323 |
+
return self.cfg.scale_ms
|
| 324 |
+
|
| 325 |
+
# ------------------------------ ResNet Proxy ------------------------------
|
| 326 |
+
|
| 327 |
+
@dataclass
|
| 328 |
+
class ResNetProxyConfig:
|
| 329 |
+
scale_ms: float = 1.0
|
| 330 |
+
alpha_conv: float = 1.0 # weight for conv FLOPs term
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _as_const_like_resnet(x_like: torch.Tensor, val):
|
| 334 |
+
return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _find_anchor_param(model: nn.Module) -> torch.Tensor:
|
| 338 |
+
# Prefer any gate-like parameter; otherwise any parameter; else cpu scalar
|
| 339 |
+
for m in model.modules():
|
| 340 |
+
for nm in ("logits", "head_gate"):
|
| 341 |
+
t = getattr(m, nm, None)
|
| 342 |
+
if isinstance(t, torch.Tensor):
|
| 343 |
+
return t
|
| 344 |
+
for p in model.parameters():
|
| 345 |
+
return p
|
| 346 |
+
return torch.tensor(0.0)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _kept_from_gate(module, anchor: torch.Tensor) -> Optional[torch.Tensor]:
|
| 350 |
+
"""Return expected kept channels for a BN gate: probs.sum() * group_size.
|
| 351 |
+
If no gate is found, return None.
|
| 352 |
+
"""
|
| 353 |
+
g = None
|
| 354 |
+
for nm in ("gate", "neuron_gate", "channel_gate", "bn_gate"):
|
| 355 |
+
if hasattr(module, nm):
|
| 356 |
+
g = getattr(module, nm)
|
| 357 |
+
break
|
| 358 |
+
if g is None and hasattr(module, "logits") and hasattr(module, "tau"):
|
| 359 |
+
g = module
|
| 360 |
+
|
| 361 |
+
if g is None or not hasattr(g, "logits"):
|
| 362 |
+
return None
|
| 363 |
+
logits = g.logits
|
| 364 |
+
tau = float(getattr(g, "tau", 1.5))
|
| 365 |
+
group = int(getattr(g, "group", getattr(g, "group_size", 1)))
|
| 366 |
+
if group <= 0: group = 1
|
| 367 |
+
probs = torch.sigmoid(logits / tau)
|
| 368 |
+
return probs.sum() * _as_const_like_resnet(anchor, group)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class ResNetLatencyProxy(LatencyProxy):
|
| 372 |
+
"""Latency proxy for ResNet-like backbones with BN gates.
|
| 373 |
+
|
| 374 |
+
Approximates latency with a FLOPs-style sum over convs, using the *expected*
|
| 375 |
+
kept channels after each BN gate (probs.sum()*group_size). Falls back to the
|
| 376 |
+
full channel count when a gate is not found.
|
| 377 |
+
|
| 378 |
+
Accepts a batch or an explicit (N,C,H,W) shape.
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
def __init__(self, cfg: Optional[ResNetProxyConfig] = None):
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.cfg = cfg or ResNetProxyConfig()
|
| 384 |
+
|
| 385 |
+
def _add_cost(self, cost_like: torch.Tensor, oc, ic, k, stride, H, W):
|
| 386 |
+
alpha = _as_const_like_resnet(cost_like, self.cfg.alpha_conv)
|
| 387 |
+
# update spatial dims with conv stride (roughly, ignoring padding effects)
|
| 388 |
+
H = (H + stride - 1) // stride
|
| 389 |
+
W = (W + stride - 1) // stride
|
| 390 |
+
flops = _as_const_like_resnet(cost_like, oc) * _as_const_like_resnet(cost_like, ic) * (k * k) * _as_const_like_resnet(cost_like, H) * _as_const_like_resnet(cost_like, W)
|
| 391 |
+
return cost_like + alpha * flops, H, W
|
| 392 |
+
|
| 393 |
+
def _predict_raw(self, model: nn.Module, sample: TensorOrBatch, **_) -> torch.Tensor:
|
| 394 |
+
N, C_in, H0, W0 = _nchw_from_batch(sample)
|
| 395 |
+
anchor = _find_anchor_param(model)
|
| 396 |
+
cost = _as_const_like_resnet(anchor, 0.0)
|
| 397 |
+
H = _as_const_like_resnet(anchor, int(H0))
|
| 398 |
+
W = _as_const_like_resnet(anchor, int(W0))
|
| 399 |
+
|
| 400 |
+
# Stem
|
| 401 |
+
conv1 = getattr(model, "conv1")
|
| 402 |
+
bn1 = getattr(model, "bn1", None)
|
| 403 |
+
k = conv1.kernel_size[0]
|
| 404 |
+
s = conv1.stride[0]
|
| 405 |
+
kept_out = None
|
| 406 |
+
if bn1 is not None:
|
| 407 |
+
kept = _kept_from_gate(bn1, anchor)
|
| 408 |
+
if kept is not None:
|
| 409 |
+
kept_out = kept
|
| 410 |
+
oc_eff = kept_out if kept_out is not None else _as_const_like_resnet(anchor, conv1.out_channels)
|
| 411 |
+
cost, H, W = self._add_cost(cost, oc_eff, _as_const_like_resnet(anchor, C_in), k, s, H, W)
|
| 412 |
+
in_ch = oc_eff
|
| 413 |
+
|
| 414 |
+
def _block_cost(block, in_ch, H, W, cost):
|
| 415 |
+
# conv1 -> bn1
|
| 416 |
+
c1 = block.conv1
|
| 417 |
+
b1 = block.bn1 if hasattr(block, "bn1") else None
|
| 418 |
+
k1, s1 = c1.kernel_size[0], c1.stride[0]
|
| 419 |
+
oc1_eff = _kept_from_gate(b1, anchor) or _as_const_like_resnet(anchor, c1.out_channels)
|
| 420 |
+
cost, H, W = self._add_cost(cost, oc1_eff, in_ch, k1, s1, H, W)
|
| 421 |
+
|
| 422 |
+
# conv2 -> bn2
|
| 423 |
+
c2 = block.conv2
|
| 424 |
+
b2 = block.bn2 if hasattr(block, "bn2") else None
|
| 425 |
+
k2, s2 = c2.kernel_size[0], c2.stride[0]
|
| 426 |
+
oc2_eff = _kept_from_gate(b2, anchor) or _as_const_like_resnet(anchor, c2.out_channels)
|
| 427 |
+
cost, H, W = self._add_cost(cost, oc2_eff, oc1_eff, k2, s2, H, W)
|
| 428 |
+
|
| 429 |
+
return oc2_eff, H, W, cost
|
| 430 |
+
|
| 431 |
+
# Layers
|
| 432 |
+
for lname in ("layer1", "layer2", "layer3", "layer4"):
|
| 433 |
+
layer = getattr(model, lname, None)
|
| 434 |
+
if layer is None:
|
| 435 |
+
continue
|
| 436 |
+
for blk in layer:
|
| 437 |
+
in_ch, H, W, cost = _block_cost(blk, in_ch, H, W, cost)
|
| 438 |
+
|
| 439 |
+
scale = _as_const_like_resnet(anchor, self.cfg.scale_ms)
|
| 440 |
+
return cost * scale
|
| 441 |
+
|
| 442 |
+
@torch.no_grad()
|
| 443 |
+
def calibrate(self, model: nn.Module, keepall_export_fn, profiler_fn, sample: TensorOrBatch, device: str = "cuda") -> float:
|
| 444 |
+
"""Calibrate `scale_ms` so proxy(model_keepall) ~= real latency in ms."""
|
| 445 |
+
keep = keepall_export_fn(model)
|
| 446 |
+
sample_shape = _nchw_from_batch(sample)
|
| 447 |
+
mean_ms, _ = profiler_fn(keep, sample_shape, device=device)
|
| 448 |
+
soft = float(self.predict(model, sample).detach().cpu())
|
| 449 |
+
self.cfg.scale_ms = mean_ms / max(soft, 1e-9)
|
| 450 |
+
return mean_ms
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
# -----------------------------------------------------------------------------
|
| 454 |
+
# LLM proxy
|
| 455 |
+
# -----------------------------------------------------------------------------
|
| 456 |
+
|
| 457 |
+
"""
|
| 458 |
+
LatencyProxyLLM
|
| 459 |
+
---------------
|
| 460 |
+
A lightweight latency proxy for decoder-only HF LLMs (LLaMA/Mistral style).
|
| 461 |
+
|
| 462 |
+
- Estimates end-to-end latency (ms-like scalar) for a given (B, S, T):
|
| 463 |
+
* Prefill on S tokens (build KV cache)
|
| 464 |
+
* Cached decode for T steps
|
| 465 |
+
- Uses soft gate expectations:
|
| 466 |
+
* Attention heads (HeadGate on GatedSelfAttentionLLM)
|
| 467 |
+
* FFN hidden (SwiGLUWidthGate via .mlp.neuron_gate)
|
| 468 |
+
- Calibrate .scale_ms so proxy ≈ real latency of a keep-all model.
|
| 469 |
+
|
| 470 |
+
Public API
|
| 471 |
+
----------
|
| 472 |
+
- LatencyProxyLLM(...).predict(model, batch_or_shape) # trainer entry
|
| 473 |
+
- LatencyProxyLLM(...).predict(model, B=?, S=?, T=?) # explicit entry
|
| 474 |
+
- LatencyProxyLLM(...).debug_layer_view(...)
|
| 475 |
+
- calibrate_proxy_llm(...), calibrate_proxy_llm_from_batch(...)
|
| 476 |
+
"""
|
| 477 |
+
|
| 478 |
+
# ------------------------------------------------------------
|
| 479 |
+
# Shared tiny utils (device/dtype-safe constants)
|
| 480 |
+
# ------------------------------------------------------------
|
| 481 |
+
def _find_gate_param_or_fallback(model: nn.Module) -> torch.Tensor:
|
| 482 |
+
"""
|
| 483 |
+
Return a tensor to anchor device/dtype for proxy constants.
|
| 484 |
+
Prefer gate logits; else any parameter; else CPU fp32 scalar.
|
| 485 |
+
"""
|
| 486 |
+
for m in model.modules():
|
| 487 |
+
if hasattr(m, "head_gate") and hasattr(getattr(m, "head_gate"), "logits"):
|
| 488 |
+
return m.head_gate.logits
|
| 489 |
+
if hasattr(m, "neuron_gate") and hasattr(m.neuron_gate, "logits"):
|
| 490 |
+
return m.neuron_gate.logits
|
| 491 |
+
if hasattr(m, "logits") and isinstance(getattr(m, "logits"), torch.Tensor):
|
| 492 |
+
return m.logits
|
| 493 |
+
for p in model.parameters():
|
| 494 |
+
return p
|
| 495 |
+
return torch.tensor(0.0)
|
| 496 |
+
|
| 497 |
+
def _as_const_like(x_like: torch.Tensor, val):
|
| 498 |
+
return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# ------------------------------------------------------------
|
| 502 |
+
# Proxy
|
| 503 |
+
# ------------------------------------------------------------
|
| 504 |
+
@dataclass
|
| 505 |
+
class _WarmupOnlyPolicy:
|
| 506 |
+
"""Tiny policy shim so you can pass warmup_steps to .predict()."""
|
| 507 |
+
warmup_steps: int = 0
|
| 508 |
+
|
| 509 |
+
class LatencyProxyLLM(LatencyProxy):
|
| 510 |
+
"""
|
| 511 |
+
LLM latency proxy (ms ~ weighted FLOPs/bandwidth terms) for prefill + cached decode.
|
| 512 |
+
Accepts either a batch or explicit B,S,T.
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
def __init__(
|
| 516 |
+
self,
|
| 517 |
+
*,
|
| 518 |
+
scale_ms: float = 1.0,
|
| 519 |
+
alpha_qkv: float = 1.0,
|
| 520 |
+
alpha_scores: float = 1.0,
|
| 521 |
+
alpha_out: float = 1.0,
|
| 522 |
+
alpha_mlp: float = 1.0,
|
| 523 |
+
gate_kv_in_proxy: bool = False,
|
| 524 |
+
default_T: int = 128,
|
| 525 |
+
):
|
| 526 |
+
super().__init__()
|
| 527 |
+
self.scale_ms = float(scale_ms)
|
| 528 |
+
self.alpha_qkv = float(alpha_qkv)
|
| 529 |
+
self.alpha_scores = float(alpha_scores)
|
| 530 |
+
self.alpha_out = float(alpha_out)
|
| 531 |
+
self.alpha_mlp = float(alpha_mlp)
|
| 532 |
+
self.gate_kv_in_proxy = bool(gate_kv_in_proxy)
|
| 533 |
+
self.default_T = int(default_T)
|
| 534 |
+
|
| 535 |
+
# ---------- gate discovery ----------
|
| 536 |
+
@staticmethod
|
| 537 |
+
def _soft_heads_from_block_llm(blk) -> Optional[torch.Tensor]:
|
| 538 |
+
attn = getattr(blk, "self_attn", None)
|
| 539 |
+
if attn is None:
|
| 540 |
+
return None
|
| 541 |
+
if hasattr(attn, "kept_heads_soft") and callable(attn.kept_heads_soft):
|
| 542 |
+
return attn.kept_heads_soft()
|
| 543 |
+
logits, tau = None, None
|
| 544 |
+
if hasattr(attn, "head_gate") and hasattr(attn.head_gate, "logits"):
|
| 545 |
+
logits = attn.head_gate.logits
|
| 546 |
+
tau = float(getattr(attn.head_gate, "tau", getattr(attn, "tau", 1.5)))
|
| 547 |
+
elif hasattr(attn, "logits"):
|
| 548 |
+
logits = attn.logits
|
| 549 |
+
tau = float(getattr(attn, "tau", 1.5))
|
| 550 |
+
if logits is None:
|
| 551 |
+
return None
|
| 552 |
+
return torch.sigmoid(logits / tau).sum()
|
| 553 |
+
|
| 554 |
+
@staticmethod
|
| 555 |
+
def _find_ffn_gate_llm(blk):
|
| 556 |
+
mlp = getattr(blk, "mlp", None)
|
| 557 |
+
g = getattr(mlp, "neuron_gate", None) if mlp is not None else None
|
| 558 |
+
if g is not None and hasattr(g, "logits") and hasattr(g, "tau"):
|
| 559 |
+
return g
|
| 560 |
+
return None
|
| 561 |
+
|
| 562 |
+
def _soft_hidden_from_block_llm(self, blk, default_hidden, anchor, warm=False):
|
| 563 |
+
if warm:
|
| 564 |
+
return default_hidden
|
| 565 |
+
g = self._find_ffn_gate_llm(blk)
|
| 566 |
+
if g is None:
|
| 567 |
+
return default_hidden
|
| 568 |
+
probs = torch.sigmoid(g.logits / float(g.tau)) # [#groups]
|
| 569 |
+
group = int(getattr(g, "group", getattr(g, "group_size", 128)))
|
| 570 |
+
kept_hidden = probs.sum() * _as_const_like(anchor, group)
|
| 571 |
+
return kept_hidden
|
| 572 |
+
|
| 573 |
+
# ---------- main ----------
|
| 574 |
+
def predict( # trainer entry and explicit-shape entry unified
|
| 575 |
+
self,
|
| 576 |
+
model: nn.Module,
|
| 577 |
+
sample: Optional[TensorOrBatch] = None,
|
| 578 |
+
*,
|
| 579 |
+
B: Optional[int] = None,
|
| 580 |
+
S: Optional[int] = None,
|
| 581 |
+
T: Optional[int] = None,
|
| 582 |
+
policy: Optional[object] = None,
|
| 583 |
+
step: Optional[int] = None,
|
| 584 |
+
return_terms: bool = False,
|
| 585 |
+
):
|
| 586 |
+
# Allow explicit B,S,(T) path
|
| 587 |
+
if B is not None and S is not None:
|
| 588 |
+
ids_B, ids_S = int(B), int(S)
|
| 589 |
+
ids_T = int(T) if T is not None else int(self.default_T)
|
| 590 |
+
else:
|
| 591 |
+
if sample is None:
|
| 592 |
+
raise ValueError("LatencyProxyLLM.predict needs either a batch sample or explicit B,S.")
|
| 593 |
+
if isinstance(sample, (tuple, list)) and len(sample) in (2, 3) and all(isinstance(x, int) for x in sample):
|
| 594 |
+
# explicit (B,S) or (B,S,T)
|
| 595 |
+
ids_B, ids_S = int(sample[0]), int(sample[1])
|
| 596 |
+
ids_T = int(sample[2]) if len(sample) == 3 else int(self.default_T)
|
| 597 |
+
else:
|
| 598 |
+
ids = _ids_from_batch(sample)
|
| 599 |
+
ids_B, ids_S = int(ids.size(0)), int(ids.size(1))
|
| 600 |
+
ids_T = int(self.default_T) if T is None else int(T)
|
| 601 |
+
|
| 602 |
+
anchor = _find_gate_param_or_fallback(model)
|
| 603 |
+
|
| 604 |
+
# scalar tensors (same device/dtype)
|
| 605 |
+
B_t = _as_const_like(anchor, ids_B)
|
| 606 |
+
S_t = _as_const_like(anchor, ids_S)
|
| 607 |
+
T_t = _as_const_like(anchor, ids_T)
|
| 608 |
+
|
| 609 |
+
cfg = model.config
|
| 610 |
+
D = _as_const_like(anchor, int(cfg.hidden_size))
|
| 611 |
+
Hh = _as_const_like(anchor, int(cfg.num_attention_heads))
|
| 612 |
+
Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hh))))
|
| 613 |
+
Dh = D // Hh
|
| 614 |
+
|
| 615 |
+
warmup_steps = int(getattr(policy, "warmup_steps", 0)) if policy is not None else 0
|
| 616 |
+
warm = bool(step is not None and step < warmup_steps)
|
| 617 |
+
|
| 618 |
+
total_qkv = anchor.new_zeros(())
|
| 619 |
+
total_scores = anchor.new_zeros(())
|
| 620 |
+
total_out = anchor.new_zeros(())
|
| 621 |
+
total_mlp = anchor.new_zeros(())
|
| 622 |
+
|
| 623 |
+
default_hidden = _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D))))
|
| 624 |
+
|
| 625 |
+
layers = getattr(getattr(model, "model", model), "layers", [])
|
| 626 |
+
for blk in layers:
|
| 627 |
+
heads_soft = Hh if warm else (self._soft_heads_from_block_llm(blk) or Hh)
|
| 628 |
+
Dq = heads_soft * Dh
|
| 629 |
+
# K/V effective width
|
| 630 |
+
if self.gate_kv_in_proxy:
|
| 631 |
+
Dkv = heads_soft * Dh
|
| 632 |
+
else:
|
| 633 |
+
Dkv = Hkv * Dh
|
| 634 |
+
hidden_soft = self._soft_hidden_from_block_llm(blk, default_hidden, anchor, warm=warm)
|
| 635 |
+
|
| 636 |
+
# Prefill + decode (simplified aggregation)
|
| 637 |
+
Seff = S_t + T_t
|
| 638 |
+
|
| 639 |
+
# q/k/v linear FLOP-like terms
|
| 640 |
+
total_qkv = total_qkv + (
|
| 641 |
+
# q
|
| 642 |
+
B_t * Seff * D * Dq +
|
| 643 |
+
# k + v
|
| 644 |
+
2 * B_t * Seff * D * Dkv
|
| 645 |
+
)
|
| 646 |
+
# attention scores (prefill SxS + decode triangular)
|
| 647 |
+
total_scores = total_scores + (
|
| 648 |
+
B_t * (S_t * S_t) * heads_soft * Dh +
|
| 649 |
+
B_t * heads_soft * Dh * (T_t * S_t + (T_t * (T_t + 1)) // 2)
|
| 650 |
+
)
|
| 651 |
+
# out proj
|
| 652 |
+
total_out = total_out + B_t * Seff * Dq * D
|
| 653 |
+
# mlp
|
| 654 |
+
total_mlp = total_mlp + B_t * Seff * 2 * D * hidden_soft
|
| 655 |
+
|
| 656 |
+
flops_like = (
|
| 657 |
+
self.alpha_qkv * total_qkv
|
| 658 |
+
+ self.alpha_scores * total_scores
|
| 659 |
+
+ self.alpha_out * total_out
|
| 660 |
+
+ self.alpha_mlp * total_mlp
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
ms = flops_like * _as_const_like(anchor, self.scale_ms)
|
| 664 |
+
if return_terms:
|
| 665 |
+
return ms, {
|
| 666 |
+
"qkv": float((self.alpha_qkv * total_qkv).detach().cpu()),
|
| 667 |
+
"scores": float((self.alpha_scores * total_scores).detach().cpu()),
|
| 668 |
+
"out": float((self.alpha_out * total_out).detach().cpu()),
|
| 669 |
+
"mlp": float((self.alpha_mlp * total_mlp).detach().cpu()),
|
| 670 |
+
}
|
| 671 |
+
return ms
|
| 672 |
+
|
| 673 |
+
# ---------- per-layer debug ----------
|
| 674 |
+
@torch.no_grad()
|
| 675 |
+
def debug_layer_view(
|
| 676 |
+
self,
|
| 677 |
+
model: nn.Module,
|
| 678 |
+
*,
|
| 679 |
+
B: int,
|
| 680 |
+
S: int,
|
| 681 |
+
T: int,
|
| 682 |
+
policy: Optional[object] = None,
|
| 683 |
+
step: Optional[int] = None,
|
| 684 |
+
) -> list:
|
| 685 |
+
anchor = _find_gate_param_or_fallback(model)
|
| 686 |
+
cfg = getattr(model, "config", None)
|
| 687 |
+
D = _as_const_like(anchor, int(getattr(cfg, "hidden_size", 0)))
|
| 688 |
+
Hq = _as_const_like(anchor, int(getattr(cfg, "num_attention_heads", 0)))
|
| 689 |
+
Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hq))))
|
| 690 |
+
Dh = D // Hq
|
| 691 |
+
|
| 692 |
+
warm = False
|
| 693 |
+
if policy is not None and step is not None:
|
| 694 |
+
warm = (int(step) < int(getattr(policy, "warmup_steps", 0)))
|
| 695 |
+
|
| 696 |
+
rows = []
|
| 697 |
+
layers = getattr(getattr(model, "model", model), "layers", None) or []
|
| 698 |
+
for i, blk in enumerate(layers):
|
| 699 |
+
heads_soft = Hq if warm else (self._soft_heads_from_block_llm(blk) or Hq)
|
| 700 |
+
Dq = heads_soft * Dh
|
| 701 |
+
Dkv = (heads_soft * Dh) if self.gate_kv_in_proxy else (Hkv * Dh)
|
| 702 |
+
hidden_soft = self._soft_hidden_from_block_llm(
|
| 703 |
+
blk, _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D)))), anchor, warm=warm
|
| 704 |
+
)
|
| 705 |
+
rows.append({
|
| 706 |
+
"layer": i,
|
| 707 |
+
"heads_soft": float(heads_soft.detach().cpu()),
|
| 708 |
+
"Dq≈heads*Dh": float(Dq.detach().cpu()),
|
| 709 |
+
"Dkv_used": float(Dkv.detach().cpu()),
|
| 710 |
+
"ffn_hidden_soft": float(hidden_soft.detach().cpu()),
|
| 711 |
+
})
|
| 712 |
+
return rows
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# ------------------------------------------------------------
|
| 716 |
+
# Calibration helpers for LLM
|
| 717 |
+
# ------------------------------------------------------------
|
| 718 |
+
@torch.inference_mode()
|
| 719 |
+
def calibrate_proxy_llm(
|
| 720 |
+
proxy: LatencyProxyLLM,
|
| 721 |
+
model: nn.Module,
|
| 722 |
+
*,
|
| 723 |
+
B: int,
|
| 724 |
+
S: int,
|
| 725 |
+
T: int,
|
| 726 |
+
export_keepall_fn,
|
| 727 |
+
device: str = "cuda",
|
| 728 |
+
warmup: int = 10,
|
| 729 |
+
iters: int = 30,
|
| 730 |
+
) -> float:
|
| 731 |
+
"""
|
| 732 |
+
Calibrate proxy.scale_ms so proxy.predict(...) matches real keep-all latency for (B,S,T).
|
| 733 |
+
Returns the measured real mean latency in ms.
|
| 734 |
+
"""
|
| 735 |
+
keepall = export_keepall_fn(model).to(device).eval()
|
| 736 |
+
|
| 737 |
+
# Measure real latency (prefill + decode)
|
| 738 |
+
from core.measure import measure_latency_text_ms as _measure # adjust if your path differs
|
| 739 |
+
real_ms, _ = _measure(keepall, B=B, S=S, T=T, warmup=warmup, iters=iters, device=device)
|
| 740 |
+
|
| 741 |
+
# Soft/proxy latency on *gated* model
|
| 742 |
+
ms_like = proxy.predict(model, B=B, S=S, T=T)
|
| 743 |
+
soft_ms = float(ms_like.detach().item()) if torch.is_tensor(ms_like) else float(ms_like)
|
| 744 |
+
|
| 745 |
+
proxy.scale_ms = float(real_ms / max(soft_ms, 1e-9))
|
| 746 |
+
return real_ms
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
@torch.inference_mode()
|
| 750 |
+
def calibrate_proxy_llm_from_batch(
|
| 751 |
+
proxy: LatencyProxyLLM,
|
| 752 |
+
model: nn.Module,
|
| 753 |
+
batch: Dict[str, torch.Tensor],
|
| 754 |
+
*,
|
| 755 |
+
T: int,
|
| 756 |
+
export_keepall_fn,
|
| 757 |
+
device: str = "cuda",
|
| 758 |
+
warmup: int = 10,
|
| 759 |
+
iters: int = 30,
|
| 760 |
+
) -> Tuple[int, int, int, float]:
|
| 761 |
+
"""
|
| 762 |
+
Infers (B,S) from a batch like {'input_ids': [B,S], ...},
|
| 763 |
+
calibrates for (B,S,T), and returns (B,S,T, real_ms).
|
| 764 |
+
"""
|
| 765 |
+
input_ids = batch["input_ids"]
|
| 766 |
+
B, S = int(input_ids.size(0)), int(input_ids.size(1))
|
| 767 |
+
ms = calibrate_proxy_llm(
|
| 768 |
+
proxy, model, B=B, S=S, T=T, export_keepall_fn=export_keepall_fn,
|
| 769 |
+
device=device, warmup=warmup, iters=iters
|
| 770 |
+
)
|
| 771 |
+
return B, S, T, ms
|
core/search_export.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Export-parameter search (hardware-aware).
|
| 2 |
+
|
| 3 |
+
This module performs a small grid search over export rounding/multiple knobs and
|
| 4 |
+
picks the configuration that minimizes *measured* latency for the target batch
|
| 5 |
+
shape. It is family-agnostic; adapters provide the export function.
|
| 6 |
+
|
| 7 |
+
For ViT, see `vit_search_best_export` which scans per-head multiples and FFN
|
| 8 |
+
snap group sizes, mirroring kernel-friendly widths.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Callable, Iterable, List, Optional, Sequence, Tuple
|
| 14 |
+
|
| 15 |
+
import copy
|
| 16 |
+
import itertools
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
|
| 21 |
+
from .export import ExportPolicy as CoreExportPolicy, Rounding as CoreRounding
|
| 22 |
+
from .profiler import measure_latency_ms, ProfileSettings
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Type alias: adapter export function
|
| 26 |
+
ExportFn = Callable[[nn.Module, object, int], nn.Module]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class SearchResult:
|
| 31 |
+
best_model: nn.Module
|
| 32 |
+
best_params: dict
|
| 33 |
+
trials: List[dict]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def grid_search_latency(
|
| 37 |
+
model_with_gates: nn.Module,
|
| 38 |
+
export_fn: ExportFn,
|
| 39 |
+
*,
|
| 40 |
+
head_multiples: Sequence[int],
|
| 41 |
+
ffn_snaps: Sequence[int],
|
| 42 |
+
step: int,
|
| 43 |
+
batch_shape: Tuple[int, int, int, int], # (B,C,H,W)
|
| 44 |
+
measure_settings: Optional[ProfileSettings] = None,
|
| 45 |
+
device: str = "cuda",
|
| 46 |
+
make_policy: Optional[Callable[[int, int], object]] = None,
|
| 47 |
+
) -> SearchResult:
|
| 48 |
+
"""Generic grid search over (head_multiple, ffn_snap_groups).
|
| 49 |
+
|
| 50 |
+
- `make_policy(h_mult, ffn_snap)` must return an adapter-acceptable export policy.
|
| 51 |
+
If not provided, falls back to a single-rounding `CoreExportPolicy` using
|
| 52 |
+
`multiple_groups=head_multiple` for both heads and FFN.
|
| 53 |
+
"""
|
| 54 |
+
trials: List[dict] = []
|
| 55 |
+
best = None
|
| 56 |
+
|
| 57 |
+
to_try = itertools.product(head_multiples, ffn_snaps)
|
| 58 |
+
for i, (hm, fs) in enumerate(to_try):
|
| 59 |
+
policy = make_policy(hm, fs) if make_policy is not None else CoreExportPolicy(
|
| 60 |
+
warmup_steps=0,
|
| 61 |
+
rounding=CoreRounding(floor_groups=1, multiple_groups=int(hm), min_keep_ratio=0.0),
|
| 62 |
+
)
|
| 63 |
+
slim = export_fn(model_with_gates, policy, step)
|
| 64 |
+
mean_ms, p95_ms = measure_latency_ms(slim, batch_shape, settings=measure_settings, device=device)
|
| 65 |
+
rec = {"head_multiple": int(hm), "ffn_snap": int(fs), "mean_ms": float(mean_ms), "p95_ms": float(p95_ms)}
|
| 66 |
+
print(f"[{i}/{len(list(to_try))}] head_multiple {int(hm)} | ffn_snap {int(fs)} | mean_ms = {float(mean_ms)}")
|
| 67 |
+
trials.append(rec)
|
| 68 |
+
if best is None or mean_ms < best[0]:
|
| 69 |
+
best = (mean_ms, hm, fs, slim)
|
| 70 |
+
|
| 71 |
+
assert best is not None
|
| 72 |
+
_, hm_best, fs_best, slim_best = best
|
| 73 |
+
return SearchResult(best_model=slim_best, best_params={"head_multiple": int(hm_best), "ffn_snap": int(fs_best)}, trials=trials)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
core/train.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generic Lagrangian trainer (family-agnostic).
|
| 2 |
+
|
| 3 |
+
This module provides a light framework to optimize *gated* students against
|
| 4 |
+
teachers with a latency target enforced via a proxy + optional real probes.
|
| 5 |
+
|
| 6 |
+
It does not assume ViT/ResNet/LLM specifics; adapters provide tiny callables.
|
| 7 |
+
|
| 8 |
+
Key ingredients:
|
| 9 |
+
- Two-phase update per step: (A) weights w.r.t. KD/task, (B) gates w.r.t. KD +
|
| 10 |
+
sparsity + latency penalty with a dual variable λ.
|
| 11 |
+
- Optional periodic export + real-latency probe to correct λ.
|
| 12 |
+
- Constraint projection for gates after each step.
|
| 13 |
+
|
| 14 |
+
Adapters must provide:
|
| 15 |
+
- get_student_logits(model, x) -> Tensor
|
| 16 |
+
- get_teacher_logits(model, x) -> Tensor
|
| 17 |
+
- export_keepall(model) -> nn.Module (clean copy without gates)
|
| 18 |
+
- export_pruned(model, policy, step) -> nn.Module (transient copy for profiling)
|
| 19 |
+
|
| 20 |
+
Core modules used:
|
| 21 |
+
- `distill.KDConfig`, `distill.kd_loss`
|
| 22 |
+
- `gates.combined_penalty`, `gates.PenaltyWeights`, `gates.project_gates_into_constraints`
|
| 23 |
+
- `proxy_cost.LatencyProxy`
|
| 24 |
+
- `profiler.measure_latency_ms`
|
| 25 |
+
"""
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from typing import Callable, Optional
|
| 30 |
+
import gc
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
|
| 35 |
+
from .distill import KDConfig, kd_loss, mse_reg
|
| 36 |
+
from .gates import PenaltyWeights, Constraints, combined_penalty, project_gates_into_constraints, collect_param_groups
|
| 37 |
+
from .proxy_cost import LatencyProxy
|
| 38 |
+
from .profiler import measure_latency_ms
|
| 39 |
+
|
| 40 |
+
# -----------------------------------------------------------------------------
|
| 41 |
+
# Config
|
| 42 |
+
# -----------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class DualConfig:
|
| 46 |
+
lr: float = 0.05 # step for λ update
|
| 47 |
+
ema_beta: float = 0.5 # blend proxy-driven λ and real probe λ
|
| 48 |
+
clip: float = 10.0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class TrainerConfig:
|
| 53 |
+
kd: KDConfig = KDConfig()
|
| 54 |
+
penalties: PenaltyWeights = PenaltyWeights(l0=0.0, keep_floor_ratio=0.0, bimodality=0.0)
|
| 55 |
+
constraints: Constraints = Constraints(min_keep_ratio=0.0, min_groups=1, max_groups_drop=None)
|
| 56 |
+
|
| 57 |
+
latency_target_ms: float = 30.0
|
| 58 |
+
real_probe_every: int = 0 # steps; 0 disables real probes
|
| 59 |
+
probe_batch_override: Optional[int] = None
|
| 60 |
+
gate_warmup_steps: int = 0 # Freeze gates for early steps
|
| 61 |
+
mse_weight: float = 0.0
|
| 62 |
+
|
| 63 |
+
early_stopping_patience: int = 0
|
| 64 |
+
early_stopping_lambda: float = 1e-4
|
| 65 |
+
|
| 66 |
+
amp: bool = True
|
| 67 |
+
device: str = "cuda"
|
| 68 |
+
|
| 69 |
+
# Optimizers
|
| 70 |
+
lr_gate: float = 1e-2
|
| 71 |
+
lr_linear: float = 1e-4
|
| 72 |
+
lr_affine: float = 3e-4
|
| 73 |
+
wd_linear: float = 1e-4
|
| 74 |
+
|
| 75 |
+
# Mixed precision scaler
|
| 76 |
+
use_grad_scaler: bool = True
|
| 77 |
+
|
| 78 |
+
# Dual update
|
| 79 |
+
dual: DualConfig = DualConfig()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# -----------------------------------------------------------------------------
|
| 83 |
+
# Trainer
|
| 84 |
+
# -----------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
class LagrangeTrainer:
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
student: nn.Module,
|
| 90 |
+
teacher: nn.Module,
|
| 91 |
+
proxy: LatencyProxy,
|
| 92 |
+
*,
|
| 93 |
+
adapter_get_student_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
|
| 94 |
+
adapter_get_teacher_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
|
| 95 |
+
adapter_export_keepall: Callable[[nn.Module], nn.Module],
|
| 96 |
+
adapter_export_pruned: Callable[[nn.Module, object, int], nn.Module],
|
| 97 |
+
export_policy: object,
|
| 98 |
+
cfg: TrainerConfig,
|
| 99 |
+
) -> None:
|
| 100 |
+
self.student = student
|
| 101 |
+
self.teacher = teacher.eval()
|
| 102 |
+
for p in self.teacher.parameters():
|
| 103 |
+
p.requires_grad_(False)
|
| 104 |
+
self.proxy = proxy
|
| 105 |
+
self.get_s = adapter_get_student_logits
|
| 106 |
+
self.get_t = adapter_get_teacher_logits
|
| 107 |
+
self.export_keepall = adapter_export_keepall
|
| 108 |
+
self.export_pruned = adapter_export_pruned
|
| 109 |
+
self.export_policy = export_policy
|
| 110 |
+
self.cfg = cfg
|
| 111 |
+
|
| 112 |
+
# Build optimizers (grouped)
|
| 113 |
+
param_groups = collect_param_groups(
|
| 114 |
+
student,
|
| 115 |
+
lr_gate=cfg.lr_gate,
|
| 116 |
+
lr_linear=cfg.lr_linear,
|
| 117 |
+
lr_affine=cfg.lr_affine,
|
| 118 |
+
wd_linear=cfg.wd_linear,
|
| 119 |
+
)
|
| 120 |
+
# gates-only optimizer uses first group
|
| 121 |
+
self.opt_g = torch.optim.Adam([param_groups[0]], lr=param_groups[0]["lr"]) # type: ignore[arg-type]
|
| 122 |
+
# weights optimizer for the rest
|
| 123 |
+
self.opt_w = torch.optim.Adam(param_groups[1:])
|
| 124 |
+
|
| 125 |
+
self.scaler = torch.amp.GradScaler('cuda', enabled=(cfg.amp and cfg.use_grad_scaler))
|
| 126 |
+
self.lambda_: float = 0.0
|
| 127 |
+
self.mse_weight = cfg.mse_weight
|
| 128 |
+
|
| 129 |
+
# ---- internal helpers -----------------------------------------------------
|
| 130 |
+
def _zero_grads(self, params):
|
| 131 |
+
for p in params:
|
| 132 |
+
if p.grad is not None:
|
| 133 |
+
p.grad = None
|
| 134 |
+
|
| 135 |
+
def _has_grad(self, params) -> bool:
|
| 136 |
+
for p in params:
|
| 137 |
+
if p.grad is not None:
|
| 138 |
+
return True
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
# ---- training -------------------------------------------------------------
|
| 142 |
+
def train_epoch(self, loader, *, real_policy=None, verbose_every: int = 50):
|
| 143 |
+
device = self.cfg.device
|
| 144 |
+
self.student.train().to(device)
|
| 145 |
+
self.teacher.to(device).eval()
|
| 146 |
+
|
| 147 |
+
running = 0.0
|
| 148 |
+
seen = 0
|
| 149 |
+
lam_real = self.lambda_
|
| 150 |
+
|
| 151 |
+
total_steps = len(loader)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
for step, batch in enumerate(loader, 1):
|
| 155 |
+
# Move batch to device in a type-safe way
|
| 156 |
+
batch = _move_batch_to_device(batch, device)
|
| 157 |
+
|
| 158 |
+
# with torch.inference_mode():
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
t_logits = self.get_t(self.teacher, batch) # [B,1,V]
|
| 161 |
+
# match AMP compute dtype to avoid upcasting later
|
| 162 |
+
if self.cfg.amp:
|
| 163 |
+
# infer autocast dtype from student params (bf16 or fp16)
|
| 164 |
+
sparam = next(self.student.parameters())
|
| 165 |
+
t_logits = t_logits.to(dtype=sparam.dtype, non_blocking=True)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# -------- Pass A: WEIGHTS (KD only) --------
|
| 169 |
+
self.opt_w.zero_grad(set_to_none=True)
|
| 170 |
+
|
| 171 |
+
with torch.amp.autocast('cuda', enabled=self.cfg.amp):
|
| 172 |
+
# Adapters receive the batch object (dict/tuple/tensor)
|
| 173 |
+
s_logits = self.get_s(self.student, batch)
|
| 174 |
+
# with torch.no_grad():
|
| 175 |
+
# t_logits = self.get_t(self.teacher, batch)
|
| 176 |
+
mse = self.mse_weight*mse_reg(s_logits, t_logits, self.cfg.kd.temperature)
|
| 177 |
+
loss_w = kd_loss(s_logits, t_logits, self.cfg.kd) + mse
|
| 178 |
+
|
| 179 |
+
self.scaler.scale(loss_w).backward()
|
| 180 |
+
# Prevent gate params from changing in pass A
|
| 181 |
+
gate_params = self.opt_g.param_groups[0]["params"]
|
| 182 |
+
self._zero_grads(gate_params)
|
| 183 |
+
|
| 184 |
+
if any(p.grad is not None for pg in self.opt_w.param_groups for p in pg["params"]):
|
| 185 |
+
self.scaler.step(self.opt_w)
|
| 186 |
+
self.scaler.update()
|
| 187 |
+
else:
|
| 188 |
+
self.opt_w.zero_grad(set_to_none=True)
|
| 189 |
+
|
| 190 |
+
del s_logits
|
| 191 |
+
gc.collect()
|
| 192 |
+
torch.cuda.empty_cache()
|
| 193 |
+
|
| 194 |
+
if step > int(self.cfg.gate_warmup_steps):
|
| 195 |
+
|
| 196 |
+
# -------- Pass B: GATES (KD + sparsity + λ * gap) --------
|
| 197 |
+
self.opt_g.zero_grad(set_to_none=True)
|
| 198 |
+
with torch.amp.autocast('cuda', enabled=self.cfg.amp):
|
| 199 |
+
s_logits = self.get_s(self.student, batch)
|
| 200 |
+
# with torch.no_grad():
|
| 201 |
+
# t_logits = self.get_t(self.teacher, batch)
|
| 202 |
+
kd_g = kd_loss(s_logits, t_logits, self.cfg.kd)
|
| 203 |
+
|
| 204 |
+
# Proxy gets the batch object too; family-specific proxy can read (B,S) etc.
|
| 205 |
+
o1_ms = self.proxy.predict(self.student, batch)
|
| 206 |
+
gap = torch.relu(o1_ms - float(self.cfg.latency_target_ms))
|
| 207 |
+
reg = combined_penalty(self.student, self.cfg.penalties)
|
| 208 |
+
mse = self.mse_weight*mse_reg(s_logits, t_logits, self.cfg.kd.temperature)
|
| 209 |
+
loss_g = kd_g + _to_tensor(self.lambda_, o1_ms) * gap + reg + mse
|
| 210 |
+
|
| 211 |
+
self.scaler.scale(loss_g).backward()
|
| 212 |
+
# Prevent non-gate params from changing in pass B
|
| 213 |
+
for pg in self.opt_w.param_groups:
|
| 214 |
+
self._zero_grads(pg["params"])
|
| 215 |
+
|
| 216 |
+
if self._has_grad(self.opt_g.param_groups[0]["params"]):
|
| 217 |
+
self.scaler.step(self.opt_g)
|
| 218 |
+
self.scaler.update()
|
| 219 |
+
else:
|
| 220 |
+
self.opt_g.zero_grad(set_to_none=True)
|
| 221 |
+
else:
|
| 222 |
+
o1_ms = self.proxy.predict(self.student, batch)
|
| 223 |
+
s_logits = loss_g = kd_g = reg = torch.tensor(0.0, device=device)
|
| 224 |
+
|
| 225 |
+
# -------- Dual (λ) update using proxy --------
|
| 226 |
+
with torch.no_grad():
|
| 227 |
+
lam_proxy = max(0.0, self.lambda_ + self.cfg.dual.lr * (float(o1_ms.detach()) - self.cfg.latency_target_ms))
|
| 228 |
+
self.lambda_ = 0.5 * (lam_real + lam_proxy)
|
| 229 |
+
|
| 230 |
+
# -------- Constraint projection, optional real probe --------
|
| 231 |
+
project_gates_into_constraints(self.student, self.cfg.constraints)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if self.cfg.real_probe_every and (step % int(self.cfg.real_probe_every) == 0):
|
| 235 |
+
# Build a probe shape for latency func if needed
|
| 236 |
+
try:
|
| 237 |
+
from core.measure import measure_latency_text_ms # text-friendly
|
| 238 |
+
if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
|
| 239 |
+
B, S = int(batch["input_ids"].size(0)), int(batch["input_ids"].size(1))
|
| 240 |
+
else:
|
| 241 |
+
# Fallback: try tensor-like batch
|
| 242 |
+
x0 = batch["input_ids"] if isinstance(batch, dict) else (batch[0] if isinstance(batch, (tuple, list)) else batch)
|
| 243 |
+
B = int(x0.size(0)); S = int(x0.size(1))
|
| 244 |
+
slim = self.export_pruned(self.student, real_policy or self.export_policy, step)
|
| 245 |
+
mean_ms, p95_ms = measure_latency_text_ms(slim, B=B, S=S, T=128, device=device)
|
| 246 |
+
except Exception:
|
| 247 |
+
# If the project has a different profiler, retain compatibility:
|
| 248 |
+
from .profiler import measure_latency_ms
|
| 249 |
+
x0 = batch["input_ids"] if isinstance(batch, dict) else (batch[0] if isinstance(batch, (tuple, list)) else batch)
|
| 250 |
+
shape = (int(x0.size(0)), *list(x0.shape[1:]))
|
| 251 |
+
slim = self.export_pruned(self.student, real_policy or self.export_policy, step)
|
| 252 |
+
mean_ms, p95_ms = measure_latency_ms(slim, shape, device=device)
|
| 253 |
+
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
lam_real = max(0.0, self.lambda_ + self.cfg.dual.lr * (mean_ms - self.cfg.latency_target_ms))
|
| 256 |
+
|
| 257 |
+
# scale_correction = mean_ms / max(1e-9, o1_ms.detach())
|
| 258 |
+
# self.proxy.cfg.scale_ms = 0.9 * self.proxy.cfg.scale_ms + 0.1 * scale_correction * self.proxy.cfg.scale_ms
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if (step % verbose_every) == 0:
|
| 262 |
+
print(
|
| 263 |
+
f"Step {step}/{len(loader)} | KL={float(loss_w.item()):.6f} | MSE={float(mse.item()):.6f} | "
|
| 264 |
+
f"Gate={float(loss_g.item()):.6f} | "
|
| 265 |
+
f"proxy={float(o1_ms.detach()):.3f}ms | real_mean={mean_ms:.3f}ms p95={p95_ms:.3f}ms | λ={self.lambda_:.6f}"
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
running += float(loss_g.detach())
|
| 269 |
+
seen += _batch_size(batch)
|
| 270 |
+
|
| 271 |
+
del s_logits, t_logits, o1_ms, kd_g, reg, loss_g, loss_w
|
| 272 |
+
torch.cuda.empty_cache()
|
| 273 |
+
gc.collect()
|
| 274 |
+
|
| 275 |
+
print(f"Epoch loss {running / max(1, seen):.6f}")
|
| 276 |
+
return self.lambda_
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# -----------------------------------------------------------------------------
|
| 280 |
+
# Helpers
|
| 281 |
+
# -----------------------------------------------------------------------------
|
| 282 |
+
|
| 283 |
+
def _to_tensor(val: float, like: torch.Tensor) -> torch.Tensor:
|
| 284 |
+
return torch.as_tensor(val, device=like.device, dtype=like.dtype)
|
| 285 |
+
|
| 286 |
+
def _move_batch_to_device(batch, device: str):
|
| 287 |
+
"""
|
| 288 |
+
Supports:
|
| 289 |
+
- dict with keys 'input_ids' and optional 'attention_mask'
|
| 290 |
+
- (x,) or (x, y) tuples/lists -> move each tensor-like to device
|
| 291 |
+
- single Tensor
|
| 292 |
+
Converts attention_mask to bool (preferred by HF SDPA).
|
| 293 |
+
"""
|
| 294 |
+
if isinstance(batch, dict):
|
| 295 |
+
out = {}
|
| 296 |
+
for k, v in batch.items():
|
| 297 |
+
if torch.is_tensor(v):
|
| 298 |
+
v = v.to(device, non_blocking=True)
|
| 299 |
+
if k == "attention_mask" and v.dtype != torch.bool:
|
| 300 |
+
v = v.to(torch.bool)
|
| 301 |
+
out[k] = v
|
| 302 |
+
return out
|
| 303 |
+
|
| 304 |
+
if isinstance(batch, (tuple, list)):
|
| 305 |
+
moved = []
|
| 306 |
+
for v in batch:
|
| 307 |
+
if torch.is_tensor(v):
|
| 308 |
+
v = v.to(device, non_blocking=True)
|
| 309 |
+
moved.append(v)
|
| 310 |
+
return type(batch)(moved)
|
| 311 |
+
|
| 312 |
+
if torch.is_tensor(batch):
|
| 313 |
+
return batch.to(device, non_blocking=True)
|
| 314 |
+
|
| 315 |
+
# Unknown type: return as-is (adapters/proxy should handle it)
|
| 316 |
+
return batch
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _batch_size(batch) -> int:
|
| 320 |
+
"""Best-effort batch size for logging/averages."""
|
| 321 |
+
if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
|
| 322 |
+
return int(batch["input_ids"].size(0))
|
| 323 |
+
if torch.is_tensor(batch):
|
| 324 |
+
return int(batch.size(0))
|
| 325 |
+
if isinstance(batch, (tuple, list)) and len(batch) and torch.is_tensor(batch[0]):
|
| 326 |
+
return int(batch[0].size(0))
|
| 327 |
+
return 1
|
core/utils.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared utilities used across core and adapters.
|
| 2 |
+
|
| 3 |
+
Consolidates helpers that are generic (device/dtype, seeding, shapes, rounding,
|
| 4 |
+
parameter grouping, model copying, etc.). Keep this file dependency-light.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
|
| 10 |
+
|
| 11 |
+
import copy
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# -----------------------------------------------------------------------------
|
| 20 |
+
# Device / dtype helpers
|
| 21 |
+
# -----------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
def as_like(x: torch.Tensor, val) -> torch.Tensor:
|
| 24 |
+
"""Create a scalar/tensor constant on same device/dtype as `x`."""
|
| 25 |
+
return torch.as_tensor(val, device=x.device, dtype=x.dtype)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def first_param(module: nn.Module) -> torch.Tensor:
|
| 29 |
+
for p in module.parameters(recurse=True):
|
| 30 |
+
return p
|
| 31 |
+
return torch.tensor(0.0)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def to_device_dtype(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
return x.to(device=ref.device, dtype=ref.dtype)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# -----------------------------------------------------------------------------
|
| 39 |
+
# Seeding & determinism
|
| 40 |
+
# -----------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
def set_seed(seed: int = 42, deterministic: bool = False) -> None:
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
np.random.seed(seed)
|
| 45 |
+
torch.manual_seed(seed)
|
| 46 |
+
torch.cuda.manual_seed_all(seed)
|
| 47 |
+
if deterministic:
|
| 48 |
+
torch.backends.cudnn.deterministic = True
|
| 49 |
+
torch.backends.cudnn.benchmark = False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# -----------------------------------------------------------------------------
|
| 53 |
+
# Model parameter helpers
|
| 54 |
+
# -----------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
def freeze(module: nn.Module) -> None:
|
| 57 |
+
for p in module.parameters():
|
| 58 |
+
p.requires_grad_(False)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def unfreeze(module: nn.Module) -> None:
|
| 62 |
+
for p in module.parameters():
|
| 63 |
+
p.requires_grad_(True)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def count_parameters(module: nn.Module, *, trainable_only: bool = False) -> int:
|
| 67 |
+
if trainable_only:
|
| 68 |
+
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 69 |
+
return sum(p.numel() for p in module.parameters())
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# -----------------------------------------------------------------------------
|
| 73 |
+
# Shape/signature helpers
|
| 74 |
+
# -----------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def input_spec_vision(sample) -> Tuple[int, int, int]:
|
| 77 |
+
"""Accept either a 4D tensor [B,3,H,W] or a 4-tuple (B,3,H,W). Returns (B,H,W)."""
|
| 78 |
+
if isinstance(sample, torch.Tensor):
|
| 79 |
+
B, C, H, W = sample.shape
|
| 80 |
+
return int(B), int(H), int(W)
|
| 81 |
+
if isinstance(sample, (tuple, list)) and len(sample) == 4:
|
| 82 |
+
B, C, H, W = sample
|
| 83 |
+
return int(B), int(H), int(W)
|
| 84 |
+
raise ValueError("sample must be a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# -----------------------------------------------------------------------------
|
| 88 |
+
# Rounding / multiples
|
| 89 |
+
# -----------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
def round_down_multiple(n: int, m: int) -> int:
|
| 92 |
+
if m is None or m <= 1:
|
| 93 |
+
return max(1, int(n))
|
| 94 |
+
n = int(n)
|
| 95 |
+
return max(m, (n // m) * m)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def clamp_int(v: int, lo: int, hi: int) -> int:
|
| 99 |
+
return max(lo, min(int(v), hi))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# -----------------------------------------------------------------------------
|
| 103 |
+
# Slicing helpers
|
| 104 |
+
# -----------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
|
| 108 |
+
W = mat.weight.detach()
|
| 109 |
+
b = mat.bias.detach() if mat.bias is not None else None
|
| 110 |
+
if keep_out is not None:
|
| 111 |
+
idx_out = torch.as_tensor(keep_out, device=W.device)
|
| 112 |
+
W = W.index_select(0, idx_out)
|
| 113 |
+
if b is not None:
|
| 114 |
+
b = b.index_select(0, idx_out)
|
| 115 |
+
if keep_in is not None:
|
| 116 |
+
idx_in = torch.as_tensor(keep_in, device=W.device)
|
| 117 |
+
W = W.index_select(1, idx_in)
|
| 118 |
+
out_f, in_f = W.shape
|
| 119 |
+
new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
|
| 120 |
+
new.weight.copy_(W)
|
| 121 |
+
if b is not None:
|
| 122 |
+
new.bias.copy_(b)
|
| 123 |
+
return new
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# -----------------------------------------------------------------------------
|
| 127 |
+
# Copying & detaching models
|
| 128 |
+
# -----------------------------------------------------------------------------
|
| 129 |
+
|
| 130 |
+
def deepcopy_eval_cpu(module: nn.Module) -> nn.Module:
|
| 131 |
+
m = copy.deepcopy(module).cpu().eval()
|
| 132 |
+
return m
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# -----------------------------------------------------------------------------
|
| 136 |
+
# Gradient utilities
|
| 137 |
+
# -----------------------------------------------------------------------------
|
| 138 |
+
|
| 139 |
+
def zero_if_any(params: Iterable[torch.Tensor]) -> None:
|
| 140 |
+
for p in params:
|
| 141 |
+
if p.grad is not None:
|
| 142 |
+
p.grad = None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def any_grad(params: Iterable[torch.Tensor]) -> bool:
|
| 146 |
+
for p in params:
|
| 147 |
+
if p.grad is not None:
|
| 148 |
+
return True
|
| 149 |
+
return False
|
| 150 |
+
|
| 151 |
+
# -----------------------------------------------------------------------------
|
| 152 |
+
# For fine-tuning
|
| 153 |
+
# -----------------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
def ensure_trainable_parameters(module: nn.Module, *, requires_grad: bool = True) -> nn.Module:
|
| 156 |
+
"""
|
| 157 |
+
Rebuild all parameters as fresh nn.Parameter tensors (detach+clone),
|
| 158 |
+
which drops any 'inference tensor' tag and re-enables autograd.
|
| 159 |
+
"""
|
| 160 |
+
for mod in module.modules():
|
| 161 |
+
for name, p in list(mod._parameters.items()):
|
| 162 |
+
if p is None:
|
| 163 |
+
continue
|
| 164 |
+
new_p = nn.Parameter(p.detach().clone(), requires_grad=requires_grad)
|
| 165 |
+
setattr(mod, name, new_p)
|
| 166 |
+
return module
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# -----------------------------------------------------------------------------
|
| 170 |
+
# Misc
|
| 171 |
+
# -----------------------------------------------------------------------------
|
| 172 |
+
|
| 173 |
+
@dataclass
|
| 174 |
+
class ExportRounding:
|
| 175 |
+
head_floor_post: int = 1
|
| 176 |
+
head_multiple_post: int = 1
|
| 177 |
+
ffn_min_keep_ratio_post: float = 0.0
|
| 178 |
+
ffn_snap_groups_post: int = 1
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def shape_signature_vit(cfg, sample_shape: Tuple[int, int, int, int]) -> Tuple:
|
| 182 |
+
B, C, H, W = sample_shape
|
| 183 |
+
return (
|
| 184 |
+
"ViT",
|
| 185 |
+
sample_shape,
|
| 186 |
+
int(getattr(cfg, "num_attention_heads", 12)),
|
| 187 |
+
int(getattr(cfg, "hidden_size", 768)),
|
| 188 |
+
int(getattr(cfg, "intermediate_size", 3072)),
|
| 189 |
+
int(getattr(cfg, "patch_size", 16)) if not isinstance(getattr(cfg, "patch_size", 16), (tuple, list)) else tuple(getattr(cfg, "patch_size", (16, 16))),
|
| 190 |
+
)
|
custom_code.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Marker file so Hub shows 'custom code' banner.
|
huggingface/.ipynb_checkpoints/llama-checkpoint.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace LLaMA/Mistral adapter
|
| 2 |
+
|
| 3 |
+
Bridges the family-agnostic core (gates/export/proxy/train) to HF causal LMs
|
| 4 |
+
(LlamaForCausalLM / MistralForCausalLM, etc.).
|
| 5 |
+
|
| 6 |
+
Responsibilities
|
| 7 |
+
----------------
|
| 8 |
+
- Attach gates to attention Q heads (and optional KV) + grouped MLP (SwiGLU)
|
| 9 |
+
- Provide a logits getter (student/teacher)
|
| 10 |
+
- Exporters:
|
| 11 |
+
* keep-all (unwrap gates, restore clean HF modules)
|
| 12 |
+
* pruned (slice q_proj/o_proj and SwiGLU up/gate/down; update HF metadata)
|
| 13 |
+
- Grid-search wrapper for post-export rounding/snap params
|
| 14 |
+
|
| 15 |
+
This adapter intentionally keeps the core unaware of LLaMA internals.
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
# Ensure repo root on sys.path for absolute imports (core, adapters, data)
|
| 20 |
+
import sys, pathlib
|
| 21 |
+
sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
|
| 22 |
+
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from typing import Optional, Sequence, Callable, Tuple
|
| 25 |
+
|
| 26 |
+
import copy
|
| 27 |
+
import math
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
|
| 32 |
+
# Core (absolute imports so running `-m examples.run_llama_optimize` works)
|
| 33 |
+
from core.gates import HeadGate, GroupGate
|
| 34 |
+
from core.export import (
|
| 35 |
+
ExportPolicy as CoreExportPolicy,
|
| 36 |
+
Rounding as CoreRounding,
|
| 37 |
+
keep_group_indices_from_gate,
|
| 38 |
+
slice_linear,
|
| 39 |
+
)
|
| 40 |
+
from core.utils import deepcopy_eval_cpu
|
| 41 |
+
from core.search_export import grid_search_latency
|
| 42 |
+
|
| 43 |
+
# -------------------------------------------------------------------------
|
| 44 |
+
# Configs
|
| 45 |
+
# -------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class LlamaGatingConfig:
|
| 49 |
+
tau: float = 1.5
|
| 50 |
+
init_logit: float = 3.0
|
| 51 |
+
head_gating: bool = True
|
| 52 |
+
gate_kv: bool = False # optional: gate KV along with Q
|
| 53 |
+
ffn_group: int = 128 # SwiGLU groups
|
| 54 |
+
ffn_gating: bool = True
|
| 55 |
+
hard_eval: bool = True # use hard gates in eval forward
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# -------------------------------------------------------------------------
|
| 59 |
+
# Helpers (GQA, rotary, cache-safe)
|
| 60 |
+
# -------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _last_nonpad_index(attn_mask: Optional[torch.Tensor], seq_len: int, device) -> torch.Tensor:
|
| 64 |
+
if attn_mask is None:
|
| 65 |
+
return torch.full((1,), seq_len - 1, device=device, dtype=torch.long) # will be expanded per-batch later
|
| 66 |
+
# attn_mask: [B, S] in {0,1}; works for left/right padding
|
| 67 |
+
return (attn_mask.sum(dim=1) - 1).clamp(min=0).long()
|
| 68 |
+
|
| 69 |
+
def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 70 |
+
if n_rep == 1:
|
| 71 |
+
return x
|
| 72 |
+
B, Hkv, T, Dh = x.shape
|
| 73 |
+
return x.unsqueeze(2).expand(B, Hkv, n_rep, T, Dh).reshape(B, Hkv * n_rep, T, Dh)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
from transformers.cache_utils import Cache
|
| 77 |
+
except Exception:
|
| 78 |
+
class Cache: # type: ignore
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# -------------------------------------------------------------------------
|
| 83 |
+
# Gated attention wrapper (Llama/Mistral ready)
|
| 84 |
+
# -------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
class GatedSelfAttentionLLM(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
Thin wrapper around HF Llama/Mistral attention module.
|
| 89 |
+
|
| 90 |
+
- Uses the base module's q_proj/k_proj/v_proj/o_proj
|
| 91 |
+
- Applies per-Q-head gates (and optional KV gates)
|
| 92 |
+
- Handles rotary and cache (tuple or HF Cache)
|
| 93 |
+
- Runs SDPA directly, then o_proj
|
| 94 |
+
"""
|
| 95 |
+
def __init__(self, attn_container: nn.Module,
|
| 96 |
+
num_q_heads: int, num_kv_heads: int, head_dim: int,
|
| 97 |
+
cfg: LlamaGatingConfig, layer_idx: int):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.base_attn = attn_container
|
| 100 |
+
self.q_proj = attn_container.q_proj
|
| 101 |
+
self.k_proj = attn_container.k_proj
|
| 102 |
+
self.v_proj = attn_container.v_proj
|
| 103 |
+
self.o_proj = getattr(attn_container, "o_proj", getattr(attn_container, "out_proj", None))
|
| 104 |
+
|
| 105 |
+
self.num_q_heads = int(num_q_heads)
|
| 106 |
+
self.num_kv_heads = int(num_kv_heads)
|
| 107 |
+
self.head_dim = int(head_dim)
|
| 108 |
+
self.gate_kv = bool(cfg.gate_kv)
|
| 109 |
+
self.drop_p = float(getattr(attn_container, "attention_dropout",
|
| 110 |
+
getattr(attn_container, "attn_dropout",
|
| 111 |
+
getattr(attn_container, "dropout", 0.0))))
|
| 112 |
+
self.head_gate = HeadGate(num_heads=self.num_q_heads,
|
| 113 |
+
head_dim=self.head_dim,
|
| 114 |
+
tau=cfg.tau, init_logit=cfg.init_logit,
|
| 115 |
+
hard_during_eval=cfg.hard_eval)
|
| 116 |
+
|
| 117 |
+
# rotary helpers if present on base
|
| 118 |
+
self.rotary_emb = getattr(attn_container, "rotary_emb", None)
|
| 119 |
+
self.apply_rotary_pos_emb = getattr(attn_container, "apply_rotary_pos_emb", None)
|
| 120 |
+
self.layer_idx = int(layer_idx)
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def logits(self) -> torch.Tensor:
|
| 124 |
+
return self.head_gate.logits
|
| 125 |
+
|
| 126 |
+
def kept_heads_soft(self) -> torch.Tensor:
|
| 127 |
+
p = self.head_gate.probs().detach().float().view(-1)
|
| 128 |
+
if p.numel() == self.num_q_heads * self.head_dim:
|
| 129 |
+
p = p.view(self.num_q_heads, self.head_dim).mean(dim=1)
|
| 130 |
+
return p.sum()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def forward(
|
| 134 |
+
self,
|
| 135 |
+
hidden_states: torch.Tensor, # [B,T,D]
|
| 136 |
+
attention_mask: Optional[torch.Tensor] = None, # additive mask [B,1,Tq,Tk] or None
|
| 137 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 138 |
+
past_key_value = None, # tuple, list, Cache or None
|
| 139 |
+
output_attentions: bool = False,
|
| 140 |
+
use_cache: bool = False,
|
| 141 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 142 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 143 |
+
**kwargs,
|
| 144 |
+
):
|
| 145 |
+
B, T, D = hidden_states.shape
|
| 146 |
+
Hq, Hkv, Dh = self.num_q_heads, self.num_kv_heads, self.head_dim
|
| 147 |
+
assert Hq * Dh == D, "hidden_size must equal num_heads * head_dim"
|
| 148 |
+
n_rep = max(1, Hq // Hkv)
|
| 149 |
+
|
| 150 |
+
# qkv projections
|
| 151 |
+
q = self.q_proj(hidden_states).view(B, T, Hq, Dh).transpose(1, 2) # [B,Hq,T,Dh]
|
| 152 |
+
k = self.k_proj(hidden_states).view(B, T, Hkv, Dh).transpose(1, 2) # [B,Hkv,T,Dh]
|
| 153 |
+
v = self.v_proj(hidden_states).view(B, T, Hkv, Dh).transpose(1, 2) # [B,Hkv,T,Dh]
|
| 154 |
+
|
| 155 |
+
# rotary
|
| 156 |
+
if (self.rotary_emb is not None) and (self.apply_rotary_pos_emb is not None):
|
| 157 |
+
Tpast = 0
|
| 158 |
+
if isinstance(past_key_value, (tuple, list)) and len(past_key_value) == 2:
|
| 159 |
+
Tpast = int(past_key_value[0].size(2))
|
| 160 |
+
elif isinstance(past_key_value, Cache):
|
| 161 |
+
Tpast = int(cache_position.max().item() if cache_position is not None else 0)
|
| 162 |
+
seq_len = Tpast + T
|
| 163 |
+
try:
|
| 164 |
+
cos, sin = self.rotary_emb(v, seq_len=seq_len)
|
| 165 |
+
except TypeError:
|
| 166 |
+
cos, sin = self.rotary_emb(q, seq_len=seq_len)
|
| 167 |
+
# try rich signature first
|
| 168 |
+
try:
|
| 169 |
+
q, k = self.apply_rotary_pos_emb(
|
| 170 |
+
q, k, cos, sin,
|
| 171 |
+
position_ids=position_ids,
|
| 172 |
+
cache_position=cache_position,
|
| 173 |
+
position_embeddings=position_embeddings
|
| 174 |
+
)
|
| 175 |
+
except TypeError:
|
| 176 |
+
try:
|
| 177 |
+
q, k = self.apply_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids)
|
| 178 |
+
except TypeError:
|
| 179 |
+
q, k = self.apply_rotary_pos_emb(q, k, cos, sin)
|
| 180 |
+
|
| 181 |
+
# cache merge
|
| 182 |
+
present = None
|
| 183 |
+
if past_key_value is None or (isinstance(past_key_value, (tuple, list)) and len(past_key_value) == 0):
|
| 184 |
+
pass
|
| 185 |
+
elif isinstance(past_key_value, (tuple, list)):
|
| 186 |
+
pk, pv = past_key_value # [B,Hkv,Tpast,Dh]
|
| 187 |
+
k = torch.cat([pk, k], dim=2)
|
| 188 |
+
v = torch.cat([pv, v], dim=2)
|
| 189 |
+
present = (k, v) if use_cache else None
|
| 190 |
+
elif isinstance(past_key_value, Cache):
|
| 191 |
+
k, v = past_key_value.update(k, v, self.layer_idx, cache_position)
|
| 192 |
+
present = past_key_value
|
| 193 |
+
|
| 194 |
+
# gates
|
| 195 |
+
# g = self.head_gate.mask(self.training).view(1, Hq, 1, 1)
|
| 196 |
+
# ---- gates (supports per-head OR per-channel HeadGate) ----
|
| 197 |
+
m = self.head_gate.mask(self.training) # 1D tensor
|
| 198 |
+
m = m.detach() if not self.training else m
|
| 199 |
+
if m.numel() == Hq:
|
| 200 |
+
# per-head gating
|
| 201 |
+
gH = m.view(1, Hq, 1, 1) # [1,Hq,1,1]
|
| 202 |
+
q = q * gH
|
| 203 |
+
if self.gate_kv:
|
| 204 |
+
if n_rep == 1:
|
| 205 |
+
k = k * gH; v = v * gH
|
| 206 |
+
else:
|
| 207 |
+
g_kv = gH.view(1, Hkv, n_rep, 1, 1).amax(dim=2)
|
| 208 |
+
k = k * g_kv; v = v * g_kv
|
| 209 |
+
elif m.numel() == Hq * Dh:
|
| 210 |
+
# per-channel gating
|
| 211 |
+
gHD = m.view(1, Hq, 1, Dh) # [1,Hq,1,Dh]
|
| 212 |
+
q = q * gHD
|
| 213 |
+
if self.gate_kv:
|
| 214 |
+
# collapse to per-head for KV, then map to Hkv via amax over replicas
|
| 215 |
+
gH = gHD.amax(dim=-1, keepdim=True) # [1,Hq,1,1]
|
| 216 |
+
if n_rep == 1:
|
| 217 |
+
k = k * gH; v = v * gH
|
| 218 |
+
else:
|
| 219 |
+
g_kv = gH.view(1, Hkv, n_rep, 1, 1).amax(dim=2)
|
| 220 |
+
k = k * g_kv; v = v * g_kv
|
| 221 |
+
else:
|
| 222 |
+
raise RuntimeError(
|
| 223 |
+
f"HeadGate mask has {m.numel()} elems; expected {Hq} or {Hq*Dh}"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# GQA replicate KV to Q count
|
| 228 |
+
k = _repeat_kv(k, n_rep)
|
| 229 |
+
v = _repeat_kv(v, n_rep)
|
| 230 |
+
|
| 231 |
+
attn = F.scaled_dot_product_attention(
|
| 232 |
+
q, k, v,
|
| 233 |
+
attn_mask=attention_mask,
|
| 234 |
+
dropout_p=self.drop_p if self.training else 0.0,
|
| 235 |
+
is_causal=True
|
| 236 |
+
)
|
| 237 |
+
out = attn.transpose(1, 2).contiguous().view(B, T, Hq * Dh)
|
| 238 |
+
out = self.o_proj(out)
|
| 239 |
+
|
| 240 |
+
attn_weights = None
|
| 241 |
+
# HF expects (attn_output, attn_weights, present_key_value) always
|
| 242 |
+
if output_attentions:
|
| 243 |
+
return (out, attn_weights, present)
|
| 244 |
+
else:
|
| 245 |
+
return (out, None, present)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# -------------------------------------------------------------------------
|
| 250 |
+
# Adapter
|
| 251 |
+
# -------------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
class LlamaAdapter:
|
| 254 |
+
def __init__(self, model: nn.Module):
|
| 255 |
+
self.model = model
|
| 256 |
+
core = getattr(model, "model", model)
|
| 257 |
+
if not hasattr(core, "layers"):
|
| 258 |
+
raise ValueError("Provided model does not look like HF LLaMA/Mistral (missing .model.layers or .layers)")
|
| 259 |
+
|
| 260 |
+
# ---------- Gating attachment ----------
|
| 261 |
+
def attach_gates(self, cfg: LlamaGatingConfig) -> nn.Module:
|
| 262 |
+
m = self.model
|
| 263 |
+
core = getattr(m, "model", m)
|
| 264 |
+
layers = core.layers
|
| 265 |
+
|
| 266 |
+
Hq = int(core.config.num_attention_heads)
|
| 267 |
+
Hkv = int(getattr(core.config, "num_key_value_heads", Hq))
|
| 268 |
+
Dh = int(core.config.hidden_size // Hq)
|
| 269 |
+
|
| 270 |
+
for li, layer in enumerate(layers):
|
| 271 |
+
# Attention heads
|
| 272 |
+
if cfg.head_gating:
|
| 273 |
+
base = layer.self_attn
|
| 274 |
+
if not isinstance(base, GatedSelfAttentionLLM):
|
| 275 |
+
gated = GatedSelfAttentionLLM(
|
| 276 |
+
attn_container=base,
|
| 277 |
+
num_q_heads=Hq,
|
| 278 |
+
num_kv_heads=Hkv,
|
| 279 |
+
head_dim=Dh,
|
| 280 |
+
cfg=cfg,
|
| 281 |
+
layer_idx=li,
|
| 282 |
+
)
|
| 283 |
+
layer.self_attn = gated # route via our wrapper
|
| 284 |
+
|
| 285 |
+
# MLP grouped gating (SwiGLU)
|
| 286 |
+
if cfg.ffn_gating:
|
| 287 |
+
mlp = layer.mlp
|
| 288 |
+
I = int(mlp.up_proj.out_features)
|
| 289 |
+
assert I % cfg.ffn_group == 0, f"SwiGLU size {I} not divisible by group {cfg.ffn_group}"
|
| 290 |
+
if not hasattr(mlp, "neuron_gate"):
|
| 291 |
+
mlp.neuron_gate = GroupGate(
|
| 292 |
+
num_groups=I // cfg.ffn_group,
|
| 293 |
+
group_size=cfg.ffn_group,
|
| 294 |
+
tau=cfg.tau, init_logit=cfg.init_logit,
|
| 295 |
+
hard_during_eval=cfg.hard_eval,
|
| 296 |
+
)
|
| 297 |
+
if not hasattr(mlp, "_orig_forward"):
|
| 298 |
+
mlp._orig_forward = mlp.forward
|
| 299 |
+
|
| 300 |
+
def _gated_mlp_forward(this, x):
|
| 301 |
+
# LLaMA: z = silu(up(x)) * (gate(x) * m); out = down(z)
|
| 302 |
+
u = this.up_proj(x)
|
| 303 |
+
g = this.gate_proj(x)
|
| 304 |
+
m = this.neuron_gate.mask(this.training).view(1, 1, -1)
|
| 305 |
+
z = torch.nn.functional.silu(u) * (g * m)
|
| 306 |
+
return this.down_proj(z)
|
| 307 |
+
|
| 308 |
+
mlp.forward = _gated_mlp_forward.__get__(mlp, mlp.__class__)
|
| 309 |
+
return m
|
| 310 |
+
|
| 311 |
+
# ---------- Logits helper ----------
|
| 312 |
+
@staticmethod
|
| 313 |
+
def _last_token_index(attn_mask: torch.Tensor) -> torch.Tensor:
|
| 314 |
+
# attn_mask: [B, S] with 1 for tokens, 0 for padding
|
| 315 |
+
# returns [B] indices of last non-pad
|
| 316 |
+
# works for both bool and int masks
|
| 317 |
+
if attn_mask is None:
|
| 318 |
+
# no mask → use last position S-1
|
| 319 |
+
return None
|
| 320 |
+
if attn_mask.dtype != torch.long:
|
| 321 |
+
attn_mask = attn_mask.to(torch.long)
|
| 322 |
+
# idx = lengths - 1
|
| 323 |
+
return (attn_mask.sum(dim=-1) - 1).clamp_min(0)
|
| 324 |
+
|
| 325 |
+
@staticmethod
|
| 326 |
+
def get_logits(model: nn.Module,
|
| 327 |
+
input_ids: torch.Tensor,
|
| 328 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 329 |
+
last_only: bool = True,
|
| 330 |
+
**forward_kwargs) -> torch.Tensor:
|
| 331 |
+
"""
|
| 332 |
+
Returns logits. If last_only=True, computes ONLY the last-token logits by:
|
| 333 |
+
1) getting hidden states from the base decoder,
|
| 334 |
+
2) selecting last non-pad position per sample,
|
| 335 |
+
3) projecting through lm_head on that 1 position.
|
| 336 |
+
This avoids allocating [B,S,V].
|
| 337 |
+
"""
|
| 338 |
+
# (1) run base decoder, not the full CausalLM head
|
| 339 |
+
core = getattr(model, "model", None)
|
| 340 |
+
if core is None:
|
| 341 |
+
# fallback if the model is already a bare decoder (rare)
|
| 342 |
+
core = model
|
| 343 |
+
|
| 344 |
+
# We only need last_hidden_state; no cache; avoid building logits for all S
|
| 345 |
+
# return_dict=False to grab tuple and avoid extra allocations
|
| 346 |
+
outputs = core(
|
| 347 |
+
input_ids=input_ids,
|
| 348 |
+
attention_mask=attention_mask,
|
| 349 |
+
use_cache=False,
|
| 350 |
+
return_dict=False,
|
| 351 |
+
**forward_kwargs
|
| 352 |
+
)
|
| 353 |
+
hidden = outputs[0] # [B, S, D]
|
| 354 |
+
|
| 355 |
+
if not last_only:
|
| 356 |
+
# If someone explicitly wants all logits, fine:
|
| 357 |
+
return model.lm_head(hidden) # [B,S,V] (expensive!)
|
| 358 |
+
|
| 359 |
+
# (2) select last token per sample
|
| 360 |
+
B, S, D = hidden.shape
|
| 361 |
+
if attention_mask is None:
|
| 362 |
+
# simple "last index"
|
| 363 |
+
idx = torch.full((B,), S - 1, device=hidden.device, dtype=torch.long)
|
| 364 |
+
else:
|
| 365 |
+
idx = LlamaAdapter._last_token_index(attention_mask)
|
| 366 |
+
|
| 367 |
+
# gather last hidden: [B, D]
|
| 368 |
+
last_h = hidden[torch.arange(B, device=hidden.device), idx] # [B, D]
|
| 369 |
+
# (3) project to logits for that 1 position
|
| 370 |
+
last_logits = model.lm_head(last_h).unsqueeze(1) # [B,1,V]
|
| 371 |
+
return last_logits
|
| 372 |
+
|
| 373 |
+
# ---------- Exporters ----------
|
| 374 |
+
@staticmethod
|
| 375 |
+
@torch.no_grad()
|
| 376 |
+
def export_keepall(model_with_gates: nn.Module) -> nn.Module:
|
| 377 |
+
"""
|
| 378 |
+
Unwrap attention wrappers; restore original MLP.forward; drop gates.
|
| 379 |
+
"""
|
| 380 |
+
slim = deepcopy_eval_cpu(model_with_gates)
|
| 381 |
+
core = getattr(slim, "model", slim)
|
| 382 |
+
if not hasattr(core, "layers"):
|
| 383 |
+
return slim
|
| 384 |
+
|
| 385 |
+
for layer in core.layers:
|
| 386 |
+
# attention
|
| 387 |
+
attn = layer.self_attn
|
| 388 |
+
if isinstance(attn, GatedSelfAttentionLLM):
|
| 389 |
+
gat = attn
|
| 390 |
+
new_attn = copy.deepcopy(gat.base_attn)
|
| 391 |
+
# keep metadata consistent
|
| 392 |
+
if hasattr(new_attn, "num_heads"):
|
| 393 |
+
new_attn.num_heads = int(gat.num_q_heads)
|
| 394 |
+
if hasattr(new_attn, "num_key_value_heads"):
|
| 395 |
+
new_attn.num_key_value_heads = int(gat.num_kv_heads)
|
| 396 |
+
if hasattr(new_attn, "head_dim"):
|
| 397 |
+
new_attn.head_dim = int(gat.head_dim)
|
| 398 |
+
layer.self_attn = new_attn
|
| 399 |
+
|
| 400 |
+
# mlp
|
| 401 |
+
mlp = layer.mlp
|
| 402 |
+
if hasattr(mlp, "_orig_forward"):
|
| 403 |
+
mlp.forward = mlp._orig_forward
|
| 404 |
+
delattr(mlp, "_orig_forward")
|
| 405 |
+
if hasattr(mlp, "neuron_gate"):
|
| 406 |
+
delattr(mlp, "neuron_gate")
|
| 407 |
+
|
| 408 |
+
return slim
|
| 409 |
+
|
| 410 |
+
@staticmethod
|
| 411 |
+
@torch.no_grad()
|
| 412 |
+
def export_pruned(model_with_gates: nn.Module, policy, step: int) -> nn.Module:
|
| 413 |
+
"""
|
| 414 |
+
Produce a clean CPU eval model:
|
| 415 |
+
- Read gates to choose Q heads; slice q_proj rows and o_proj cols
|
| 416 |
+
- Snap kept heads to an LCM of (policy multiple, Hkv)
|
| 417 |
+
- Slice SwiGLU up/gate/down by groups
|
| 418 |
+
- Unwrap back to plain HF modules; update metadata
|
| 419 |
+
"""
|
| 420 |
+
# Accept either CoreExportPolicy with per-axis rounding or family policy
|
| 421 |
+
if isinstance(policy, LlamaExportPolicy):
|
| 422 |
+
head_rounding = policy.head_rounding
|
| 423 |
+
ffn_rounding = policy.ffn_rounding
|
| 424 |
+
warmup_steps = policy.warmup_steps
|
| 425 |
+
else:
|
| 426 |
+
head_rounding = getattr(policy, "rounding", None)
|
| 427 |
+
ffn_rounding = getattr(policy, "rounding", None)
|
| 428 |
+
warmup_steps = int(getattr(policy, "warmup_steps", 0))
|
| 429 |
+
|
| 430 |
+
slim = deepcopy_eval_cpu(model_with_gates)
|
| 431 |
+
core = getattr(slim, "model", slim)
|
| 432 |
+
layers = getattr(core, "layers", None)
|
| 433 |
+
if layers is None:
|
| 434 |
+
return slim
|
| 435 |
+
|
| 436 |
+
warm = (step < warmup_steps)
|
| 437 |
+
|
| 438 |
+
def _lcm(a: int, b: int) -> int:
|
| 439 |
+
return abs(a * b) // math.gcd(max(a, 1), max(b, 1)) if a > 0 and b > 0 else max(a, b, 1)
|
| 440 |
+
|
| 441 |
+
for li, layer in enumerate(layers):
|
| 442 |
+
# ---- Attention (Q heads) ----
|
| 443 |
+
attn = layer.self_attn
|
| 444 |
+
if isinstance(attn, GatedSelfAttentionLLM):
|
| 445 |
+
gat = attn
|
| 446 |
+
base = gat.base_attn
|
| 447 |
+
|
| 448 |
+
Hq = int(gat.num_q_heads)
|
| 449 |
+
Hkv = int(gat.num_kv_heads)
|
| 450 |
+
Dh = int(gat.head_dim)
|
| 451 |
+
|
| 452 |
+
if warm:
|
| 453 |
+
keep_idx = torch.arange(Hq)
|
| 454 |
+
else:
|
| 455 |
+
# Build a "per-head" proxy gate if base gate is per-channel.
|
| 456 |
+
base_logits = gat.head_gate.logits.detach().float().view(-1)
|
| 457 |
+
tau = float(getattr(gat.head_gate, "tau", 1.0))
|
| 458 |
+
|
| 459 |
+
if base_logits.numel() == Hq:
|
| 460 |
+
# Native per-head gate: use as-is
|
| 461 |
+
proxy_gate = gat.head_gate
|
| 462 |
+
keep_idx = keep_group_indices_from_gate(
|
| 463 |
+
proxy_gate, policy=policy, step=step, custom_rounding=head_rounding
|
| 464 |
+
)
|
| 465 |
+
elif base_logits.numel() == Hq * Dh:
|
| 466 |
+
# Collapse per-channel → per-head (mean; or use .amax for stricter)
|
| 467 |
+
per_head_logits = base_logits.view(Hq, Dh).mean(dim=1)
|
| 468 |
+
|
| 469 |
+
class _PerHeadProxyGate:
|
| 470 |
+
def __init__(self, logits, tau):
|
| 471 |
+
self.logits = logits
|
| 472 |
+
self.tau = tau
|
| 473 |
+
self.num_groups = logits.numel()
|
| 474 |
+
self.group_size = 1
|
| 475 |
+
|
| 476 |
+
proxy_gate = _PerHeadProxyGate(per_head_logits, tau)
|
| 477 |
+
keep_idx = keep_group_indices_from_gate(
|
| 478 |
+
proxy_gate, policy=policy, step=step, custom_rounding=head_rounding
|
| 479 |
+
)
|
| 480 |
+
else:
|
| 481 |
+
raise RuntimeError(
|
| 482 |
+
f"Unexpected HeadGate logits len {base_logits.numel()} vs H={Hq} or H*Dh={Hq*Dh}"
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Enforce LCM with GQA (Hkv) via truncation to floor-multiple
|
| 486 |
+
def _lcm(a: int, b: int) -> int:
|
| 487 |
+
import math
|
| 488 |
+
return abs(a * b) // math.gcd(max(a, 1), max(b, 1)) if a > 0 and b > 0 else max(a, b, 1)
|
| 489 |
+
|
| 490 |
+
pol_mult = getattr(head_rounding, "multiple_groups", 1)
|
| 491 |
+
snap = _lcm(int(pol_mult), max(1, Hkv))
|
| 492 |
+
if keep_idx.numel() % snap != 0:
|
| 493 |
+
k = (keep_idx.numel() // snap) * snap
|
| 494 |
+
k = max(snap, min(Hq, k))
|
| 495 |
+
# recompute top-k by per-head logits (ensure same criterion used above)
|
| 496 |
+
if base_logits.numel() == Hq * Dh:
|
| 497 |
+
scores = per_head_logits
|
| 498 |
+
else:
|
| 499 |
+
scores = base_logits
|
| 500 |
+
keep_idx = torch.topk(scores, k=k, largest=True).indices.sort().values
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
H_keep = int(keep_idx.numel())
|
| 504 |
+
# channels for q/o slicing
|
| 505 |
+
ch_idx = torch.cat([torch.arange(h * Dh, (h + 1) * Dh) for h in keep_idx]).long()
|
| 506 |
+
|
| 507 |
+
# slice wrapper linears
|
| 508 |
+
gat.q_proj = slice_linear(gat.q_proj, keep_out=ch_idx)
|
| 509 |
+
gat.o_proj = slice_linear(gat.o_proj, keep_in=ch_idx)
|
| 510 |
+
|
| 511 |
+
# transplant into a clean HF attention
|
| 512 |
+
new_attn = copy.deepcopy(base)
|
| 513 |
+
if hasattr(new_attn, "q_proj"):
|
| 514 |
+
new_attn.q_proj = gat.q_proj
|
| 515 |
+
if hasattr(new_attn, "o_proj"):
|
| 516 |
+
new_attn.o_proj = gat.o_proj
|
| 517 |
+
elif hasattr(new_attn, "out_proj"):
|
| 518 |
+
new_attn.out_proj = gat.o_proj
|
| 519 |
+
|
| 520 |
+
# update metadata
|
| 521 |
+
if hasattr(new_attn, "num_heads"):
|
| 522 |
+
new_attn.num_heads = int(H_keep)
|
| 523 |
+
if hasattr(new_attn, "num_key_value_heads"):
|
| 524 |
+
new_attn.num_key_value_heads = int(Hkv)
|
| 525 |
+
if hasattr(new_attn, "head_dim"):
|
| 526 |
+
new_attn.head_dim = int(Dh)
|
| 527 |
+
if hasattr(core.config, "hidden_size"):
|
| 528 |
+
core.config.hidden_size = int(H_keep * Dh)
|
| 529 |
+
|
| 530 |
+
layer.self_attn = new_attn # unwrap
|
| 531 |
+
|
| 532 |
+
# ---- MLP (SwiGLU grouped) ----
|
| 533 |
+
mlp = layer.mlp
|
| 534 |
+
g = getattr(mlp, "neuron_gate", None)
|
| 535 |
+
if g is not None:
|
| 536 |
+
grp_idx = keep_group_indices_from_gate(
|
| 537 |
+
g, policy=policy, step=step, custom_rounding=ffn_rounding,
|
| 538 |
+
)
|
| 539 |
+
group = int(g.group_size) # GroupGate exposes group_size
|
| 540 |
+
keep_exp = torch.cat([torch.arange(i * group, (i + 1) * group) for i in grp_idx]).long()
|
| 541 |
+
|
| 542 |
+
mlp.up_proj = slice_linear(mlp.up_proj, keep_out=keep_exp)
|
| 543 |
+
mlp.gate_proj = slice_linear(mlp.gate_proj, keep_out=keep_exp)
|
| 544 |
+
mlp.down_proj = slice_linear(mlp.down_proj, keep_in=keep_exp)
|
| 545 |
+
|
| 546 |
+
# Restore clean forward & drop gate
|
| 547 |
+
if hasattr(mlp, "_orig_forward"):
|
| 548 |
+
mlp.forward = mlp._orig_forward
|
| 549 |
+
delattr(mlp, "_orig_forward")
|
| 550 |
+
if hasattr(mlp, "neuron_gate"):
|
| 551 |
+
delattr(mlp, "neuron_gate")
|
| 552 |
+
|
| 553 |
+
return slim
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# -------------------------------------------------------------------------
|
| 557 |
+
# Export policy (allow different rounding for Heads vs FFN)
|
| 558 |
+
# -------------------------------------------------------------------------
|
| 559 |
+
|
| 560 |
+
@dataclass
|
| 561 |
+
class LlamaExportPolicy:
|
| 562 |
+
warmup_steps: int = 0
|
| 563 |
+
head_rounding: CoreRounding = CoreRounding() # e.g., CoreRounding(floor=8, multiple=8)
|
| 564 |
+
ffn_rounding: CoreRounding = CoreRounding() # e.g., CoreRounding(min_keep_ratio=0.8, multiple=32)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
# -------------------------------------------------------------------------
|
| 568 |
+
# Grid-search convenience
|
| 569 |
+
# -------------------------------------------------------------------------
|
| 570 |
+
|
| 571 |
+
@dataclass
|
| 572 |
+
class LlamaGrid:
|
| 573 |
+
head_multiple_grid: Optional[Sequence[int]] = (1, 2, 4, 8)
|
| 574 |
+
ffn_snap_grid: Sequence[int] = (1, 32, 64, 128)
|
| 575 |
+
|
| 576 |
+
def llama_search_best_export(
|
| 577 |
+
model_with_gates: nn.Module,
|
| 578 |
+
*,
|
| 579 |
+
export_fn: Callable[[nn.Module, CoreExportPolicy, int], nn.Module],
|
| 580 |
+
num_q_heads: int,
|
| 581 |
+
num_kv_heads: int,
|
| 582 |
+
step: int,
|
| 583 |
+
batch_shape: Tuple[int, int], # (B,S) for text
|
| 584 |
+
grid: Optional[LlamaGrid] = None,
|
| 585 |
+
device: str = "cuda",
|
| 586 |
+
measure_settings=None,
|
| 587 |
+
make_policy: Optional[Callable[[int, int], object]] = None,
|
| 588 |
+
):
|
| 589 |
+
"""
|
| 590 |
+
Convenience wrapper for LLaMA-style search.
|
| 591 |
+
Uses the same `grid_search_latency` as ViT; we just feed head/ffn grids.
|
| 592 |
+
"""
|
| 593 |
+
g = grid or LlamaGrid()
|
| 594 |
+
head_grid = g.head_multiple_grid or [1]
|
| 595 |
+
ffn_grid = list(g.ffn_snap_grid)
|
| 596 |
+
|
| 597 |
+
return grid_search_latency(
|
| 598 |
+
model_with_gates,
|
| 599 |
+
export_fn,
|
| 600 |
+
head_multiples=head_grid,
|
| 601 |
+
ffn_snaps=ffn_grid,
|
| 602 |
+
step=step,
|
| 603 |
+
batch_shape=batch_shape, # adapter’s runner should interpret as (B,S)
|
| 604 |
+
measure_settings=measure_settings,
|
| 605 |
+
device=device,
|
| 606 |
+
make_policy=make_policy,
|
| 607 |
+
)
|
huggingface/.ipynb_checkpoints/vit-checkpoint.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace ViT adapter
|
| 2 |
+
|
| 3 |
+
Bridges the family-agnostic core (gates/export/proxy/train) to ViT-like models
|
| 4 |
+
from Hugging Face (`ViTModel`, `ViTForImageClassification`, DeiT, etc.).
|
| 5 |
+
|
| 6 |
+
Responsibilities
|
| 7 |
+
----------------
|
| 8 |
+
- Attach gates to attention heads and MLP hidden in groups
|
| 9 |
+
- Provide logits getters for student/teacher
|
| 10 |
+
- Export helpers: keep-all (remove gates), and pruned (slice weights + metadata)
|
| 11 |
+
|
| 12 |
+
This adapter intentionally keeps the core unaware of ViT internals.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
# Ensure repo root on sys.path for absolute imports (core, adapters, data)
|
| 17 |
+
import sys, pathlib
|
| 18 |
+
sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
import copy
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
|
| 27 |
+
# NOTE: absolute imports so running `-m examples.run_vit_optimize` works without package install
|
| 28 |
+
from core.gates import HeadGate, GroupGate
|
| 29 |
+
from core.export import (
|
| 30 |
+
ExportPolicy as CoreExportPolicy,
|
| 31 |
+
Rounding as CoreRounding,
|
| 32 |
+
keep_group_indices_from_gate,
|
| 33 |
+
keep_element_indices_from_gate,
|
| 34 |
+
slice_linear,
|
| 35 |
+
Rounding as CoreRounding,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
from core.utils import deepcopy_eval_cpu
|
| 39 |
+
from core.search_export import grid_search_latency
|
| 40 |
+
|
| 41 |
+
# -----------------------------------------------------------------------------
|
| 42 |
+
# Config
|
| 43 |
+
# -----------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class ViTGatingConfig:
|
| 47 |
+
tau: float = 1.5
|
| 48 |
+
init_logit: float = 3.0
|
| 49 |
+
head_gating: bool = True
|
| 50 |
+
ffn_group: int = 16
|
| 51 |
+
ffn_gating: bool = True
|
| 52 |
+
hard_eval: bool = True # use hard masks in eval mode during forward
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _encoder_layers(m: nn.Module):
|
| 57 |
+
"""
|
| 58 |
+
Return the sequence of Transformer blocks for HF ViT.
|
| 59 |
+
Supports:
|
| 60 |
+
- ViTModel: m.encoder.layer
|
| 61 |
+
- ViTForImageClassification: m.vit.encoder.layer
|
| 62 |
+
"""
|
| 63 |
+
# ViTModel path
|
| 64 |
+
enc = getattr(m, "encoder", None)
|
| 65 |
+
if enc is not None and hasattr(enc, "layer"):
|
| 66 |
+
return enc.layer
|
| 67 |
+
|
| 68 |
+
# ViTForImageClassification path
|
| 69 |
+
vit = getattr(m, "vit", None)
|
| 70 |
+
if vit is not None and hasattr(vit, "encoder") and hasattr(vit.encoder, "layer"):
|
| 71 |
+
return vit.encoder.layer
|
| 72 |
+
|
| 73 |
+
raise ValueError("Provided model does not look like a HF ViT (missing *.encoder.layer)")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# -----------------------------------------------------------------------------
|
| 78 |
+
# Gated attention wrapper
|
| 79 |
+
# -----------------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
+
class GatedSelfAttentionHF(nn.Module):
|
| 82 |
+
"""A thin wrapper around HF ViT self-attention that multiplies per-head gates.
|
| 83 |
+
|
| 84 |
+
It keeps references to the underlying query/key/value `nn.Linear` layers and
|
| 85 |
+
the output projection, while exposing a `HeadGate` in `head_gate`.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, attn_container: nn.Module, num_heads: int, head_dim: int, cfg: ViTGatingConfig):
|
| 89 |
+
super().__init__()
|
| 90 |
+
base_attn = attn_container.attention # ViTSdpaSelfAttention or ViTSelfAttention
|
| 91 |
+
out_proj = attn_container.output.dense
|
| 92 |
+
|
| 93 |
+
self.base_attn = base_attn
|
| 94 |
+
self.out_proj = out_proj
|
| 95 |
+
|
| 96 |
+
self.q_proj = base_attn.query
|
| 97 |
+
self.k_proj = base_attn.key
|
| 98 |
+
self.v_proj = base_attn.value
|
| 99 |
+
|
| 100 |
+
self.num_heads = int(num_heads)
|
| 101 |
+
self.head_dim = int(head_dim)
|
| 102 |
+
self.drop_p = getattr(base_attn, "dropout", nn.Dropout(0.0)).p
|
| 103 |
+
|
| 104 |
+
self.head_gate = HeadGate(num_heads=self.num_heads, head_dim=self.head_dim, tau=cfg.tau, init_logit=cfg.init_logit, hard_during_eval=cfg.hard_eval)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def logits(self) -> torch.Tensor:
|
| 108 |
+
return self.head_gate.logits
|
| 109 |
+
|
| 110 |
+
def kept_heads_soft(self) -> torch.Tensor:
|
| 111 |
+
return self.head_gate.probs().sum()
|
| 112 |
+
|
| 113 |
+
def forward(self, hidden_states, head_mask=None):
|
| 114 |
+
B, N, _ = hidden_states.shape
|
| 115 |
+
H, Dh = self.num_heads, self.head_dim
|
| 116 |
+
|
| 117 |
+
wdev = self.q_proj.weight.device
|
| 118 |
+
if hidden_states.device != wdev:
|
| 119 |
+
hidden_states = hidden_states.to(wdev, non_blocking=True)
|
| 120 |
+
|
| 121 |
+
q_lin = self.q_proj(hidden_states)
|
| 122 |
+
k_lin = self.k_proj(hidden_states)
|
| 123 |
+
v_lin = self.v_proj(hidden_states)
|
| 124 |
+
|
| 125 |
+
q = q_lin.view(B, N, H, Dh).transpose(1, 2)
|
| 126 |
+
k = k_lin.view(B, N, H, Dh).transpose(1, 2)
|
| 127 |
+
v = v_lin.view(B, N, H, Dh).transpose(1, 2)
|
| 128 |
+
|
| 129 |
+
logits = self.head_gate.logits
|
| 130 |
+
tau = float(self.head_gate.tau)
|
| 131 |
+
if self.training:
|
| 132 |
+
u = torch.rand_like(logits).clamp_(1e-6, 1-1e-6)
|
| 133 |
+
s = u.log() - (1 - u).log()
|
| 134 |
+
y = torch.sigmoid((logits + s) / tau)
|
| 135 |
+
g_head = ((y > 0.5).to(y.dtype) - y).detach() + y
|
| 136 |
+
else:
|
| 137 |
+
if getattr(self.head_gate, 'hard_during_eval', True):
|
| 138 |
+
g_head = (logits > 0).to(logits.dtype)
|
| 139 |
+
else:
|
| 140 |
+
g_head = torch.sigmoid(logits / tau)
|
| 141 |
+
g = g_head.view(1, H, 1, 1)
|
| 142 |
+
|
| 143 |
+
q = q * g; k = k * g; v = v * g
|
| 144 |
+
|
| 145 |
+
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
| 146 |
+
q, k, v, dropout_p=self.drop_p if self.training else 0.0
|
| 147 |
+
) # [B, H, N, Dh]
|
| 148 |
+
|
| 149 |
+
attn_out = attn_out.transpose(1, 2).contiguous().view(B, N, H * Dh)
|
| 150 |
+
attn_out = self.out_proj(attn_out)
|
| 151 |
+
return attn_out, None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# -----------------------------------------------------------------------------
|
| 155 |
+
# Adapter
|
| 156 |
+
# -----------------------------------------------------------------------------
|
| 157 |
+
|
| 158 |
+
class ViTAdapter:
|
| 159 |
+
def __init__(self, model: nn.Module):
|
| 160 |
+
self.model = model
|
| 161 |
+
_ = _encoder_layers(model)
|
| 162 |
+
|
| 163 |
+
# ---------- Gating attachment ----------
|
| 164 |
+
def attach_gates(self, cfg: ViTGatingConfig) -> nn.Module:
|
| 165 |
+
m = self.model
|
| 166 |
+
H = int(getattr(m.config, "num_attention_heads", 12))
|
| 167 |
+
D = int(getattr(m.config, "hidden_size", 768))
|
| 168 |
+
Dh = D // H
|
| 169 |
+
|
| 170 |
+
for layer in _encoder_layers(m):
|
| 171 |
+
# Attention heads
|
| 172 |
+
if cfg.head_gating:
|
| 173 |
+
attn_container = layer.attention
|
| 174 |
+
if not isinstance(getattr(attn_container, "attention", None), GatedSelfAttentionHF):
|
| 175 |
+
gated = GatedSelfAttentionHF(attn_container, H, Dh, cfg)
|
| 176 |
+
attn_container.attention = gated
|
| 177 |
+
|
| 178 |
+
# FFN hidden (grouped)
|
| 179 |
+
if cfg.ffn_gating:
|
| 180 |
+
inter = layer.intermediate
|
| 181 |
+
d_ff = int(inter.dense.out_features)
|
| 182 |
+
assert d_ff % cfg.ffn_group == 0, f"FFN size {d_ff} not divisible by group {cfg.ffn_group}"
|
| 183 |
+
if not hasattr(inter, "neuron_gate"):
|
| 184 |
+
inter.neuron_gate = GroupGate(num_groups=d_ff // cfg.ffn_group, group_size=cfg.ffn_group, tau=cfg.tau, init_logit=cfg.init_logit, hard_during_eval=cfg.hard_eval)
|
| 185 |
+
# Monkey-patch forward to apply mask after activation (keeps HF shapes)
|
| 186 |
+
if not hasattr(inter, "_orig_forward"):
|
| 187 |
+
inter._orig_forward = inter.forward
|
| 188 |
+
|
| 189 |
+
def _gated_forward(this, x):
|
| 190 |
+
h = this.dense(x)
|
| 191 |
+
h = this.intermediate_act_fn(h)
|
| 192 |
+
msk = this.neuron_gate.mask(this.training).view(1, 1, -1)
|
| 193 |
+
return h * msk
|
| 194 |
+
|
| 195 |
+
inter.forward = _gated_forward.__get__(inter, inter.__class__)
|
| 196 |
+
return m
|
| 197 |
+
|
| 198 |
+
# ---------- Logits helpers ----------
|
| 199 |
+
@staticmethod
|
| 200 |
+
def get_logits(model: nn.Module, x: torch.Tensor, *, head: Optional[nn.Module] = None) -> torch.Tensor:
|
| 201 |
+
out = model(pixel_values=x)
|
| 202 |
+
if hasattr(out, "logits"):
|
| 203 |
+
return out.logits # ViTForImageClassification path
|
| 204 |
+
if hasattr(out, "last_hidden_state"): # ViTModel path (needs external head)
|
| 205 |
+
if head is None:
|
| 206 |
+
raise ValueError("Provide a classification head when using ViTModel without logits.")
|
| 207 |
+
cls_tok = out.last_hidden_state[:, 0, :]
|
| 208 |
+
if next(head.parameters(), torch.tensor([], device=cls_tok.device)).device != cls_tok.device:
|
| 209 |
+
head = head.to(cls_tok.device)
|
| 210 |
+
return head(cls_tok)
|
| 211 |
+
raise ValueError("Model output lacks logits and last_hidden_state.")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ---------- Exporters ----------
|
| 215 |
+
@staticmethod
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def export_keepall(model_with_gates: nn.Module) -> nn.Module:
|
| 218 |
+
slim = deepcopy_eval_cpu(model_with_gates)
|
| 219 |
+
for layer in _encoder_layers(slim):
|
| 220 |
+
# Attention: unwrap gate
|
| 221 |
+
attn_container = layer.attention
|
| 222 |
+
if isinstance(getattr(attn_container, "attention", None), GatedSelfAttentionHF):
|
| 223 |
+
gat = attn_container.attention
|
| 224 |
+
new_attn = copy.deepcopy(gat.base_attn)
|
| 225 |
+
# restore HF metadata if present
|
| 226 |
+
if hasattr(new_attn, "num_attention_heads"):
|
| 227 |
+
new_attn.num_attention_heads = int(gat.num_heads)
|
| 228 |
+
if hasattr(new_attn, "attention_head_size"):
|
| 229 |
+
new_attn.attention_head_size = int(gat.head_dim)
|
| 230 |
+
if hasattr(new_attn, "all_head_size"):
|
| 231 |
+
new_attn.all_head_size = int(gat.num_heads * gat.head_dim)
|
| 232 |
+
attn_container.attention = new_attn
|
| 233 |
+
# FFN: restore original forward and drop gate
|
| 234 |
+
inter = layer.intermediate
|
| 235 |
+
if hasattr(inter, "_orig_forward"):
|
| 236 |
+
inter.forward = inter._orig_forward
|
| 237 |
+
delattr(inter, "_orig_forward")
|
| 238 |
+
if hasattr(inter, "neuron_gate"):
|
| 239 |
+
delattr(inter, "neuron_gate")
|
| 240 |
+
return slim
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
@torch.no_grad()
|
| 244 |
+
def export_pruned(model_with_gates: nn.Module, policy, step: int) -> nn.Module:
|
| 245 |
+
# Support both CoreExportPolicy (single rounding) and ViTExportPolicy (per-axis)
|
| 246 |
+
if isinstance(policy, ViTExportPolicy):
|
| 247 |
+
head_rounding = policy.head_rounding
|
| 248 |
+
ffn_rounding = policy.ffn_rounding
|
| 249 |
+
warmup_steps = policy.warmup_steps
|
| 250 |
+
else:
|
| 251 |
+
# fallback to single rounding for both
|
| 252 |
+
head_rounding = getattr(policy, "rounding", None)
|
| 253 |
+
ffn_rounding = getattr(policy, "rounding", None)
|
| 254 |
+
warmup_steps = int(getattr(policy, "warmup_steps", 0))
|
| 255 |
+
|
| 256 |
+
slim = deepcopy_eval_cpu(model_with_gates)
|
| 257 |
+
warm = (step < warmup_steps)
|
| 258 |
+
|
| 259 |
+
for layer in _encoder_layers(slim):
|
| 260 |
+
# --- Attention heads ---
|
| 261 |
+
attn_container = layer.attention
|
| 262 |
+
gat = getattr(attn_container, "attention", None)
|
| 263 |
+
if isinstance(gat, GatedSelfAttentionHF):
|
| 264 |
+
# choose rounding
|
| 265 |
+
rnd = head_rounding
|
| 266 |
+
# decide head indices via our helper; honor warmup if needed by passing step
|
| 267 |
+
grp_idx = keep_group_indices_from_gate(
|
| 268 |
+
gat.head_gate,
|
| 269 |
+
policy=policy,
|
| 270 |
+
step=step,
|
| 271 |
+
custom_rounding=rnd,
|
| 272 |
+
)
|
| 273 |
+
H_keep = int(grp_idx.numel())
|
| 274 |
+
Dh = int(gat.head_dim)
|
| 275 |
+
|
| 276 |
+
ch_idx = torch.cat([torch.arange(h * Dh, (h + 1) * Dh) for h in grp_idx]).long()
|
| 277 |
+
gat.q_proj = slice_linear(gat.q_proj, keep_out=ch_idx)
|
| 278 |
+
gat.k_proj = slice_linear(gat.k_proj, keep_out=ch_idx)
|
| 279 |
+
gat.v_proj = slice_linear(gat.v_proj, keep_out=ch_idx)
|
| 280 |
+
attn_container.output.dense = slice_linear(attn_container.output.dense, keep_in=ch_idx)
|
| 281 |
+
|
| 282 |
+
new_attn = copy.deepcopy(gat.base_attn)
|
| 283 |
+
new_attn.query = gat.q_proj
|
| 284 |
+
new_attn.key = gat.k_proj
|
| 285 |
+
new_attn.value = gat.v_proj
|
| 286 |
+
if hasattr(new_attn, "num_attention_heads"):
|
| 287 |
+
new_attn.num_attention_heads = H_keep
|
| 288 |
+
if hasattr(new_attn, "attention_head_size"):
|
| 289 |
+
new_attn.attention_head_size = Dh
|
| 290 |
+
if hasattr(new_attn, "all_head_size"):
|
| 291 |
+
new_attn.all_head_size = H_keep * Dh
|
| 292 |
+
attn_container.attention = new_attn
|
| 293 |
+
|
| 294 |
+
# --- FFN groups ---
|
| 295 |
+
inter, out = layer.intermediate, layer.output
|
| 296 |
+
g = getattr(inter, "neuron_gate", None)
|
| 297 |
+
if g is not None:
|
| 298 |
+
rnd = ffn_rounding
|
| 299 |
+
grp_idx = keep_group_indices_from_gate(
|
| 300 |
+
g,
|
| 301 |
+
policy=policy,
|
| 302 |
+
step=step,
|
| 303 |
+
custom_rounding=rnd,
|
| 304 |
+
)
|
| 305 |
+
group = int(g.group_size)
|
| 306 |
+
keep_exp = torch.cat([torch.arange(i * group, (i + 1) * group) for i in grp_idx]).long()
|
| 307 |
+
inter.dense = slice_linear(inter.dense, keep_out=keep_exp)
|
| 308 |
+
out.dense = slice_linear(out.dense, keep_in=keep_exp)
|
| 309 |
+
|
| 310 |
+
# # restore clean forward & drop gate
|
| 311 |
+
# if hasattr(inter, "_orig_forward"):
|
| 312 |
+
# def _clean_forward(this, x):
|
| 313 |
+
# h = this.dense(x)
|
| 314 |
+
# return this.intermediate_act_fn(h)
|
| 315 |
+
# inter.forward = _clean_forward.__get__(inter, inter.__class__)
|
| 316 |
+
# delattr(inter, "_orig_forward")
|
| 317 |
+
# if hasattr(inter, "neuron_gate"):
|
| 318 |
+
# delattr(inter, "neuron_gate")
|
| 319 |
+
|
| 320 |
+
inter.forward = inter.__class__.forward.__get__(inter, inter.__class__)
|
| 321 |
+
if hasattr(inter, "neuron_gate"):
|
| 322 |
+
delattr(inter, "neuron_gate")
|
| 323 |
+
if hasattr(inter, "_orig_forward"):
|
| 324 |
+
delattr(inter, "_orig_forward")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
return slim
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# -----------------------------------------------------------------------------
|
| 333 |
+
# Export policy
|
| 334 |
+
# -----------------------------------------------------------------------------
|
| 335 |
+
"""ViT-specific export policy that allows different rounding for heads vs FFN."""
|
| 336 |
+
@dataclass
|
| 337 |
+
class ViTExportPolicy:
|
| 338 |
+
warmup_steps: int = 0
|
| 339 |
+
head_rounding: CoreRounding = CoreRounding()
|
| 340 |
+
ffn_rounding: CoreRounding = CoreRounding()
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@dataclass
|
| 344 |
+
class ViTGrid:
|
| 345 |
+
head_multiple_grid: Optional[Sequence[int]] = (2, 4, 8)
|
| 346 |
+
ffn_snap_grid: Sequence[int] = (1, 8)
|
| 347 |
+
# head_multiple_grid: Optional[Sequence[int]] = None # default --> 1..num_heads
|
| 348 |
+
# ffn_snap_grid: Sequence[int] = (1, 2, 4, 8, 16)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def vit_search_best_export(
|
| 352 |
+
model_with_gates: nn.Module,
|
| 353 |
+
*,
|
| 354 |
+
export_fn: ExportFn,
|
| 355 |
+
num_heads: int,
|
| 356 |
+
step: int,
|
| 357 |
+
batch_shape: Tuple[int, int, int, int],
|
| 358 |
+
grid: Optional[ViTGrid] = None,
|
| 359 |
+
device: str = "cuda",
|
| 360 |
+
measure_settings: Optional[ProfileSettings] = None,
|
| 361 |
+
make_policy: Optional[Callable[[int, int], object]] = None,
|
| 362 |
+
) -> SearchResult:
|
| 363 |
+
"""Convenience wrapper for ViT-style search.
|
| 364 |
+
|
| 365 |
+
If `make_policy` is not provided, the caller's adapter should accept a
|
| 366 |
+
policy with separate head/FFN rounding; see `adapters.huggingface.vit.ViTExportPolicy`.
|
| 367 |
+
"""
|
| 368 |
+
g = grid or ViTGrid()
|
| 369 |
+
head_grid = g.head_multiple_grid or list(range(1, int(num_heads) + 1))
|
| 370 |
+
ffn_grid = list(g.ffn_snap_grid)
|
| 371 |
+
|
| 372 |
+
return grid_search_latency(
|
| 373 |
+
model_with_gates,
|
| 374 |
+
export_fn,
|
| 375 |
+
head_multiples=head_grid,
|
| 376 |
+
ffn_snaps=ffn_grid,
|
| 377 |
+
step=step,
|
| 378 |
+
batch_shape=batch_shape,
|
| 379 |
+
measure_settings=measure_settings,
|
| 380 |
+
device=device,
|
| 381 |
+
make_policy=make_policy,
|
| 382 |
+
)
|
| 383 |
+
|
huggingface/__init__.py
ADDED
|
File without changes
|
huggingface/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
huggingface/__pycache__/vit.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
huggingface/llama.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace LLaMA/Mistral adapter
|
| 2 |
+
|
| 3 |
+
Bridges the family-agnostic core (gates/export/proxy/train) to HF causal LMs
|
| 4 |
+
(LlamaForCausalLM / MistralForCausalLM, etc.).
|
| 5 |
+
|
| 6 |
+
Responsibilities
|
| 7 |
+
----------------
|
| 8 |
+
- Attach gates to attention Q heads (and optional KV) + grouped MLP (SwiGLU)
|
| 9 |
+
- Provide a logits getter (student/teacher)
|
| 10 |
+
- Exporters:
|
| 11 |
+
* keep-all (unwrap gates, restore clean HF modules)
|
| 12 |
+
* pruned (slice q_proj/o_proj and SwiGLU up/gate/down; update HF metadata)
|
| 13 |
+
- Grid-search wrapper for post-export rounding/snap params
|
| 14 |
+
|
| 15 |
+
This adapter intentionally keeps the core unaware of LLaMA internals.
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
# Ensure repo root on sys.path for absolute imports (core, adapters, data)
|
| 20 |
+
import sys, pathlib
|
| 21 |
+
sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
|
| 22 |
+
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from typing import Optional, Sequence, Callable, Tuple
|
| 25 |
+
|
| 26 |
+
import copy
|
| 27 |
+
import math
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
|
| 32 |
+
# Core (absolute imports so running `-m examples.run_llama_optimize` works)
|
| 33 |
+
from core.gates import HeadGate, GroupGate
|
| 34 |
+
from core.export import (
|
| 35 |
+
ExportPolicy as CoreExportPolicy,
|
| 36 |
+
Rounding as CoreRounding,
|
| 37 |
+
keep_group_indices_from_gate,
|
| 38 |
+
slice_linear,
|
| 39 |
+
)
|
| 40 |
+
from core.utils import deepcopy_eval_cpu
|
| 41 |
+
from core.search_export import grid_search_latency
|
| 42 |
+
|
| 43 |
+
# -------------------------------------------------------------------------
|
| 44 |
+
# Configs
|
| 45 |
+
# -------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class LlamaGatingConfig:
|
| 49 |
+
tau: float = 1.5
|
| 50 |
+
init_logit: float = 3.0
|
| 51 |
+
head_gating: bool = True
|
| 52 |
+
gate_kv: bool = False # optional: gate KV along with Q
|
| 53 |
+
ffn_group: int = 128 # SwiGLU groups
|
| 54 |
+
ffn_gating: bool = True
|
| 55 |
+
hard_eval: bool = True # use hard gates in eval forward
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# -------------------------------------------------------------------------
|
| 59 |
+
# Helpers (GQA, rotary, cache-safe)
|
| 60 |
+
# -------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _last_nonpad_index(attn_mask: Optional[torch.Tensor], seq_len: int, device) -> torch.Tensor:
|
| 64 |
+
if attn_mask is None:
|
| 65 |
+
return torch.full((1,), seq_len - 1, device=device, dtype=torch.long) # will be expanded per-batch later
|
| 66 |
+
# attn_mask: [B, S] in {0,1}; works for left/right padding
|
| 67 |
+
return (attn_mask.sum(dim=1) - 1).clamp(min=0).long()
|
| 68 |
+
|
| 69 |
+
def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 70 |
+
if n_rep == 1:
|
| 71 |
+
return x
|
| 72 |
+
B, Hkv, T, Dh = x.shape
|
| 73 |
+
return x.unsqueeze(2).expand(B, Hkv, n_rep, T, Dh).reshape(B, Hkv * n_rep, T, Dh)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
from transformers.cache_utils import Cache
|
| 77 |
+
except Exception:
|
| 78 |
+
class Cache: # type: ignore
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# -------------------------------------------------------------------------
|
| 83 |
+
# Gated attention wrapper (Llama/Mistral ready)
|
| 84 |
+
# -------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
class GatedSelfAttentionLLM(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
Thin wrapper around HF Llama/Mistral attention module.
|
| 89 |
+
|
| 90 |
+
- Uses the base module's q_proj/k_proj/v_proj/o_proj
|
| 91 |
+
- Applies per-Q-head gates (and optional KV gates)
|
| 92 |
+
- Handles rotary and cache (tuple or HF Cache)
|
| 93 |
+
- Runs SDPA directly, then o_proj
|
| 94 |
+
"""
|
| 95 |
+
def __init__(self, attn_container: nn.Module,
|
| 96 |
+
num_q_heads: int, num_kv_heads: int, head_dim: int,
|
| 97 |
+
cfg: LlamaGatingConfig, layer_idx: int):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.base_attn = attn_container
|
| 100 |
+
self.q_proj = attn_container.q_proj
|
| 101 |
+
self.k_proj = attn_container.k_proj
|
| 102 |
+
self.v_proj = attn_container.v_proj
|
| 103 |
+
self.o_proj = getattr(attn_container, "o_proj", getattr(attn_container, "out_proj", None))
|
| 104 |
+
|
| 105 |
+
self.num_q_heads = int(num_q_heads)
|
| 106 |
+
self.num_kv_heads = int(num_kv_heads)
|
| 107 |
+
self.head_dim = int(head_dim)
|
| 108 |
+
self.gate_kv = bool(cfg.gate_kv)
|
| 109 |
+
self.drop_p = float(getattr(attn_container, "attention_dropout",
|
| 110 |
+
getattr(attn_container, "attn_dropout",
|
| 111 |
+
getattr(attn_container, "dropout", 0.0))))
|
| 112 |
+
self.head_gate = HeadGate(num_heads=self.num_q_heads,
|
| 113 |
+
head_dim=self.head_dim,
|
| 114 |
+
tau=cfg.tau, init_logit=cfg.init_logit,
|
| 115 |
+
hard_during_eval=cfg.hard_eval)
|
| 116 |
+
|
| 117 |
+
# rotary helpers if present on base
|
| 118 |
+
self.rotary_emb = getattr(attn_container, "rotary_emb", None)
|
| 119 |
+
self.apply_rotary_pos_emb = getattr(attn_container, "apply_rotary_pos_emb", None)
|
| 120 |
+
self.layer_idx = int(layer_idx)
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def logits(self) -> torch.Tensor:
|
| 124 |
+
return self.head_gate.logits
|
| 125 |
+
|
| 126 |
+
def kept_heads_soft(self) -> torch.Tensor:
|
| 127 |
+
p = self.head_gate.probs().detach().float().view(-1)
|
| 128 |
+
if p.numel() == self.num_q_heads * self.head_dim:
|
| 129 |
+
p = p.view(self.num_q_heads, self.head_dim).mean(dim=1)
|
| 130 |
+
return p.sum()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def forward(
|
| 134 |
+
self,
|
| 135 |
+
hidden_states: torch.Tensor, # [B,T,D]
|
| 136 |
+
attention_mask: Optional[torch.Tensor] = None, # additive mask [B,1,Tq,Tk] or None
|
| 137 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 138 |
+
past_key_value = None, # tuple, list, Cache or None
|
| 139 |
+
output_attentions: bool = False,
|
| 140 |
+
use_cache: bool = False,
|
| 141 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 142 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 143 |
+
**kwargs,
|
| 144 |
+
):
|
| 145 |
+
B, T, D = hidden_states.shape
|
| 146 |
+
Hq, Hkv, Dh = self.num_q_heads, self.num_kv_heads, self.head_dim
|
| 147 |
+
assert Hq * Dh == D, "hidden_size must equal num_heads * head_dim"
|
| 148 |
+
n_rep = max(1, Hq // Hkv)
|
| 149 |
+
|
| 150 |
+
# qkv projections
|
| 151 |
+
q = self.q_proj(hidden_states).view(B, T, Hq, Dh).transpose(1, 2) # [B,Hq,T,Dh]
|
| 152 |
+
k = self.k_proj(hidden_states).view(B, T, Hkv, Dh).transpose(1, 2) # [B,Hkv,T,Dh]
|
| 153 |
+
v = self.v_proj(hidden_states).view(B, T, Hkv, Dh).transpose(1, 2) # [B,Hkv,T,Dh]
|
| 154 |
+
|
| 155 |
+
# rotary
|
| 156 |
+
if (self.rotary_emb is not None) and (self.apply_rotary_pos_emb is not None):
|
| 157 |
+
Tpast = 0
|
| 158 |
+
if isinstance(past_key_value, (tuple, list)) and len(past_key_value) == 2:
|
| 159 |
+
Tpast = int(past_key_value[0].size(2))
|
| 160 |
+
elif isinstance(past_key_value, Cache):
|
| 161 |
+
Tpast = int(cache_position.max().item() if cache_position is not None else 0)
|
| 162 |
+
seq_len = Tpast + T
|
| 163 |
+
try:
|
| 164 |
+
cos, sin = self.rotary_emb(v, seq_len=seq_len)
|
| 165 |
+
except TypeError:
|
| 166 |
+
cos, sin = self.rotary_emb(q, seq_len=seq_len)
|
| 167 |
+
# try rich signature first
|
| 168 |
+
try:
|
| 169 |
+
q, k = self.apply_rotary_pos_emb(
|
| 170 |
+
q, k, cos, sin,
|
| 171 |
+
position_ids=position_ids,
|
| 172 |
+
cache_position=cache_position,
|
| 173 |
+
position_embeddings=position_embeddings
|
| 174 |
+
)
|
| 175 |
+
except TypeError:
|
| 176 |
+
try:
|
| 177 |
+
q, k = self.apply_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids)
|
| 178 |
+
except TypeError:
|
| 179 |
+
q, k = self.apply_rotary_pos_emb(q, k, cos, sin)
|
| 180 |
+
|
| 181 |
+
# cache merge
|
| 182 |
+
present = None
|
| 183 |
+
if past_key_value is None or (isinstance(past_key_value, (tuple, list)) and len(past_key_value) == 0):
|
| 184 |
+
pass
|
| 185 |
+
elif isinstance(past_key_value, (tuple, list)):
|
| 186 |
+
pk, pv = past_key_value # [B,Hkv,Tpast,Dh]
|
| 187 |
+
k = torch.cat([pk, k], dim=2)
|
| 188 |
+
v = torch.cat([pv, v], dim=2)
|
| 189 |
+
present = (k, v) if use_cache else None
|
| 190 |
+
elif isinstance(past_key_value, Cache):
|
| 191 |
+
k, v = past_key_value.update(k, v, self.layer_idx, cache_position)
|
| 192 |
+
present = past_key_value
|
| 193 |
+
|
| 194 |
+
# gates
|
| 195 |
+
# g = self.head_gate.mask(self.training).view(1, Hq, 1, 1)
|
| 196 |
+
# ---- gates (supports per-head OR per-channel HeadGate) ----
|
| 197 |
+
m = self.head_gate.mask(self.training) # 1D tensor
|
| 198 |
+
m = m.detach() if not self.training else m
|
| 199 |
+
if m.numel() == Hq:
|
| 200 |
+
# per-head gating
|
| 201 |
+
gH = m.view(1, Hq, 1, 1) # [1,Hq,1,1]
|
| 202 |
+
q = q * gH
|
| 203 |
+
if self.gate_kv:
|
| 204 |
+
if n_rep == 1:
|
| 205 |
+
k = k * gH; v = v * gH
|
| 206 |
+
else:
|
| 207 |
+
g_kv = gH.view(1, Hkv, n_rep, 1, 1).amax(dim=2)
|
| 208 |
+
k = k * g_kv; v = v * g_kv
|
| 209 |
+
elif m.numel() == Hq * Dh:
|
| 210 |
+
# per-channel gating
|
| 211 |
+
gHD = m.view(1, Hq, 1, Dh) # [1,Hq,1,Dh]
|
| 212 |
+
q = q * gHD
|
| 213 |
+
if self.gate_kv:
|
| 214 |
+
# collapse to per-head for KV, then map to Hkv via amax over replicas
|
| 215 |
+
gH = gHD.amax(dim=-1, keepdim=True) # [1,Hq,1,1]
|
| 216 |
+
if n_rep == 1:
|
| 217 |
+
k = k * gH; v = v * gH
|
| 218 |
+
else:
|
| 219 |
+
g_kv = gH.view(1, Hkv, n_rep, 1, 1).amax(dim=2)
|
| 220 |
+
k = k * g_kv; v = v * g_kv
|
| 221 |
+
else:
|
| 222 |
+
raise RuntimeError(
|
| 223 |
+
f"HeadGate mask has {m.numel()} elems; expected {Hq} or {Hq*Dh}"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# GQA replicate KV to Q count
|
| 228 |
+
k = _repeat_kv(k, n_rep)
|
| 229 |
+
v = _repeat_kv(v, n_rep)
|
| 230 |
+
|
| 231 |
+
attn = F.scaled_dot_product_attention(
|
| 232 |
+
q, k, v,
|
| 233 |
+
attn_mask=attention_mask,
|
| 234 |
+
dropout_p=self.drop_p if self.training else 0.0,
|
| 235 |
+
is_causal=True
|
| 236 |
+
)
|
| 237 |
+
out = attn.transpose(1, 2).contiguous().view(B, T, Hq * Dh)
|
| 238 |
+
out = self.o_proj(out)
|
| 239 |
+
|
| 240 |
+
attn_weights = None
|
| 241 |
+
# HF expects (attn_output, attn_weights, present_key_value) always
|
| 242 |
+
if output_attentions:
|
| 243 |
+
return (out, attn_weights, present)
|
| 244 |
+
else:
|
| 245 |
+
return (out, None, present)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# -------------------------------------------------------------------------
|
| 250 |
+
# Adapter
|
| 251 |
+
# -------------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
class LlamaAdapter:
|
| 254 |
+
def __init__(self, model: nn.Module):
|
| 255 |
+
self.model = model
|
| 256 |
+
core = getattr(model, "model", model)
|
| 257 |
+
if not hasattr(core, "layers"):
|
| 258 |
+
raise ValueError("Provided model does not look like HF LLaMA/Mistral (missing .model.layers or .layers)")
|
| 259 |
+
|
| 260 |
+
# ---------- Gating attachment ----------
|
| 261 |
+
def attach_gates(self, cfg: LlamaGatingConfig) -> nn.Module:
|
| 262 |
+
m = self.model
|
| 263 |
+
core = getattr(m, "model", m)
|
| 264 |
+
layers = core.layers
|
| 265 |
+
|
| 266 |
+
Hq = int(core.config.num_attention_heads)
|
| 267 |
+
Hkv = int(getattr(core.config, "num_key_value_heads", Hq))
|
| 268 |
+
Dh = int(core.config.hidden_size // Hq)
|
| 269 |
+
|
| 270 |
+
for li, layer in enumerate(layers):
|
| 271 |
+
# Attention heads
|
| 272 |
+
if cfg.head_gating:
|
| 273 |
+
base = layer.self_attn
|
| 274 |
+
if not isinstance(base, GatedSelfAttentionLLM):
|
| 275 |
+
gated = GatedSelfAttentionLLM(
|
| 276 |
+
attn_container=base,
|
| 277 |
+
num_q_heads=Hq,
|
| 278 |
+
num_kv_heads=Hkv,
|
| 279 |
+
head_dim=Dh,
|
| 280 |
+
cfg=cfg,
|
| 281 |
+
layer_idx=li,
|
| 282 |
+
)
|
| 283 |
+
layer.self_attn = gated # route via our wrapper
|
| 284 |
+
|
| 285 |
+
# MLP grouped gating (SwiGLU)
|
| 286 |
+
if cfg.ffn_gating:
|
| 287 |
+
mlp = layer.mlp
|
| 288 |
+
I = int(mlp.up_proj.out_features)
|
| 289 |
+
assert I % cfg.ffn_group == 0, f"SwiGLU size {I} not divisible by group {cfg.ffn_group}"
|
| 290 |
+
if not hasattr(mlp, "neuron_gate"):
|
| 291 |
+
mlp.neuron_gate = GroupGate(
|
| 292 |
+
num_groups=I // cfg.ffn_group,
|
| 293 |
+
group_size=cfg.ffn_group,
|
| 294 |
+
tau=cfg.tau, init_logit=cfg.init_logit,
|
| 295 |
+
hard_during_eval=cfg.hard_eval,
|
| 296 |
+
)
|
| 297 |
+
if not hasattr(mlp, "_orig_forward"):
|
| 298 |
+
mlp._orig_forward = mlp.forward
|
| 299 |
+
|
| 300 |
+
def _gated_mlp_forward(this, x):
|
| 301 |
+
# LLaMA: z = silu(up(x)) * (gate(x) * m); out = down(z)
|
| 302 |
+
u = this.up_proj(x)
|
| 303 |
+
g = this.gate_proj(x)
|
| 304 |
+
m = this.neuron_gate.mask(this.training).view(1, 1, -1)
|
| 305 |
+
z = torch.nn.functional.silu(u) * (g * m)
|
| 306 |
+
return this.down_proj(z)
|
| 307 |
+
|
| 308 |
+
mlp.forward = _gated_mlp_forward.__get__(mlp, mlp.__class__)
|
| 309 |
+
return m
|
| 310 |
+
|
| 311 |
+
# ---------- Logits helper ----------
|
| 312 |
+
@staticmethod
|
| 313 |
+
def _last_token_index(attn_mask: torch.Tensor) -> torch.Tensor:
|
| 314 |
+
# attn_mask: [B, S] with 1 for tokens, 0 for padding
|
| 315 |
+
# returns [B] indices of last non-pad
|
| 316 |
+
# works for both bool and int masks
|
| 317 |
+
if attn_mask is None:
|
| 318 |
+
# no mask → use last position S-1
|
| 319 |
+
return None
|
| 320 |
+
if attn_mask.dtype != torch.long:
|
| 321 |
+
attn_mask = attn_mask.to(torch.long)
|
| 322 |
+
# idx = lengths - 1
|
| 323 |
+
return (attn_mask.sum(dim=-1) - 1).clamp_min(0)
|
| 324 |
+
|
| 325 |
+
@staticmethod
|
| 326 |
+
def get_logits(model: nn.Module,
|
| 327 |
+
input_ids: torch.Tensor,
|
| 328 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 329 |
+
last_only: bool = True,
|
| 330 |
+
**forward_kwargs) -> torch.Tensor:
|
| 331 |
+
"""
|
| 332 |
+
Returns logits. If last_only=True, computes ONLY the last-token logits by:
|
| 333 |
+
1) getting hidden states from the base decoder,
|
| 334 |
+
2) selecting last non-pad position per sample,
|
| 335 |
+
3) projecting through lm_head on that 1 position.
|
| 336 |
+
This avoids allocating [B,S,V].
|
| 337 |
+
"""
|
| 338 |
+
# (1) run base decoder, not the full CausalLM head
|
| 339 |
+
core = getattr(model, "model", None)
|
| 340 |
+
if core is None:
|
| 341 |
+
# fallback if the model is already a bare decoder (rare)
|
| 342 |
+
core = model
|
| 343 |
+
|
| 344 |
+
# We only need last_hidden_state; no cache; avoid building logits for all S
|
| 345 |
+
# return_dict=False to grab tuple and avoid extra allocations
|
| 346 |
+
outputs = core(
|
| 347 |
+
input_ids=input_ids,
|
| 348 |
+
attention_mask=attention_mask,
|
| 349 |
+
use_cache=False,
|
| 350 |
+
return_dict=False,
|
| 351 |
+
**forward_kwargs
|
| 352 |
+
)
|
| 353 |
+
hidden = outputs[0] # [B, S, D]
|
| 354 |
+
|
| 355 |
+
if not last_only:
|
| 356 |
+
# If someone explicitly wants all logits, fine:
|
| 357 |
+
return model.lm_head(hidden) # [B,S,V] (expensive!)
|
| 358 |
+
|
| 359 |
+
# (2) select last token per sample
|
| 360 |
+
B, S, D = hidden.shape
|
| 361 |
+
if attention_mask is None:
|
| 362 |
+
# simple "last index"
|
| 363 |
+
idx = torch.full((B,), S - 1, device=hidden.device, dtype=torch.long)
|
| 364 |
+
else:
|
| 365 |
+
idx = LlamaAdapter._last_token_index(attention_mask)
|
| 366 |
+
|
| 367 |
+
# gather last hidden: [B, D]
|
| 368 |
+
last_h = hidden[torch.arange(B, device=hidden.device), idx] # [B, D]
|
| 369 |
+
# (3) project to logits for that 1 position
|
| 370 |
+
last_logits = model.lm_head(last_h).unsqueeze(1) # [B,1,V]
|
| 371 |
+
return last_logits
|
| 372 |
+
|
| 373 |
+
# ---------- Exporters ----------
|
| 374 |
+
@staticmethod
|
| 375 |
+
@torch.no_grad()
|
| 376 |
+
def export_keepall(model_with_gates: nn.Module) -> nn.Module:
|
| 377 |
+
"""
|
| 378 |
+
Unwrap attention wrappers; restore original MLP.forward; drop gates.
|
| 379 |
+
"""
|
| 380 |
+
slim = deepcopy_eval_cpu(model_with_gates)
|
| 381 |
+
core = getattr(slim, "model", slim)
|
| 382 |
+
if not hasattr(core, "layers"):
|
| 383 |
+
return slim
|
| 384 |
+
|
| 385 |
+
for layer in core.layers:
|
| 386 |
+
# attention
|
| 387 |
+
attn = layer.self_attn
|
| 388 |
+
if isinstance(attn, GatedSelfAttentionLLM):
|
| 389 |
+
gat = attn
|
| 390 |
+
new_attn = copy.deepcopy(gat.base_attn)
|
| 391 |
+
# keep metadata consistent
|
| 392 |
+
if hasattr(new_attn, "num_heads"):
|
| 393 |
+
new_attn.num_heads = int(gat.num_q_heads)
|
| 394 |
+
if hasattr(new_attn, "num_key_value_heads"):
|
| 395 |
+
new_attn.num_key_value_heads = int(gat.num_kv_heads)
|
| 396 |
+
if hasattr(new_attn, "head_dim"):
|
| 397 |
+
new_attn.head_dim = int(gat.head_dim)
|
| 398 |
+
layer.self_attn = new_attn
|
| 399 |
+
|
| 400 |
+
# mlp
|
| 401 |
+
mlp = layer.mlp
|
| 402 |
+
if hasattr(mlp, "_orig_forward"):
|
| 403 |
+
mlp.forward = mlp._orig_forward
|
| 404 |
+
delattr(mlp, "_orig_forward")
|
| 405 |
+
if hasattr(mlp, "neuron_gate"):
|
| 406 |
+
delattr(mlp, "neuron_gate")
|
| 407 |
+
|
| 408 |
+
return slim
|
| 409 |
+
|
| 410 |
+
@staticmethod
|
| 411 |
+
@torch.no_grad()
|
| 412 |
+
def export_pruned(model_with_gates: nn.Module, policy, step: int) -> nn.Module:
|
| 413 |
+
"""
|
| 414 |
+
Produce a clean CPU eval model:
|
| 415 |
+
- Read gates to choose Q heads; slice q_proj rows and o_proj cols
|
| 416 |
+
- Snap kept heads to an LCM of (policy multiple, Hkv)
|
| 417 |
+
- Slice SwiGLU up/gate/down by groups
|
| 418 |
+
- Unwrap back to plain HF modules; update metadata
|
| 419 |
+
"""
|
| 420 |
+
# Accept either CoreExportPolicy with per-axis rounding or family policy
|
| 421 |
+
if isinstance(policy, LlamaExportPolicy):
|
| 422 |
+
head_rounding = policy.head_rounding
|
| 423 |
+
ffn_rounding = policy.ffn_rounding
|
| 424 |
+
warmup_steps = policy.warmup_steps
|
| 425 |
+
else:
|
| 426 |
+
head_rounding = getattr(policy, "rounding", None)
|
| 427 |
+
ffn_rounding = getattr(policy, "rounding", None)
|
| 428 |
+
warmup_steps = int(getattr(policy, "warmup_steps", 0))
|
| 429 |
+
|
| 430 |
+
slim = deepcopy_eval_cpu(model_with_gates)
|
| 431 |
+
core = getattr(slim, "model", slim)
|
| 432 |
+
layers = getattr(core, "layers", None)
|
| 433 |
+
if layers is None:
|
| 434 |
+
return slim
|
| 435 |
+
|
| 436 |
+
warm = (step < warmup_steps)
|
| 437 |
+
|
| 438 |
+
def _lcm(a: int, b: int) -> int:
|
| 439 |
+
return abs(a * b) // math.gcd(max(a, 1), max(b, 1)) if a > 0 and b > 0 else max(a, b, 1)
|
| 440 |
+
|
| 441 |
+
for li, layer in enumerate(layers):
|
| 442 |
+
# ---- Attention (Q heads) ----
|
| 443 |
+
attn = layer.self_attn
|
| 444 |
+
if isinstance(attn, GatedSelfAttentionLLM):
|
| 445 |
+
gat = attn
|
| 446 |
+
base = gat.base_attn
|
| 447 |
+
|
| 448 |
+
Hq = int(gat.num_q_heads)
|
| 449 |
+
Hkv = int(gat.num_kv_heads)
|
| 450 |
+
Dh = int(gat.head_dim)
|
| 451 |
+
|
| 452 |
+
if warm:
|
| 453 |
+
keep_idx = torch.arange(Hq)
|
| 454 |
+
else:
|
| 455 |
+
# Build a "per-head" proxy gate if base gate is per-channel.
|
| 456 |
+
base_logits = gat.head_gate.logits.detach().float().view(-1)
|
| 457 |
+
tau = float(getattr(gat.head_gate, "tau", 1.0))
|
| 458 |
+
|
| 459 |
+
if base_logits.numel() == Hq:
|
| 460 |
+
# Native per-head gate: use as-is
|
| 461 |
+
proxy_gate = gat.head_gate
|
| 462 |
+
keep_idx = keep_group_indices_from_gate(
|
| 463 |
+
proxy_gate, policy=policy, step=step, custom_rounding=head_rounding
|
| 464 |
+
)
|
| 465 |
+
elif base_logits.numel() == Hq * Dh:
|
| 466 |
+
# Collapse per-channel → per-head (mean; or use .amax for stricter)
|
| 467 |
+
per_head_logits = base_logits.view(Hq, Dh).mean(dim=1)
|
| 468 |
+
|
| 469 |
+
class _PerHeadProxyGate:
|
| 470 |
+
def __init__(self, logits, tau):
|
| 471 |
+
self.logits = logits
|
| 472 |
+
self.tau = tau
|
| 473 |
+
self.num_groups = logits.numel()
|
| 474 |
+
self.group_size = 1
|
| 475 |
+
|
| 476 |
+
proxy_gate = _PerHeadProxyGate(per_head_logits, tau)
|
| 477 |
+
keep_idx = keep_group_indices_from_gate(
|
| 478 |
+
proxy_gate, policy=policy, step=step, custom_rounding=head_rounding
|
| 479 |
+
)
|
| 480 |
+
else:
|
| 481 |
+
raise RuntimeError(
|
| 482 |
+
f"Unexpected HeadGate logits len {base_logits.numel()} vs H={Hq} or H*Dh={Hq*Dh}"
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Enforce LCM with GQA (Hkv) via truncation to floor-multiple
|
| 486 |
+
def _lcm(a: int, b: int) -> int:
|
| 487 |
+
import math
|
| 488 |
+
return abs(a * b) // math.gcd(max(a, 1), max(b, 1)) if a > 0 and b > 0 else max(a, b, 1)
|
| 489 |
+
|
| 490 |
+
pol_mult = getattr(head_rounding, "multiple_groups", 1)
|
| 491 |
+
snap = _lcm(int(pol_mult), max(1, Hkv))
|
| 492 |
+
if keep_idx.numel() % snap != 0:
|
| 493 |
+
k = (keep_idx.numel() // snap) * snap
|
| 494 |
+
k = max(snap, min(Hq, k))
|
| 495 |
+
# recompute top-k by per-head logits (ensure same criterion used above)
|
| 496 |
+
if base_logits.numel() == Hq * Dh:
|
| 497 |
+
scores = per_head_logits
|
| 498 |
+
else:
|
| 499 |
+
scores = base_logits
|
| 500 |
+
keep_idx = torch.topk(scores, k=k, largest=True).indices.sort().values
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
H_keep = int(keep_idx.numel())
|
| 504 |
+
# channels for q/o slicing
|
| 505 |
+
ch_idx = torch.cat([torch.arange(h * Dh, (h + 1) * Dh) for h in keep_idx]).long()
|
| 506 |
+
|
| 507 |
+
# slice wrapper linears
|
| 508 |
+
gat.q_proj = slice_linear(gat.q_proj, keep_out=ch_idx)
|
| 509 |
+
gat.o_proj = slice_linear(gat.o_proj, keep_in=ch_idx)
|
| 510 |
+
|
| 511 |
+
# transplant into a clean HF attention
|
| 512 |
+
new_attn = copy.deepcopy(base)
|
| 513 |
+
if hasattr(new_attn, "q_proj"):
|
| 514 |
+
new_attn.q_proj = gat.q_proj
|
| 515 |
+
if hasattr(new_attn, "o_proj"):
|
| 516 |
+
new_attn.o_proj = gat.o_proj
|
| 517 |
+
elif hasattr(new_attn, "out_proj"):
|
| 518 |
+
new_attn.out_proj = gat.o_proj
|
| 519 |
+
|
| 520 |
+
# update metadata
|
| 521 |
+
if hasattr(new_attn, "num_heads"):
|
| 522 |
+
new_attn.num_heads = int(H_keep)
|
| 523 |
+
if hasattr(new_attn, "num_key_value_heads"):
|
| 524 |
+
new_attn.num_key_value_heads = int(Hkv)
|
| 525 |
+
if hasattr(new_attn, "head_dim"):
|
| 526 |
+
new_attn.head_dim = int(Dh)
|
| 527 |
+
if hasattr(core.config, "hidden_size"):
|
| 528 |
+
core.config.hidden_size = int(H_keep * Dh)
|
| 529 |
+
|
| 530 |
+
layer.self_attn = new_attn # unwrap
|
| 531 |
+
|
| 532 |
+
# ---- MLP (SwiGLU grouped) ----
|
| 533 |
+
mlp = layer.mlp
|
| 534 |
+
g = getattr(mlp, "neuron_gate", None)
|
| 535 |
+
if g is not None:
|
| 536 |
+
grp_idx = keep_group_indices_from_gate(
|
| 537 |
+
g, policy=policy, step=step, custom_rounding=ffn_rounding,
|
| 538 |
+
)
|
| 539 |
+
group = int(g.group_size) # GroupGate exposes group_size
|
| 540 |
+
keep_exp = torch.cat([torch.arange(i * group, (i + 1) * group) for i in grp_idx]).long()
|
| 541 |
+
|
| 542 |
+
mlp.up_proj = slice_linear(mlp.up_proj, keep_out=keep_exp)
|
| 543 |
+
mlp.gate_proj = slice_linear(mlp.gate_proj, keep_out=keep_exp)
|
| 544 |
+
mlp.down_proj = slice_linear(mlp.down_proj, keep_in=keep_exp)
|
| 545 |
+
|
| 546 |
+
# Restore clean forward & drop gate
|
| 547 |
+
if hasattr(mlp, "_orig_forward"):
|
| 548 |
+
mlp.forward = mlp._orig_forward
|
| 549 |
+
delattr(mlp, "_orig_forward")
|
| 550 |
+
if hasattr(mlp, "neuron_gate"):
|
| 551 |
+
delattr(mlp, "neuron_gate")
|
| 552 |
+
|
| 553 |
+
return slim
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# -------------------------------------------------------------------------
|
| 557 |
+
# Export policy (allow different rounding for Heads vs FFN)
|
| 558 |
+
# -------------------------------------------------------------------------
|
| 559 |
+
|
| 560 |
+
@dataclass
|
| 561 |
+
class LlamaExportPolicy:
|
| 562 |
+
warmup_steps: int = 0
|
| 563 |
+
head_rounding: CoreRounding = CoreRounding() # e.g., CoreRounding(floor=8, multiple=8)
|
| 564 |
+
ffn_rounding: CoreRounding = CoreRounding() # e.g., CoreRounding(min_keep_ratio=0.8, multiple=32)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
# -------------------------------------------------------------------------
|
| 568 |
+
# Grid-search convenience
|
| 569 |
+
# -------------------------------------------------------------------------
|
| 570 |
+
|
| 571 |
+
@dataclass
|
| 572 |
+
class LlamaGrid:
|
| 573 |
+
head_multiple_grid: Optional[Sequence[int]] = (1, 2, 4, 8)
|
| 574 |
+
ffn_snap_grid: Sequence[int] = (1, 32, 64, 128)
|
| 575 |
+
|
| 576 |
+
def llama_search_best_export(
|
| 577 |
+
model_with_gates: nn.Module,
|
| 578 |
+
*,
|
| 579 |
+
export_fn: Callable[[nn.Module, CoreExportPolicy, int], nn.Module],
|
| 580 |
+
num_q_heads: int,
|
| 581 |
+
num_kv_heads: int,
|
| 582 |
+
step: int,
|
| 583 |
+
batch_shape: Tuple[int, int], # (B,S) for text
|
| 584 |
+
grid: Optional[LlamaGrid] = None,
|
| 585 |
+
device: str = "cuda",
|
| 586 |
+
measure_settings=None,
|
| 587 |
+
make_policy: Optional[Callable[[int, int], object]] = None,
|
| 588 |
+
):
|
| 589 |
+
"""
|
| 590 |
+
Convenience wrapper for LLaMA-style search.
|
| 591 |
+
Uses the same `grid_search_latency` as ViT; we just feed head/ffn grids.
|
| 592 |
+
"""
|
| 593 |
+
g = grid or LlamaGrid()
|
| 594 |
+
head_grid = g.head_multiple_grid or [1]
|
| 595 |
+
ffn_grid = list(g.ffn_snap_grid)
|
| 596 |
+
|
| 597 |
+
return grid_search_latency(
|
| 598 |
+
model_with_gates,
|
| 599 |
+
export_fn,
|
| 600 |
+
head_multiples=head_grid,
|
| 601 |
+
ffn_snaps=ffn_grid,
|
| 602 |
+
step=step,
|
| 603 |
+
batch_shape=batch_shape, # adapter’s runner should interpret as (B,S)
|
| 604 |
+
measure_settings=measure_settings,
|
| 605 |
+
device=device,
|
| 606 |
+
make_policy=make_policy,
|
| 607 |
+
)
|
huggingface/registry.py
ADDED
|
File without changes
|
huggingface/vit.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace ViT adapter
|
| 2 |
+
|
| 3 |
+
Bridges the family-agnostic core (gates/export/proxy/train) to ViT-like models
|
| 4 |
+
from Hugging Face (`ViTModel`, `ViTForImageClassification`, DeiT, etc.).
|
| 5 |
+
|
| 6 |
+
Responsibilities
|
| 7 |
+
----------------
|
| 8 |
+
- Attach gates to attention heads and MLP hidden in groups
|
| 9 |
+
- Provide logits getters for student/teacher
|
| 10 |
+
- Export helpers: keep-all (remove gates), and pruned (slice weights + metadata)
|
| 11 |
+
|
| 12 |
+
This adapter intentionally keeps the core unaware of ViT internals.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
# Ensure repo root on sys.path for absolute imports (core, adapters, data)
|
| 17 |
+
import sys, pathlib
|
| 18 |
+
sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
import copy
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
|
| 27 |
+
# NOTE: absolute imports so running `-m examples.run_vit_optimize` works without package install
|
| 28 |
+
from core.gates import HeadGate, GroupGate
|
| 29 |
+
from core.export import (
|
| 30 |
+
ExportPolicy as CoreExportPolicy,
|
| 31 |
+
Rounding as CoreRounding,
|
| 32 |
+
keep_group_indices_from_gate,
|
| 33 |
+
keep_element_indices_from_gate,
|
| 34 |
+
slice_linear,
|
| 35 |
+
Rounding as CoreRounding,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
from core.utils import deepcopy_eval_cpu
|
| 39 |
+
from core.search_export import grid_search_latency
|
| 40 |
+
|
| 41 |
+
# -----------------------------------------------------------------------------
|
| 42 |
+
# Config
|
| 43 |
+
# -----------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class ViTGatingConfig:
|
| 47 |
+
tau: float = 1.5
|
| 48 |
+
init_logit: float = 3.0
|
| 49 |
+
head_gating: bool = True
|
| 50 |
+
ffn_group: int = 16
|
| 51 |
+
ffn_gating: bool = True
|
| 52 |
+
hard_eval: bool = True # use hard masks in eval mode during forward
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _encoder_layers(m: nn.Module):
|
| 57 |
+
"""
|
| 58 |
+
Return the sequence of Transformer blocks for HF ViT.
|
| 59 |
+
Supports:
|
| 60 |
+
- ViTModel: m.encoder.layer
|
| 61 |
+
- ViTForImageClassification: m.vit.encoder.layer
|
| 62 |
+
"""
|
| 63 |
+
# ViTModel path
|
| 64 |
+
enc = getattr(m, "encoder", None)
|
| 65 |
+
if enc is not None and hasattr(enc, "layer"):
|
| 66 |
+
return enc.layer
|
| 67 |
+
|
| 68 |
+
# ViTForImageClassification path
|
| 69 |
+
vit = getattr(m, "vit", None)
|
| 70 |
+
if vit is not None and hasattr(vit, "encoder") and hasattr(vit.encoder, "layer"):
|
| 71 |
+
return vit.encoder.layer
|
| 72 |
+
|
| 73 |
+
raise ValueError("Provided model does not look like a HF ViT (missing *.encoder.layer)")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# -----------------------------------------------------------------------------
|
| 78 |
+
# Gated attention wrapper
|
| 79 |
+
# -----------------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
+
class GatedSelfAttentionHF(nn.Module):
|
| 82 |
+
"""A thin wrapper around HF ViT self-attention that multiplies per-head gates.
|
| 83 |
+
|
| 84 |
+
It keeps references to the underlying query/key/value `nn.Linear` layers and
|
| 85 |
+
the output projection, while exposing a `HeadGate` in `head_gate`.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, attn_container: nn.Module, num_heads: int, head_dim: int, cfg: ViTGatingConfig):
|
| 89 |
+
super().__init__()
|
| 90 |
+
base_attn = attn_container.attention # ViTSdpaSelfAttention or ViTSelfAttention
|
| 91 |
+
out_proj = attn_container.output.dense
|
| 92 |
+
|
| 93 |
+
self.base_attn = base_attn
|
| 94 |
+
self.out_proj = out_proj
|
| 95 |
+
|
| 96 |
+
self.q_proj = base_attn.query
|
| 97 |
+
self.k_proj = base_attn.key
|
| 98 |
+
self.v_proj = base_attn.value
|
| 99 |
+
|
| 100 |
+
self.num_heads = int(num_heads)
|
| 101 |
+
self.head_dim = int(head_dim)
|
| 102 |
+
self.drop_p = getattr(base_attn, "dropout", nn.Dropout(0.0)).p
|
| 103 |
+
|
| 104 |
+
self.head_gate = HeadGate(num_heads=self.num_heads, head_dim=self.head_dim, tau=cfg.tau, init_logit=cfg.init_logit, hard_during_eval=cfg.hard_eval)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def logits(self) -> torch.Tensor:
|
| 108 |
+
return self.head_gate.logits
|
| 109 |
+
|
| 110 |
+
def kept_heads_soft(self) -> torch.Tensor:
|
| 111 |
+
return self.head_gate.probs().sum()
|
| 112 |
+
|
| 113 |
+
def forward(self, hidden_states, head_mask=None):
|
| 114 |
+
B, N, _ = hidden_states.shape
|
| 115 |
+
H, Dh = self.num_heads, self.head_dim
|
| 116 |
+
|
| 117 |
+
wdev = self.q_proj.weight.device
|
| 118 |
+
if hidden_states.device != wdev:
|
| 119 |
+
hidden_states = hidden_states.to(wdev, non_blocking=True)
|
| 120 |
+
|
| 121 |
+
q_lin = self.q_proj(hidden_states)
|
| 122 |
+
k_lin = self.k_proj(hidden_states)
|
| 123 |
+
v_lin = self.v_proj(hidden_states)
|
| 124 |
+
|
| 125 |
+
q = q_lin.view(B, N, H, Dh).transpose(1, 2)
|
| 126 |
+
k = k_lin.view(B, N, H, Dh).transpose(1, 2)
|
| 127 |
+
v = v_lin.view(B, N, H, Dh).transpose(1, 2)
|
| 128 |
+
|
| 129 |
+
logits = self.head_gate.logits
|
| 130 |
+
tau = float(self.head_gate.tau)
|
| 131 |
+
if self.training:
|
| 132 |
+
u = torch.rand_like(logits).clamp_(1e-6, 1-1e-6)
|
| 133 |
+
s = u.log() - (1 - u).log()
|
| 134 |
+
y = torch.sigmoid((logits + s) / tau)
|
| 135 |
+
g_head = ((y > 0.5).to(y.dtype) - y).detach() + y
|
| 136 |
+
else:
|
| 137 |
+
if getattr(self.head_gate, 'hard_during_eval', True):
|
| 138 |
+
g_head = (logits > 0).to(logits.dtype)
|
| 139 |
+
else:
|
| 140 |
+
g_head = torch.sigmoid(logits / tau)
|
| 141 |
+
g = g_head.view(1, H, 1, 1)
|
| 142 |
+
|
| 143 |
+
q = q * g; k = k * g; v = v * g
|
| 144 |
+
|
| 145 |
+
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
| 146 |
+
q, k, v, dropout_p=self.drop_p if self.training else 0.0
|
| 147 |
+
) # [B, H, N, Dh]
|
| 148 |
+
|
| 149 |
+
attn_out = attn_out.transpose(1, 2).contiguous().view(B, N, H * Dh)
|
| 150 |
+
attn_out = self.out_proj(attn_out)
|
| 151 |
+
return attn_out, None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# -----------------------------------------------------------------------------
|
| 155 |
+
# Adapter
|
| 156 |
+
# -----------------------------------------------------------------------------
|
| 157 |
+
|
| 158 |
+
class ViTAdapter:
|
| 159 |
+
def __init__(self, model: nn.Module):
|
| 160 |
+
self.model = model
|
| 161 |
+
_ = _encoder_layers(model)
|
| 162 |
+
|
| 163 |
+
# ---------- Gating attachment ----------
|
| 164 |
+
def attach_gates(self, cfg: ViTGatingConfig) -> nn.Module:
|
| 165 |
+
m = self.model
|
| 166 |
+
H = int(getattr(m.config, "num_attention_heads", 12))
|
| 167 |
+
D = int(getattr(m.config, "hidden_size", 768))
|
| 168 |
+
Dh = D // H
|
| 169 |
+
|
| 170 |
+
for layer in _encoder_layers(m):
|
| 171 |
+
# Attention heads
|
| 172 |
+
if cfg.head_gating:
|
| 173 |
+
attn_container = layer.attention
|
| 174 |
+
if not isinstance(getattr(attn_container, "attention", None), GatedSelfAttentionHF):
|
| 175 |
+
gated = GatedSelfAttentionHF(attn_container, H, Dh, cfg)
|
| 176 |
+
attn_container.attention = gated
|
| 177 |
+
|
| 178 |
+
# FFN hidden (grouped)
|
| 179 |
+
if cfg.ffn_gating:
|
| 180 |
+
inter = layer.intermediate
|
| 181 |
+
d_ff = int(inter.dense.out_features)
|
| 182 |
+
assert d_ff % cfg.ffn_group == 0, f"FFN size {d_ff} not divisible by group {cfg.ffn_group}"
|
| 183 |
+
if not hasattr(inter, "neuron_gate"):
|
| 184 |
+
inter.neuron_gate = GroupGate(num_groups=d_ff // cfg.ffn_group, group_size=cfg.ffn_group, tau=cfg.tau, init_logit=cfg.init_logit, hard_during_eval=cfg.hard_eval)
|
| 185 |
+
# Monkey-patch forward to apply mask after activation (keeps HF shapes)
|
| 186 |
+
if not hasattr(inter, "_orig_forward"):
|
| 187 |
+
inter._orig_forward = inter.forward
|
| 188 |
+
|
| 189 |
+
def _gated_forward(this, x):
|
| 190 |
+
h = this.dense(x)
|
| 191 |
+
h = this.intermediate_act_fn(h)
|
| 192 |
+
msk = this.neuron_gate.mask(this.training).view(1, 1, -1)
|
| 193 |
+
return h * msk
|
| 194 |
+
|
| 195 |
+
inter.forward = _gated_forward.__get__(inter, inter.__class__)
|
| 196 |
+
return m
|
| 197 |
+
|
| 198 |
+
# ---------- Logits helpers ----------
|
| 199 |
+
@staticmethod
|
| 200 |
+
def get_logits(model: nn.Module, x: torch.Tensor, *, head: Optional[nn.Module] = None) -> torch.Tensor:
|
| 201 |
+
out = model(pixel_values=x)
|
| 202 |
+
if hasattr(out, "logits"):
|
| 203 |
+
return out.logits # ViTForImageClassification path
|
| 204 |
+
if hasattr(out, "last_hidden_state"): # ViTModel path (needs external head)
|
| 205 |
+
if head is None:
|
| 206 |
+
raise ValueError("Provide a classification head when using ViTModel without logits.")
|
| 207 |
+
cls_tok = out.last_hidden_state[:, 0, :]
|
| 208 |
+
if next(head.parameters(), torch.tensor([], device=cls_tok.device)).device != cls_tok.device:
|
| 209 |
+
head = head.to(cls_tok.device)
|
| 210 |
+
return head(cls_tok)
|
| 211 |
+
raise ValueError("Model output lacks logits and last_hidden_state.")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ---------- Exporters ----------
|
| 215 |
+
@staticmethod
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def export_keepall(model_with_gates: nn.Module) -> nn.Module:
|
| 218 |
+
slim = deepcopy_eval_cpu(model_with_gates)
|
| 219 |
+
for layer in _encoder_layers(slim):
|
| 220 |
+
# Attention: unwrap gate
|
| 221 |
+
attn_container = layer.attention
|
| 222 |
+
if isinstance(getattr(attn_container, "attention", None), GatedSelfAttentionHF):
|
| 223 |
+
gat = attn_container.attention
|
| 224 |
+
new_attn = copy.deepcopy(gat.base_attn)
|
| 225 |
+
# restore HF metadata if present
|
| 226 |
+
if hasattr(new_attn, "num_attention_heads"):
|
| 227 |
+
new_attn.num_attention_heads = int(gat.num_heads)
|
| 228 |
+
if hasattr(new_attn, "attention_head_size"):
|
| 229 |
+
new_attn.attention_head_size = int(gat.head_dim)
|
| 230 |
+
if hasattr(new_attn, "all_head_size"):
|
| 231 |
+
new_attn.all_head_size = int(gat.num_heads * gat.head_dim)
|
| 232 |
+
attn_container.attention = new_attn
|
| 233 |
+
# FFN: restore original forward and drop gate
|
| 234 |
+
inter = layer.intermediate
|
| 235 |
+
if hasattr(inter, "_orig_forward"):
|
| 236 |
+
inter.forward = inter._orig_forward
|
| 237 |
+
delattr(inter, "_orig_forward")
|
| 238 |
+
if hasattr(inter, "neuron_gate"):
|
| 239 |
+
delattr(inter, "neuron_gate")
|
| 240 |
+
return slim
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
@torch.no_grad()
|
| 244 |
+
def export_pruned(model_with_gates: nn.Module, policy, step: int) -> nn.Module:
|
| 245 |
+
# Support both CoreExportPolicy (single rounding) and ViTExportPolicy (per-axis)
|
| 246 |
+
if isinstance(policy, ViTExportPolicy):
|
| 247 |
+
head_rounding = policy.head_rounding
|
| 248 |
+
ffn_rounding = policy.ffn_rounding
|
| 249 |
+
warmup_steps = policy.warmup_steps
|
| 250 |
+
else:
|
| 251 |
+
# fallback to single rounding for both
|
| 252 |
+
head_rounding = getattr(policy, "rounding", None)
|
| 253 |
+
ffn_rounding = getattr(policy, "rounding", None)
|
| 254 |
+
warmup_steps = int(getattr(policy, "warmup_steps", 0))
|
| 255 |
+
|
| 256 |
+
slim = deepcopy_eval_cpu(model_with_gates)
|
| 257 |
+
warm = (step < warmup_steps)
|
| 258 |
+
|
| 259 |
+
for layer in _encoder_layers(slim):
|
| 260 |
+
# --- Attention heads ---
|
| 261 |
+
attn_container = layer.attention
|
| 262 |
+
gat = getattr(attn_container, "attention", None)
|
| 263 |
+
if isinstance(gat, GatedSelfAttentionHF):
|
| 264 |
+
# choose rounding
|
| 265 |
+
rnd = head_rounding
|
| 266 |
+
# decide head indices via our helper; honor warmup if needed by passing step
|
| 267 |
+
grp_idx = keep_group_indices_from_gate(
|
| 268 |
+
gat.head_gate,
|
| 269 |
+
policy=policy,
|
| 270 |
+
step=step,
|
| 271 |
+
custom_rounding=rnd,
|
| 272 |
+
)
|
| 273 |
+
H_keep = int(grp_idx.numel())
|
| 274 |
+
Dh = int(gat.head_dim)
|
| 275 |
+
|
| 276 |
+
ch_idx = torch.cat([torch.arange(h * Dh, (h + 1) * Dh) for h in grp_idx]).long()
|
| 277 |
+
gat.q_proj = slice_linear(gat.q_proj, keep_out=ch_idx)
|
| 278 |
+
gat.k_proj = slice_linear(gat.k_proj, keep_out=ch_idx)
|
| 279 |
+
gat.v_proj = slice_linear(gat.v_proj, keep_out=ch_idx)
|
| 280 |
+
attn_container.output.dense = slice_linear(attn_container.output.dense, keep_in=ch_idx)
|
| 281 |
+
|
| 282 |
+
new_attn = copy.deepcopy(gat.base_attn)
|
| 283 |
+
new_attn.query = gat.q_proj
|
| 284 |
+
new_attn.key = gat.k_proj
|
| 285 |
+
new_attn.value = gat.v_proj
|
| 286 |
+
if hasattr(new_attn, "num_attention_heads"):
|
| 287 |
+
new_attn.num_attention_heads = H_keep
|
| 288 |
+
if hasattr(new_attn, "attention_head_size"):
|
| 289 |
+
new_attn.attention_head_size = Dh
|
| 290 |
+
if hasattr(new_attn, "all_head_size"):
|
| 291 |
+
new_attn.all_head_size = H_keep * Dh
|
| 292 |
+
attn_container.attention = new_attn
|
| 293 |
+
|
| 294 |
+
# --- FFN groups ---
|
| 295 |
+
inter, out = layer.intermediate, layer.output
|
| 296 |
+
g = getattr(inter, "neuron_gate", None)
|
| 297 |
+
if g is not None:
|
| 298 |
+
rnd = ffn_rounding
|
| 299 |
+
grp_idx = keep_group_indices_from_gate(
|
| 300 |
+
g,
|
| 301 |
+
policy=policy,
|
| 302 |
+
step=step,
|
| 303 |
+
custom_rounding=rnd,
|
| 304 |
+
)
|
| 305 |
+
group = int(g.group_size)
|
| 306 |
+
keep_exp = torch.cat([torch.arange(i * group, (i + 1) * group) for i in grp_idx]).long()
|
| 307 |
+
inter.dense = slice_linear(inter.dense, keep_out=keep_exp)
|
| 308 |
+
out.dense = slice_linear(out.dense, keep_in=keep_exp)
|
| 309 |
+
|
| 310 |
+
# # restore clean forward & drop gate
|
| 311 |
+
# if hasattr(inter, "_orig_forward"):
|
| 312 |
+
# def _clean_forward(this, x):
|
| 313 |
+
# h = this.dense(x)
|
| 314 |
+
# return this.intermediate_act_fn(h)
|
| 315 |
+
# inter.forward = _clean_forward.__get__(inter, inter.__class__)
|
| 316 |
+
# delattr(inter, "_orig_forward")
|
| 317 |
+
# if hasattr(inter, "neuron_gate"):
|
| 318 |
+
# delattr(inter, "neuron_gate")
|
| 319 |
+
|
| 320 |
+
inter.forward = inter.__class__.forward.__get__(inter, inter.__class__)
|
| 321 |
+
if hasattr(inter, "neuron_gate"):
|
| 322 |
+
delattr(inter, "neuron_gate")
|
| 323 |
+
if hasattr(inter, "_orig_forward"):
|
| 324 |
+
delattr(inter, "_orig_forward")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
return slim
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# -----------------------------------------------------------------------------
|
| 333 |
+
# Export policy
|
| 334 |
+
# -----------------------------------------------------------------------------
|
| 335 |
+
"""ViT-specific export policy that allows different rounding for heads vs FFN."""
|
| 336 |
+
@dataclass
|
| 337 |
+
class ViTExportPolicy:
|
| 338 |
+
warmup_steps: int = 0
|
| 339 |
+
head_rounding: CoreRounding = CoreRounding()
|
| 340 |
+
ffn_rounding: CoreRounding = CoreRounding()
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@dataclass
|
| 344 |
+
class ViTGrid:
|
| 345 |
+
head_multiple_grid: Optional[Sequence[int]] = (2, 4, 8)
|
| 346 |
+
ffn_snap_grid: Sequence[int] = (1, 8)
|
| 347 |
+
# head_multiple_grid: Optional[Sequence[int]] = None # default --> 1..num_heads
|
| 348 |
+
# ffn_snap_grid: Sequence[int] = (1, 2, 4, 8, 16)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def vit_search_best_export(
|
| 352 |
+
model_with_gates: nn.Module,
|
| 353 |
+
*,
|
| 354 |
+
export_fn: ExportFn,
|
| 355 |
+
num_heads: int,
|
| 356 |
+
step: int,
|
| 357 |
+
batch_shape: Tuple[int, int, int, int],
|
| 358 |
+
grid: Optional[ViTGrid] = None,
|
| 359 |
+
device: str = "cuda",
|
| 360 |
+
measure_settings: Optional[ProfileSettings] = None,
|
| 361 |
+
make_policy: Optional[Callable[[int, int], object]] = None,
|
| 362 |
+
) -> SearchResult:
|
| 363 |
+
"""Convenience wrapper for ViT-style search.
|
| 364 |
+
|
| 365 |
+
If `make_policy` is not provided, the caller's adapter should accept a
|
| 366 |
+
policy with separate head/FFN rounding; see `adapters.huggingface.vit.ViTExportPolicy`.
|
| 367 |
+
"""
|
| 368 |
+
g = grid or ViTGrid()
|
| 369 |
+
head_grid = g.head_multiple_grid or list(range(1, int(num_heads) + 1))
|
| 370 |
+
ffn_grid = list(g.ffn_snap_grid)
|
| 371 |
+
|
| 372 |
+
return grid_search_latency(
|
| 373 |
+
model_with_gates,
|
| 374 |
+
export_fn,
|
| 375 |
+
head_multiples=head_grid,
|
| 376 |
+
ffn_snaps=ffn_grid,
|
| 377 |
+
step=step,
|
| 378 |
+
batch_shape=batch_shape,
|
| 379 |
+
measure_settings=measure_settings,
|
| 380 |
+
device=device,
|
| 381 |
+
make_policy=make_policy,
|
| 382 |
+
)
|
| 383 |
+
|
model_index.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"task": "image-classification",
|
| 3 |
+
"base_id": "google/vit-base-patch16-224",
|
| 4 |
+
"variant": "gated-student"
|
| 5 |
+
}
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eca9442ba47bd27888b3dc0b0df757113779d2c21182d5626cf1d54643fe637c
|
| 3 |
+
size 343618083
|