miosipov commited on
Commit
cf5d751
·
verified ·
1 Parent(s): d4219ed

Upload folder using huggingface_hub

Browse files
Files changed (39) hide show
  1. README.md +41 -0
  2. config.json +2026 -0
  3. core/.ipynb_checkpoints/distill-checkpoint.py +184 -0
  4. core/.ipynb_checkpoints/finetune-checkpoint.py +267 -0
  5. core/.ipynb_checkpoints/profiler-checkpoint.py +236 -0
  6. core/.ipynb_checkpoints/proxy_cost-checkpoint.py +771 -0
  7. core/.ipynb_checkpoints/train-checkpoint.py +327 -0
  8. core/.ipynb_checkpoints/utils-checkpoint.py +190 -0
  9. core/__init__.py +0 -0
  10. core/__pycache__/__init__.cpython-310.pyc +0 -0
  11. core/__pycache__/distill.cpython-310.pyc +0 -0
  12. core/__pycache__/export.cpython-310.pyc +0 -0
  13. core/__pycache__/finetune.cpython-310.pyc +0 -0
  14. core/__pycache__/gates.cpython-310.pyc +0 -0
  15. core/__pycache__/profiler.cpython-310.pyc +0 -0
  16. core/__pycache__/proxy_cost.cpython-310.pyc +0 -0
  17. core/__pycache__/search_export.cpython-310.pyc +0 -0
  18. core/__pycache__/train.cpython-310.pyc +0 -0
  19. core/__pycache__/utils.cpython-310.pyc +0 -0
  20. core/distill.py +183 -0
  21. core/export.py +220 -0
  22. core/finetune.py +267 -0
  23. core/gates.py +389 -0
  24. core/profiler.py +236 -0
  25. core/proxy_cost.py +771 -0
  26. core/search_export.py +76 -0
  27. core/train.py +327 -0
  28. core/utils.py +190 -0
  29. custom_code.py +1 -0
  30. huggingface/.ipynb_checkpoints/llama-checkpoint.py +607 -0
  31. huggingface/.ipynb_checkpoints/vit-checkpoint.py +383 -0
  32. huggingface/__init__.py +0 -0
  33. huggingface/__pycache__/__init__.cpython-310.pyc +0 -0
  34. huggingface/__pycache__/vit.cpython-310.pyc +0 -0
  35. huggingface/llama.py +607 -0
  36. huggingface/registry.py +0 -0
  37. huggingface/vit.py +383 -0
  38. model_index.json +5 -0
  39. pytorch_model.bin +3 -0
README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```yaml
2
+ ---
3
+ library_name: pytorch
4
+ tags:
5
+ - resnet
6
+ - pruning
7
+ - knowledge-distillation
8
+ - speedup
9
+ license: apache-2.0
10
+ dataset: imagenet-1k
11
+ pipeline_tag: image-classification
12
+ ---
13
+ ```
14
+ # hawada/vit-base-patch16-224-rtx4090-gated
15
+
16
+ This repository contains two variants:
17
+ - **Gated student** (with learned pruning gates) – requires custom code.
18
+ - **Slim student** (post-prune/export) – loads with standard code (LLM) or bundled code (ResNet).
19
+
20
+ ## Inference (LLM, slim)
21
+ ```python
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ tok = AutoTokenizer.from_pretrained('hawada/vit-base-patch16-224-rtx4090-slim')
24
+ mdl = AutoModelForCausalLM.from_pretrained('hawada/vit-base-patch16-224-rtx4090-slim', torch_dtype='auto').eval()
25
+ x = tok('Hello', return_tensors='pt')
26
+ print(tok.decode(mdl.generate(**x, max_new_tokens=16)[0]))
27
+ ```
28
+
29
+ ## Notes
30
+ - The **gated** repo includes lightweight custom code (adapters/…, core/…) needed to attach/load gates.
31
+ - The **slim** LLM is exported to standard HF architecture for out-of-the-box loading.
32
+ - For ResNet, both repos include minimal custom code to define the module.
33
+
34
+ ## Training metadata
35
+ ```json
36
+ {
37
+ "base_id": "google/vit-base-patch16-224",
38
+ "variant": "gated-student",
39
+ "repo_slim": "hawada/vit-base-patch16-224-rtx4090-slim"
40
+ }
41
+ ```
config.json ADDED
@@ -0,0 +1,2026 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTForImageClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "encoder_stride": 16,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.0,
9
+ "hidden_size": 768,
10
+ "id2label": {
11
+ "0": "tench, Tinca tinca",
12
+ "1": "goldfish, Carassius auratus",
13
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
14
+ "3": "tiger shark, Galeocerdo cuvieri",
15
+ "4": "hammerhead, hammerhead shark",
16
+ "5": "electric ray, crampfish, numbfish, torpedo",
17
+ "6": "stingray",
18
+ "7": "cock",
19
+ "8": "hen",
20
+ "9": "ostrich, Struthio camelus",
21
+ "10": "brambling, Fringilla montifringilla",
22
+ "11": "goldfinch, Carduelis carduelis",
23
+ "12": "house finch, linnet, Carpodacus mexicanus",
24
+ "13": "junco, snowbird",
25
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
26
+ "15": "robin, American robin, Turdus migratorius",
27
+ "16": "bulbul",
28
+ "17": "jay",
29
+ "18": "magpie",
30
+ "19": "chickadee",
31
+ "20": "water ouzel, dipper",
32
+ "21": "kite",
33
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
34
+ "23": "vulture",
35
+ "24": "great grey owl, great gray owl, Strix nebulosa",
36
+ "25": "European fire salamander, Salamandra salamandra",
37
+ "26": "common newt, Triturus vulgaris",
38
+ "27": "eft",
39
+ "28": "spotted salamander, Ambystoma maculatum",
40
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
41
+ "30": "bullfrog, Rana catesbeiana",
42
+ "31": "tree frog, tree-frog",
43
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
44
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
45
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
46
+ "35": "mud turtle",
47
+ "36": "terrapin",
48
+ "37": "box turtle, box tortoise",
49
+ "38": "banded gecko",
50
+ "39": "common iguana, iguana, Iguana iguana",
51
+ "40": "American chameleon, anole, Anolis carolinensis",
52
+ "41": "whiptail, whiptail lizard",
53
+ "42": "agama",
54
+ "43": "frilled lizard, Chlamydosaurus kingi",
55
+ "44": "alligator lizard",
56
+ "45": "Gila monster, Heloderma suspectum",
57
+ "46": "green lizard, Lacerta viridis",
58
+ "47": "African chameleon, Chamaeleo chamaeleon",
59
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
60
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
61
+ "50": "American alligator, Alligator mississipiensis",
62
+ "51": "triceratops",
63
+ "52": "thunder snake, worm snake, Carphophis amoenus",
64
+ "53": "ringneck snake, ring-necked snake, ring snake",
65
+ "54": "hognose snake, puff adder, sand viper",
66
+ "55": "green snake, grass snake",
67
+ "56": "king snake, kingsnake",
68
+ "57": "garter snake, grass snake",
69
+ "58": "water snake",
70
+ "59": "vine snake",
71
+ "60": "night snake, Hypsiglena torquata",
72
+ "61": "boa constrictor, Constrictor constrictor",
73
+ "62": "rock python, rock snake, Python sebae",
74
+ "63": "Indian cobra, Naja naja",
75
+ "64": "green mamba",
76
+ "65": "sea snake",
77
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
78
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
79
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
80
+ "69": "trilobite",
81
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
82
+ "71": "scorpion",
83
+ "72": "black and gold garden spider, Argiope aurantia",
84
+ "73": "barn spider, Araneus cavaticus",
85
+ "74": "garden spider, Aranea diademata",
86
+ "75": "black widow, Latrodectus mactans",
87
+ "76": "tarantula",
88
+ "77": "wolf spider, hunting spider",
89
+ "78": "tick",
90
+ "79": "centipede",
91
+ "80": "black grouse",
92
+ "81": "ptarmigan",
93
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
94
+ "83": "prairie chicken, prairie grouse, prairie fowl",
95
+ "84": "peacock",
96
+ "85": "quail",
97
+ "86": "partridge",
98
+ "87": "African grey, African gray, Psittacus erithacus",
99
+ "88": "macaw",
100
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
101
+ "90": "lorikeet",
102
+ "91": "coucal",
103
+ "92": "bee eater",
104
+ "93": "hornbill",
105
+ "94": "hummingbird",
106
+ "95": "jacamar",
107
+ "96": "toucan",
108
+ "97": "drake",
109
+ "98": "red-breasted merganser, Mergus serrator",
110
+ "99": "goose",
111
+ "100": "black swan, Cygnus atratus",
112
+ "101": "tusker",
113
+ "102": "echidna, spiny anteater, anteater",
114
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
115
+ "104": "wallaby, brush kangaroo",
116
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
117
+ "106": "wombat",
118
+ "107": "jellyfish",
119
+ "108": "sea anemone, anemone",
120
+ "109": "brain coral",
121
+ "110": "flatworm, platyhelminth",
122
+ "111": "nematode, nematode worm, roundworm",
123
+ "112": "conch",
124
+ "113": "snail",
125
+ "114": "slug",
126
+ "115": "sea slug, nudibranch",
127
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
128
+ "117": "chambered nautilus, pearly nautilus, nautilus",
129
+ "118": "Dungeness crab, Cancer magister",
130
+ "119": "rock crab, Cancer irroratus",
131
+ "120": "fiddler crab",
132
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
133
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
134
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
135
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
136
+ "125": "hermit crab",
137
+ "126": "isopod",
138
+ "127": "white stork, Ciconia ciconia",
139
+ "128": "black stork, Ciconia nigra",
140
+ "129": "spoonbill",
141
+ "130": "flamingo",
142
+ "131": "little blue heron, Egretta caerulea",
143
+ "132": "American egret, great white heron, Egretta albus",
144
+ "133": "bittern",
145
+ "134": "crane",
146
+ "135": "limpkin, Aramus pictus",
147
+ "136": "European gallinule, Porphyrio porphyrio",
148
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
149
+ "138": "bustard",
150
+ "139": "ruddy turnstone, Arenaria interpres",
151
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
152
+ "141": "redshank, Tringa totanus",
153
+ "142": "dowitcher",
154
+ "143": "oystercatcher, oyster catcher",
155
+ "144": "pelican",
156
+ "145": "king penguin, Aptenodytes patagonica",
157
+ "146": "albatross, mollymawk",
158
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
159
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
160
+ "149": "dugong, Dugong dugon",
161
+ "150": "sea lion",
162
+ "151": "Chihuahua",
163
+ "152": "Japanese spaniel",
164
+ "153": "Maltese dog, Maltese terrier, Maltese",
165
+ "154": "Pekinese, Pekingese, Peke",
166
+ "155": "Shih-Tzu",
167
+ "156": "Blenheim spaniel",
168
+ "157": "papillon",
169
+ "158": "toy terrier",
170
+ "159": "Rhodesian ridgeback",
171
+ "160": "Afghan hound, Afghan",
172
+ "161": "basset, basset hound",
173
+ "162": "beagle",
174
+ "163": "bloodhound, sleuthhound",
175
+ "164": "bluetick",
176
+ "165": "black-and-tan coonhound",
177
+ "166": "Walker hound, Walker foxhound",
178
+ "167": "English foxhound",
179
+ "168": "redbone",
180
+ "169": "borzoi, Russian wolfhound",
181
+ "170": "Irish wolfhound",
182
+ "171": "Italian greyhound",
183
+ "172": "whippet",
184
+ "173": "Ibizan hound, Ibizan Podenco",
185
+ "174": "Norwegian elkhound, elkhound",
186
+ "175": "otterhound, otter hound",
187
+ "176": "Saluki, gazelle hound",
188
+ "177": "Scottish deerhound, deerhound",
189
+ "178": "Weimaraner",
190
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
191
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
192
+ "181": "Bedlington terrier",
193
+ "182": "Border terrier",
194
+ "183": "Kerry blue terrier",
195
+ "184": "Irish terrier",
196
+ "185": "Norfolk terrier",
197
+ "186": "Norwich terrier",
198
+ "187": "Yorkshire terrier",
199
+ "188": "wire-haired fox terrier",
200
+ "189": "Lakeland terrier",
201
+ "190": "Sealyham terrier, Sealyham",
202
+ "191": "Airedale, Airedale terrier",
203
+ "192": "cairn, cairn terrier",
204
+ "193": "Australian terrier",
205
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
206
+ "195": "Boston bull, Boston terrier",
207
+ "196": "miniature schnauzer",
208
+ "197": "giant schnauzer",
209
+ "198": "standard schnauzer",
210
+ "199": "Scotch terrier, Scottish terrier, Scottie",
211
+ "200": "Tibetan terrier, chrysanthemum dog",
212
+ "201": "silky terrier, Sydney silky",
213
+ "202": "soft-coated wheaten terrier",
214
+ "203": "West Highland white terrier",
215
+ "204": "Lhasa, Lhasa apso",
216
+ "205": "flat-coated retriever",
217
+ "206": "curly-coated retriever",
218
+ "207": "golden retriever",
219
+ "208": "Labrador retriever",
220
+ "209": "Chesapeake Bay retriever",
221
+ "210": "German short-haired pointer",
222
+ "211": "vizsla, Hungarian pointer",
223
+ "212": "English setter",
224
+ "213": "Irish setter, red setter",
225
+ "214": "Gordon setter",
226
+ "215": "Brittany spaniel",
227
+ "216": "clumber, clumber spaniel",
228
+ "217": "English springer, English springer spaniel",
229
+ "218": "Welsh springer spaniel",
230
+ "219": "cocker spaniel, English cocker spaniel, cocker",
231
+ "220": "Sussex spaniel",
232
+ "221": "Irish water spaniel",
233
+ "222": "kuvasz",
234
+ "223": "schipperke",
235
+ "224": "groenendael",
236
+ "225": "malinois",
237
+ "226": "briard",
238
+ "227": "kelpie",
239
+ "228": "komondor",
240
+ "229": "Old English sheepdog, bobtail",
241
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
242
+ "231": "collie",
243
+ "232": "Border collie",
244
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
245
+ "234": "Rottweiler",
246
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
247
+ "236": "Doberman, Doberman pinscher",
248
+ "237": "miniature pinscher",
249
+ "238": "Greater Swiss Mountain dog",
250
+ "239": "Bernese mountain dog",
251
+ "240": "Appenzeller",
252
+ "241": "EntleBucher",
253
+ "242": "boxer",
254
+ "243": "bull mastiff",
255
+ "244": "Tibetan mastiff",
256
+ "245": "French bulldog",
257
+ "246": "Great Dane",
258
+ "247": "Saint Bernard, St Bernard",
259
+ "248": "Eskimo dog, husky",
260
+ "249": "malamute, malemute, Alaskan malamute",
261
+ "250": "Siberian husky",
262
+ "251": "dalmatian, coach dog, carriage dog",
263
+ "252": "affenpinscher, monkey pinscher, monkey dog",
264
+ "253": "basenji",
265
+ "254": "pug, pug-dog",
266
+ "255": "Leonberg",
267
+ "256": "Newfoundland, Newfoundland dog",
268
+ "257": "Great Pyrenees",
269
+ "258": "Samoyed, Samoyede",
270
+ "259": "Pomeranian",
271
+ "260": "chow, chow chow",
272
+ "261": "keeshond",
273
+ "262": "Brabancon griffon",
274
+ "263": "Pembroke, Pembroke Welsh corgi",
275
+ "264": "Cardigan, Cardigan Welsh corgi",
276
+ "265": "toy poodle",
277
+ "266": "miniature poodle",
278
+ "267": "standard poodle",
279
+ "268": "Mexican hairless",
280
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
281
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
282
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
283
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
284
+ "273": "dingo, warrigal, warragal, Canis dingo",
285
+ "274": "dhole, Cuon alpinus",
286
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
287
+ "276": "hyena, hyaena",
288
+ "277": "red fox, Vulpes vulpes",
289
+ "278": "kit fox, Vulpes macrotis",
290
+ "279": "Arctic fox, white fox, Alopex lagopus",
291
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
292
+ "281": "tabby, tabby cat",
293
+ "282": "tiger cat",
294
+ "283": "Persian cat",
295
+ "284": "Siamese cat, Siamese",
296
+ "285": "Egyptian cat",
297
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
298
+ "287": "lynx, catamount",
299
+ "288": "leopard, Panthera pardus",
300
+ "289": "snow leopard, ounce, Panthera uncia",
301
+ "290": "jaguar, panther, Panthera onca, Felis onca",
302
+ "291": "lion, king of beasts, Panthera leo",
303
+ "292": "tiger, Panthera tigris",
304
+ "293": "cheetah, chetah, Acinonyx jubatus",
305
+ "294": "brown bear, bruin, Ursus arctos",
306
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
307
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
308
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
309
+ "298": "mongoose",
310
+ "299": "meerkat, mierkat",
311
+ "300": "tiger beetle",
312
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
313
+ "302": "ground beetle, carabid beetle",
314
+ "303": "long-horned beetle, longicorn, longicorn beetle",
315
+ "304": "leaf beetle, chrysomelid",
316
+ "305": "dung beetle",
317
+ "306": "rhinoceros beetle",
318
+ "307": "weevil",
319
+ "308": "fly",
320
+ "309": "bee",
321
+ "310": "ant, emmet, pismire",
322
+ "311": "grasshopper, hopper",
323
+ "312": "cricket",
324
+ "313": "walking stick, walkingstick, stick insect",
325
+ "314": "cockroach, roach",
326
+ "315": "mantis, mantid",
327
+ "316": "cicada, cicala",
328
+ "317": "leafhopper",
329
+ "318": "lacewing, lacewing fly",
330
+ "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
331
+ "320": "damselfly",
332
+ "321": "admiral",
333
+ "322": "ringlet, ringlet butterfly",
334
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
335
+ "324": "cabbage butterfly",
336
+ "325": "sulphur butterfly, sulfur butterfly",
337
+ "326": "lycaenid, lycaenid butterfly",
338
+ "327": "starfish, sea star",
339
+ "328": "sea urchin",
340
+ "329": "sea cucumber, holothurian",
341
+ "330": "wood rabbit, cottontail, cottontail rabbit",
342
+ "331": "hare",
343
+ "332": "Angora, Angora rabbit",
344
+ "333": "hamster",
345
+ "334": "porcupine, hedgehog",
346
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
347
+ "336": "marmot",
348
+ "337": "beaver",
349
+ "338": "guinea pig, Cavia cobaya",
350
+ "339": "sorrel",
351
+ "340": "zebra",
352
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
353
+ "342": "wild boar, boar, Sus scrofa",
354
+ "343": "warthog",
355
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
356
+ "345": "ox",
357
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
358
+ "347": "bison",
359
+ "348": "ram, tup",
360
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
361
+ "350": "ibex, Capra ibex",
362
+ "351": "hartebeest",
363
+ "352": "impala, Aepyceros melampus",
364
+ "353": "gazelle",
365
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
366
+ "355": "llama",
367
+ "356": "weasel",
368
+ "357": "mink",
369
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
370
+ "359": "black-footed ferret, ferret, Mustela nigripes",
371
+ "360": "otter",
372
+ "361": "skunk, polecat, wood pussy",
373
+ "362": "badger",
374
+ "363": "armadillo",
375
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
376
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
377
+ "366": "gorilla, Gorilla gorilla",
378
+ "367": "chimpanzee, chimp, Pan troglodytes",
379
+ "368": "gibbon, Hylobates lar",
380
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
381
+ "370": "guenon, guenon monkey",
382
+ "371": "patas, hussar monkey, Erythrocebus patas",
383
+ "372": "baboon",
384
+ "373": "macaque",
385
+ "374": "langur",
386
+ "375": "colobus, colobus monkey",
387
+ "376": "proboscis monkey, Nasalis larvatus",
388
+ "377": "marmoset",
389
+ "378": "capuchin, ringtail, Cebus capucinus",
390
+ "379": "howler monkey, howler",
391
+ "380": "titi, titi monkey",
392
+ "381": "spider monkey, Ateles geoffroyi",
393
+ "382": "squirrel monkey, Saimiri sciureus",
394
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
395
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
396
+ "385": "Indian elephant, Elephas maximus",
397
+ "386": "African elephant, Loxodonta africana",
398
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
399
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
400
+ "389": "barracouta, snoek",
401
+ "390": "eel",
402
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
403
+ "392": "rock beauty, Holocanthus tricolor",
404
+ "393": "anemone fish",
405
+ "394": "sturgeon",
406
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
407
+ "396": "lionfish",
408
+ "397": "puffer, pufferfish, blowfish, globefish",
409
+ "398": "abacus",
410
+ "399": "abaya",
411
+ "400": "academic gown, academic robe, judge's robe",
412
+ "401": "accordion, piano accordion, squeeze box",
413
+ "402": "acoustic guitar",
414
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
415
+ "404": "airliner",
416
+ "405": "airship, dirigible",
417
+ "406": "altar",
418
+ "407": "ambulance",
419
+ "408": "amphibian, amphibious vehicle",
420
+ "409": "analog clock",
421
+ "410": "apiary, bee house",
422
+ "411": "apron",
423
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
424
+ "413": "assault rifle, assault gun",
425
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
426
+ "415": "bakery, bakeshop, bakehouse",
427
+ "416": "balance beam, beam",
428
+ "417": "balloon",
429
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
430
+ "419": "Band Aid",
431
+ "420": "banjo",
432
+ "421": "bannister, banister, balustrade, balusters, handrail",
433
+ "422": "barbell",
434
+ "423": "barber chair",
435
+ "424": "barbershop",
436
+ "425": "barn",
437
+ "426": "barometer",
438
+ "427": "barrel, cask",
439
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
440
+ "429": "baseball",
441
+ "430": "basketball",
442
+ "431": "bassinet",
443
+ "432": "bassoon",
444
+ "433": "bathing cap, swimming cap",
445
+ "434": "bath towel",
446
+ "435": "bathtub, bathing tub, bath, tub",
447
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
448
+ "437": "beacon, lighthouse, beacon light, pharos",
449
+ "438": "beaker",
450
+ "439": "bearskin, busby, shako",
451
+ "440": "beer bottle",
452
+ "441": "beer glass",
453
+ "442": "bell cote, bell cot",
454
+ "443": "bib",
455
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
456
+ "445": "bikini, two-piece",
457
+ "446": "binder, ring-binder",
458
+ "447": "binoculars, field glasses, opera glasses",
459
+ "448": "birdhouse",
460
+ "449": "boathouse",
461
+ "450": "bobsled, bobsleigh, bob",
462
+ "451": "bolo tie, bolo, bola tie, bola",
463
+ "452": "bonnet, poke bonnet",
464
+ "453": "bookcase",
465
+ "454": "bookshop, bookstore, bookstall",
466
+ "455": "bottlecap",
467
+ "456": "bow",
468
+ "457": "bow tie, bow-tie, bowtie",
469
+ "458": "brass, memorial tablet, plaque",
470
+ "459": "brassiere, bra, bandeau",
471
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
472
+ "461": "breastplate, aegis, egis",
473
+ "462": "broom",
474
+ "463": "bucket, pail",
475
+ "464": "buckle",
476
+ "465": "bulletproof vest",
477
+ "466": "bullet train, bullet",
478
+ "467": "butcher shop, meat market",
479
+ "468": "cab, hack, taxi, taxicab",
480
+ "469": "caldron, cauldron",
481
+ "470": "candle, taper, wax light",
482
+ "471": "cannon",
483
+ "472": "canoe",
484
+ "473": "can opener, tin opener",
485
+ "474": "cardigan",
486
+ "475": "car mirror",
487
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
488
+ "477": "carpenter's kit, tool kit",
489
+ "478": "carton",
490
+ "479": "car wheel",
491
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
492
+ "481": "cassette",
493
+ "482": "cassette player",
494
+ "483": "castle",
495
+ "484": "catamaran",
496
+ "485": "CD player",
497
+ "486": "cello, violoncello",
498
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
499
+ "488": "chain",
500
+ "489": "chainlink fence",
501
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
502
+ "491": "chain saw, chainsaw",
503
+ "492": "chest",
504
+ "493": "chiffonier, commode",
505
+ "494": "chime, bell, gong",
506
+ "495": "china cabinet, china closet",
507
+ "496": "Christmas stocking",
508
+ "497": "church, church building",
509
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
510
+ "499": "cleaver, meat cleaver, chopper",
511
+ "500": "cliff dwelling",
512
+ "501": "cloak",
513
+ "502": "clog, geta, patten, sabot",
514
+ "503": "cocktail shaker",
515
+ "504": "coffee mug",
516
+ "505": "coffeepot",
517
+ "506": "coil, spiral, volute, whorl, helix",
518
+ "507": "combination lock",
519
+ "508": "computer keyboard, keypad",
520
+ "509": "confectionery, confectionary, candy store",
521
+ "510": "container ship, containership, container vessel",
522
+ "511": "convertible",
523
+ "512": "corkscrew, bottle screw",
524
+ "513": "cornet, horn, trumpet, trump",
525
+ "514": "cowboy boot",
526
+ "515": "cowboy hat, ten-gallon hat",
527
+ "516": "cradle",
528
+ "517": "crane",
529
+ "518": "crash helmet",
530
+ "519": "crate",
531
+ "520": "crib, cot",
532
+ "521": "Crock Pot",
533
+ "522": "croquet ball",
534
+ "523": "crutch",
535
+ "524": "cuirass",
536
+ "525": "dam, dike, dyke",
537
+ "526": "desk",
538
+ "527": "desktop computer",
539
+ "528": "dial telephone, dial phone",
540
+ "529": "diaper, nappy, napkin",
541
+ "530": "digital clock",
542
+ "531": "digital watch",
543
+ "532": "dining table, board",
544
+ "533": "dishrag, dishcloth",
545
+ "534": "dishwasher, dish washer, dishwashing machine",
546
+ "535": "disk brake, disc brake",
547
+ "536": "dock, dockage, docking facility",
548
+ "537": "dogsled, dog sled, dog sleigh",
549
+ "538": "dome",
550
+ "539": "doormat, welcome mat",
551
+ "540": "drilling platform, offshore rig",
552
+ "541": "drum, membranophone, tympan",
553
+ "542": "drumstick",
554
+ "543": "dumbbell",
555
+ "544": "Dutch oven",
556
+ "545": "electric fan, blower",
557
+ "546": "electric guitar",
558
+ "547": "electric locomotive",
559
+ "548": "entertainment center",
560
+ "549": "envelope",
561
+ "550": "espresso maker",
562
+ "551": "face powder",
563
+ "552": "feather boa, boa",
564
+ "553": "file, file cabinet, filing cabinet",
565
+ "554": "fireboat",
566
+ "555": "fire engine, fire truck",
567
+ "556": "fire screen, fireguard",
568
+ "557": "flagpole, flagstaff",
569
+ "558": "flute, transverse flute",
570
+ "559": "folding chair",
571
+ "560": "football helmet",
572
+ "561": "forklift",
573
+ "562": "fountain",
574
+ "563": "fountain pen",
575
+ "564": "four-poster",
576
+ "565": "freight car",
577
+ "566": "French horn, horn",
578
+ "567": "frying pan, frypan, skillet",
579
+ "568": "fur coat",
580
+ "569": "garbage truck, dustcart",
581
+ "570": "gasmask, respirator, gas helmet",
582
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
583
+ "572": "goblet",
584
+ "573": "go-kart",
585
+ "574": "golf ball",
586
+ "575": "golfcart, golf cart",
587
+ "576": "gondola",
588
+ "577": "gong, tam-tam",
589
+ "578": "gown",
590
+ "579": "grand piano, grand",
591
+ "580": "greenhouse, nursery, glasshouse",
592
+ "581": "grille, radiator grille",
593
+ "582": "grocery store, grocery, food market, market",
594
+ "583": "guillotine",
595
+ "584": "hair slide",
596
+ "585": "hair spray",
597
+ "586": "half track",
598
+ "587": "hammer",
599
+ "588": "hamper",
600
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
601
+ "590": "hand-held computer, hand-held microcomputer",
602
+ "591": "handkerchief, hankie, hanky, hankey",
603
+ "592": "hard disc, hard disk, fixed disk",
604
+ "593": "harmonica, mouth organ, harp, mouth harp",
605
+ "594": "harp",
606
+ "595": "harvester, reaper",
607
+ "596": "hatchet",
608
+ "597": "holster",
609
+ "598": "home theater, home theatre",
610
+ "599": "honeycomb",
611
+ "600": "hook, claw",
612
+ "601": "hoopskirt, crinoline",
613
+ "602": "horizontal bar, high bar",
614
+ "603": "horse cart, horse-cart",
615
+ "604": "hourglass",
616
+ "605": "iPod",
617
+ "606": "iron, smoothing iron",
618
+ "607": "jack-o'-lantern",
619
+ "608": "jean, blue jean, denim",
620
+ "609": "jeep, landrover",
621
+ "610": "jersey, T-shirt, tee shirt",
622
+ "611": "jigsaw puzzle",
623
+ "612": "jinrikisha, ricksha, rickshaw",
624
+ "613": "joystick",
625
+ "614": "kimono",
626
+ "615": "knee pad",
627
+ "616": "knot",
628
+ "617": "lab coat, laboratory coat",
629
+ "618": "ladle",
630
+ "619": "lampshade, lamp shade",
631
+ "620": "laptop, laptop computer",
632
+ "621": "lawn mower, mower",
633
+ "622": "lens cap, lens cover",
634
+ "623": "letter opener, paper knife, paperknife",
635
+ "624": "library",
636
+ "625": "lifeboat",
637
+ "626": "lighter, light, igniter, ignitor",
638
+ "627": "limousine, limo",
639
+ "628": "liner, ocean liner",
640
+ "629": "lipstick, lip rouge",
641
+ "630": "Loafer",
642
+ "631": "lotion",
643
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
644
+ "633": "loupe, jeweler's loupe",
645
+ "634": "lumbermill, sawmill",
646
+ "635": "magnetic compass",
647
+ "636": "mailbag, postbag",
648
+ "637": "mailbox, letter box",
649
+ "638": "maillot",
650
+ "639": "maillot, tank suit",
651
+ "640": "manhole cover",
652
+ "641": "maraca",
653
+ "642": "marimba, xylophone",
654
+ "643": "mask",
655
+ "644": "matchstick",
656
+ "645": "maypole",
657
+ "646": "maze, labyrinth",
658
+ "647": "measuring cup",
659
+ "648": "medicine chest, medicine cabinet",
660
+ "649": "megalith, megalithic structure",
661
+ "650": "microphone, mike",
662
+ "651": "microwave, microwave oven",
663
+ "652": "military uniform",
664
+ "653": "milk can",
665
+ "654": "minibus",
666
+ "655": "miniskirt, mini",
667
+ "656": "minivan",
668
+ "657": "missile",
669
+ "658": "mitten",
670
+ "659": "mixing bowl",
671
+ "660": "mobile home, manufactured home",
672
+ "661": "Model T",
673
+ "662": "modem",
674
+ "663": "monastery",
675
+ "664": "monitor",
676
+ "665": "moped",
677
+ "666": "mortar",
678
+ "667": "mortarboard",
679
+ "668": "mosque",
680
+ "669": "mosquito net",
681
+ "670": "motor scooter, scooter",
682
+ "671": "mountain bike, all-terrain bike, off-roader",
683
+ "672": "mountain tent",
684
+ "673": "mouse, computer mouse",
685
+ "674": "mousetrap",
686
+ "675": "moving van",
687
+ "676": "muzzle",
688
+ "677": "nail",
689
+ "678": "neck brace",
690
+ "679": "necklace",
691
+ "680": "nipple",
692
+ "681": "notebook, notebook computer",
693
+ "682": "obelisk",
694
+ "683": "oboe, hautboy, hautbois",
695
+ "684": "ocarina, sweet potato",
696
+ "685": "odometer, hodometer, mileometer, milometer",
697
+ "686": "oil filter",
698
+ "687": "organ, pipe organ",
699
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
700
+ "689": "overskirt",
701
+ "690": "oxcart",
702
+ "691": "oxygen mask",
703
+ "692": "packet",
704
+ "693": "paddle, boat paddle",
705
+ "694": "paddlewheel, paddle wheel",
706
+ "695": "padlock",
707
+ "696": "paintbrush",
708
+ "697": "pajama, pyjama, pj's, jammies",
709
+ "698": "palace",
710
+ "699": "panpipe, pandean pipe, syrinx",
711
+ "700": "paper towel",
712
+ "701": "parachute, chute",
713
+ "702": "parallel bars, bars",
714
+ "703": "park bench",
715
+ "704": "parking meter",
716
+ "705": "passenger car, coach, carriage",
717
+ "706": "patio, terrace",
718
+ "707": "pay-phone, pay-station",
719
+ "708": "pedestal, plinth, footstall",
720
+ "709": "pencil box, pencil case",
721
+ "710": "pencil sharpener",
722
+ "711": "perfume, essence",
723
+ "712": "Petri dish",
724
+ "713": "photocopier",
725
+ "714": "pick, plectrum, plectron",
726
+ "715": "pickelhaube",
727
+ "716": "picket fence, paling",
728
+ "717": "pickup, pickup truck",
729
+ "718": "pier",
730
+ "719": "piggy bank, penny bank",
731
+ "720": "pill bottle",
732
+ "721": "pillow",
733
+ "722": "ping-pong ball",
734
+ "723": "pinwheel",
735
+ "724": "pirate, pirate ship",
736
+ "725": "pitcher, ewer",
737
+ "726": "plane, carpenter's plane, woodworking plane",
738
+ "727": "planetarium",
739
+ "728": "plastic bag",
740
+ "729": "plate rack",
741
+ "730": "plow, plough",
742
+ "731": "plunger, plumber's helper",
743
+ "732": "Polaroid camera, Polaroid Land camera",
744
+ "733": "pole",
745
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
746
+ "735": "poncho",
747
+ "736": "pool table, billiard table, snooker table",
748
+ "737": "pop bottle, soda bottle",
749
+ "738": "pot, flowerpot",
750
+ "739": "potter's wheel",
751
+ "740": "power drill",
752
+ "741": "prayer rug, prayer mat",
753
+ "742": "printer",
754
+ "743": "prison, prison house",
755
+ "744": "projectile, missile",
756
+ "745": "projector",
757
+ "746": "puck, hockey puck",
758
+ "747": "punching bag, punch bag, punching ball, punchball",
759
+ "748": "purse",
760
+ "749": "quill, quill pen",
761
+ "750": "quilt, comforter, comfort, puff",
762
+ "751": "racer, race car, racing car",
763
+ "752": "racket, racquet",
764
+ "753": "radiator",
765
+ "754": "radio, wireless",
766
+ "755": "radio telescope, radio reflector",
767
+ "756": "rain barrel",
768
+ "757": "recreational vehicle, RV, R.V.",
769
+ "758": "reel",
770
+ "759": "reflex camera",
771
+ "760": "refrigerator, icebox",
772
+ "761": "remote control, remote",
773
+ "762": "restaurant, eating house, eating place, eatery",
774
+ "763": "revolver, six-gun, six-shooter",
775
+ "764": "rifle",
776
+ "765": "rocking chair, rocker",
777
+ "766": "rotisserie",
778
+ "767": "rubber eraser, rubber, pencil eraser",
779
+ "768": "rugby ball",
780
+ "769": "rule, ruler",
781
+ "770": "running shoe",
782
+ "771": "safe",
783
+ "772": "safety pin",
784
+ "773": "saltshaker, salt shaker",
785
+ "774": "sandal",
786
+ "775": "sarong",
787
+ "776": "sax, saxophone",
788
+ "777": "scabbard",
789
+ "778": "scale, weighing machine",
790
+ "779": "school bus",
791
+ "780": "schooner",
792
+ "781": "scoreboard",
793
+ "782": "screen, CRT screen",
794
+ "783": "screw",
795
+ "784": "screwdriver",
796
+ "785": "seat belt, seatbelt",
797
+ "786": "sewing machine",
798
+ "787": "shield, buckler",
799
+ "788": "shoe shop, shoe-shop, shoe store",
800
+ "789": "shoji",
801
+ "790": "shopping basket",
802
+ "791": "shopping cart",
803
+ "792": "shovel",
804
+ "793": "shower cap",
805
+ "794": "shower curtain",
806
+ "795": "ski",
807
+ "796": "ski mask",
808
+ "797": "sleeping bag",
809
+ "798": "slide rule, slipstick",
810
+ "799": "sliding door",
811
+ "800": "slot, one-armed bandit",
812
+ "801": "snorkel",
813
+ "802": "snowmobile",
814
+ "803": "snowplow, snowplough",
815
+ "804": "soap dispenser",
816
+ "805": "soccer ball",
817
+ "806": "sock",
818
+ "807": "solar dish, solar collector, solar furnace",
819
+ "808": "sombrero",
820
+ "809": "soup bowl",
821
+ "810": "space bar",
822
+ "811": "space heater",
823
+ "812": "space shuttle",
824
+ "813": "spatula",
825
+ "814": "speedboat",
826
+ "815": "spider web, spider's web",
827
+ "816": "spindle",
828
+ "817": "sports car, sport car",
829
+ "818": "spotlight, spot",
830
+ "819": "stage",
831
+ "820": "steam locomotive",
832
+ "821": "steel arch bridge",
833
+ "822": "steel drum",
834
+ "823": "stethoscope",
835
+ "824": "stole",
836
+ "825": "stone wall",
837
+ "826": "stopwatch, stop watch",
838
+ "827": "stove",
839
+ "828": "strainer",
840
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
841
+ "830": "stretcher",
842
+ "831": "studio couch, day bed",
843
+ "832": "stupa, tope",
844
+ "833": "submarine, pigboat, sub, U-boat",
845
+ "834": "suit, suit of clothes",
846
+ "835": "sundial",
847
+ "836": "sunglass",
848
+ "837": "sunglasses, dark glasses, shades",
849
+ "838": "sunscreen, sunblock, sun blocker",
850
+ "839": "suspension bridge",
851
+ "840": "swab, swob, mop",
852
+ "841": "sweatshirt",
853
+ "842": "swimming trunks, bathing trunks",
854
+ "843": "swing",
855
+ "844": "switch, electric switch, electrical switch",
856
+ "845": "syringe",
857
+ "846": "table lamp",
858
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
859
+ "848": "tape player",
860
+ "849": "teapot",
861
+ "850": "teddy, teddy bear",
862
+ "851": "television, television system",
863
+ "852": "tennis ball",
864
+ "853": "thatch, thatched roof",
865
+ "854": "theater curtain, theatre curtain",
866
+ "855": "thimble",
867
+ "856": "thresher, thrasher, threshing machine",
868
+ "857": "throne",
869
+ "858": "tile roof",
870
+ "859": "toaster",
871
+ "860": "tobacco shop, tobacconist shop, tobacconist",
872
+ "861": "toilet seat",
873
+ "862": "torch",
874
+ "863": "totem pole",
875
+ "864": "tow truck, tow car, wrecker",
876
+ "865": "toyshop",
877
+ "866": "tractor",
878
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
879
+ "868": "tray",
880
+ "869": "trench coat",
881
+ "870": "tricycle, trike, velocipede",
882
+ "871": "trimaran",
883
+ "872": "tripod",
884
+ "873": "triumphal arch",
885
+ "874": "trolleybus, trolley coach, trackless trolley",
886
+ "875": "trombone",
887
+ "876": "tub, vat",
888
+ "877": "turnstile",
889
+ "878": "typewriter keyboard",
890
+ "879": "umbrella",
891
+ "880": "unicycle, monocycle",
892
+ "881": "upright, upright piano",
893
+ "882": "vacuum, vacuum cleaner",
894
+ "883": "vase",
895
+ "884": "vault",
896
+ "885": "velvet",
897
+ "886": "vending machine",
898
+ "887": "vestment",
899
+ "888": "viaduct",
900
+ "889": "violin, fiddle",
901
+ "890": "volleyball",
902
+ "891": "waffle iron",
903
+ "892": "wall clock",
904
+ "893": "wallet, billfold, notecase, pocketbook",
905
+ "894": "wardrobe, closet, press",
906
+ "895": "warplane, military plane",
907
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
908
+ "897": "washer, automatic washer, washing machine",
909
+ "898": "water bottle",
910
+ "899": "water jug",
911
+ "900": "water tower",
912
+ "901": "whiskey jug",
913
+ "902": "whistle",
914
+ "903": "wig",
915
+ "904": "window screen",
916
+ "905": "window shade",
917
+ "906": "Windsor tie",
918
+ "907": "wine bottle",
919
+ "908": "wing",
920
+ "909": "wok",
921
+ "910": "wooden spoon",
922
+ "911": "wool, woolen, woollen",
923
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
924
+ "913": "wreck",
925
+ "914": "yawl",
926
+ "915": "yurt",
927
+ "916": "web site, website, internet site, site",
928
+ "917": "comic book",
929
+ "918": "crossword puzzle, crossword",
930
+ "919": "street sign",
931
+ "920": "traffic light, traffic signal, stoplight",
932
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
933
+ "922": "menu",
934
+ "923": "plate",
935
+ "924": "guacamole",
936
+ "925": "consomme",
937
+ "926": "hot pot, hotpot",
938
+ "927": "trifle",
939
+ "928": "ice cream, icecream",
940
+ "929": "ice lolly, lolly, lollipop, popsicle",
941
+ "930": "French loaf",
942
+ "931": "bagel, beigel",
943
+ "932": "pretzel",
944
+ "933": "cheeseburger",
945
+ "934": "hotdog, hot dog, red hot",
946
+ "935": "mashed potato",
947
+ "936": "head cabbage",
948
+ "937": "broccoli",
949
+ "938": "cauliflower",
950
+ "939": "zucchini, courgette",
951
+ "940": "spaghetti squash",
952
+ "941": "acorn squash",
953
+ "942": "butternut squash",
954
+ "943": "cucumber, cuke",
955
+ "944": "artichoke, globe artichoke",
956
+ "945": "bell pepper",
957
+ "946": "cardoon",
958
+ "947": "mushroom",
959
+ "948": "Granny Smith",
960
+ "949": "strawberry",
961
+ "950": "orange",
962
+ "951": "lemon",
963
+ "952": "fig",
964
+ "953": "pineapple, ananas",
965
+ "954": "banana",
966
+ "955": "jackfruit, jak, jack",
967
+ "956": "custard apple",
968
+ "957": "pomegranate",
969
+ "958": "hay",
970
+ "959": "carbonara",
971
+ "960": "chocolate sauce, chocolate syrup",
972
+ "961": "dough",
973
+ "962": "meat loaf, meatloaf",
974
+ "963": "pizza, pizza pie",
975
+ "964": "potpie",
976
+ "965": "burrito",
977
+ "966": "red wine",
978
+ "967": "espresso",
979
+ "968": "cup",
980
+ "969": "eggnog",
981
+ "970": "alp",
982
+ "971": "bubble",
983
+ "972": "cliff, drop, drop-off",
984
+ "973": "coral reef",
985
+ "974": "geyser",
986
+ "975": "lakeside, lakeshore",
987
+ "976": "promontory, headland, head, foreland",
988
+ "977": "sandbar, sand bar",
989
+ "978": "seashore, coast, seacoast, sea-coast",
990
+ "979": "valley, vale",
991
+ "980": "volcano",
992
+ "981": "ballplayer, baseball player",
993
+ "982": "groom, bridegroom",
994
+ "983": "scuba diver",
995
+ "984": "rapeseed",
996
+ "985": "daisy",
997
+ "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
998
+ "987": "corn",
999
+ "988": "acorn",
1000
+ "989": "hip, rose hip, rosehip",
1001
+ "990": "buckeye, horse chestnut, conker",
1002
+ "991": "coral fungus",
1003
+ "992": "agaric",
1004
+ "993": "gyromitra",
1005
+ "994": "stinkhorn, carrion fungus",
1006
+ "995": "earthstar",
1007
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1008
+ "997": "bolete",
1009
+ "998": "ear, spike, capitulum",
1010
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1011
+ },
1012
+ "image_size": 224,
1013
+ "initializer_range": 0.02,
1014
+ "intermediate_size": 3072,
1015
+ "label2id": {
1016
+ "Afghan hound, Afghan": 160,
1017
+ "African chameleon, Chamaeleo chamaeleon": 47,
1018
+ "African crocodile, Nile crocodile, Crocodylus niloticus": 49,
1019
+ "African elephant, Loxodonta africana": 386,
1020
+ "African grey, African gray, Psittacus erithacus": 87,
1021
+ "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus": 275,
1022
+ "Airedale, Airedale terrier": 191,
1023
+ "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier": 180,
1024
+ "American alligator, Alligator mississipiensis": 50,
1025
+ "American black bear, black bear, Ursus americanus, Euarctos americanus": 295,
1026
+ "American chameleon, anole, Anolis carolinensis": 40,
1027
+ "American coot, marsh hen, mud hen, water hen, Fulica americana": 137,
1028
+ "American egret, great white heron, Egretta albus": 132,
1029
+ "American lobster, Northern lobster, Maine lobster, Homarus americanus": 122,
1030
+ "Angora, Angora rabbit": 332,
1031
+ "Appenzeller": 240,
1032
+ "Arabian camel, dromedary, Camelus dromedarius": 354,
1033
+ "Arctic fox, white fox, Alopex lagopus": 279,
1034
+ "Australian terrier": 193,
1035
+ "Band Aid": 419,
1036
+ "Bedlington terrier": 181,
1037
+ "Bernese mountain dog": 239,
1038
+ "Blenheim spaniel": 156,
1039
+ "Border collie": 232,
1040
+ "Border terrier": 182,
1041
+ "Boston bull, Boston terrier": 195,
1042
+ "Bouvier des Flandres, Bouviers des Flandres": 233,
1043
+ "Brabancon griffon": 262,
1044
+ "Brittany spaniel": 215,
1045
+ "CD player": 485,
1046
+ "Cardigan, Cardigan Welsh corgi": 264,
1047
+ "Chesapeake Bay retriever": 209,
1048
+ "Chihuahua": 151,
1049
+ "Christmas stocking": 496,
1050
+ "Crock Pot": 521,
1051
+ "Dandie Dinmont, Dandie Dinmont terrier": 194,
1052
+ "Doberman, Doberman pinscher": 236,
1053
+ "Dungeness crab, Cancer magister": 118,
1054
+ "Dutch oven": 544,
1055
+ "Egyptian cat": 285,
1056
+ "English foxhound": 167,
1057
+ "English setter": 212,
1058
+ "English springer, English springer spaniel": 217,
1059
+ "EntleBucher": 241,
1060
+ "Eskimo dog, husky": 248,
1061
+ "European fire salamander, Salamandra salamandra": 25,
1062
+ "European gallinule, Porphyrio porphyrio": 136,
1063
+ "French bulldog": 245,
1064
+ "French horn, horn": 566,
1065
+ "French loaf": 930,
1066
+ "German shepherd, German shepherd dog, German police dog, alsatian": 235,
1067
+ "German short-haired pointer": 210,
1068
+ "Gila monster, Heloderma suspectum": 45,
1069
+ "Gordon setter": 214,
1070
+ "Granny Smith": 948,
1071
+ "Great Dane": 246,
1072
+ "Great Pyrenees": 257,
1073
+ "Greater Swiss Mountain dog": 238,
1074
+ "Ibizan hound, Ibizan Podenco": 173,
1075
+ "Indian cobra, Naja naja": 63,
1076
+ "Indian elephant, Elephas maximus": 385,
1077
+ "Irish setter, red setter": 213,
1078
+ "Irish terrier": 184,
1079
+ "Irish water spaniel": 221,
1080
+ "Irish wolfhound": 170,
1081
+ "Italian greyhound": 171,
1082
+ "Japanese spaniel": 152,
1083
+ "Kerry blue terrier": 183,
1084
+ "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis": 48,
1085
+ "Labrador retriever": 208,
1086
+ "Lakeland terrier": 189,
1087
+ "Leonberg": 255,
1088
+ "Lhasa, Lhasa apso": 204,
1089
+ "Loafer": 630,
1090
+ "Madagascar cat, ring-tailed lemur, Lemur catta": 383,
1091
+ "Maltese dog, Maltese terrier, Maltese": 153,
1092
+ "Mexican hairless": 268,
1093
+ "Model T": 661,
1094
+ "Newfoundland, Newfoundland dog": 256,
1095
+ "Norfolk terrier": 185,
1096
+ "Norwegian elkhound, elkhound": 174,
1097
+ "Norwich terrier": 186,
1098
+ "Old English sheepdog, bobtail": 229,
1099
+ "Pekinese, Pekingese, Peke": 154,
1100
+ "Pembroke, Pembroke Welsh corgi": 263,
1101
+ "Persian cat": 283,
1102
+ "Petri dish": 712,
1103
+ "Polaroid camera, Polaroid Land camera": 732,
1104
+ "Pomeranian": 259,
1105
+ "Rhodesian ridgeback": 159,
1106
+ "Rottweiler": 234,
1107
+ "Saint Bernard, St Bernard": 247,
1108
+ "Saluki, gazelle hound": 176,
1109
+ "Samoyed, Samoyede": 258,
1110
+ "Scotch terrier, Scottish terrier, Scottie": 199,
1111
+ "Scottish deerhound, deerhound": 177,
1112
+ "Sealyham terrier, Sealyham": 190,
1113
+ "Shetland sheepdog, Shetland sheep dog, Shetland": 230,
1114
+ "Shih-Tzu": 155,
1115
+ "Siamese cat, Siamese": 284,
1116
+ "Siberian husky": 250,
1117
+ "Staffordshire bullterrier, Staffordshire bull terrier": 179,
1118
+ "Sussex spaniel": 220,
1119
+ "Tibetan mastiff": 244,
1120
+ "Tibetan terrier, chrysanthemum dog": 200,
1121
+ "Walker hound, Walker foxhound": 166,
1122
+ "Weimaraner": 178,
1123
+ "Welsh springer spaniel": 218,
1124
+ "West Highland white terrier": 203,
1125
+ "Windsor tie": 906,
1126
+ "Yorkshire terrier": 187,
1127
+ "abacus": 398,
1128
+ "abaya": 399,
1129
+ "academic gown, academic robe, judge's robe": 400,
1130
+ "accordion, piano accordion, squeeze box": 401,
1131
+ "acorn": 988,
1132
+ "acorn squash": 941,
1133
+ "acoustic guitar": 402,
1134
+ "admiral": 321,
1135
+ "affenpinscher, monkey pinscher, monkey dog": 252,
1136
+ "agama": 42,
1137
+ "agaric": 992,
1138
+ "aircraft carrier, carrier, flattop, attack aircraft carrier": 403,
1139
+ "airliner": 404,
1140
+ "airship, dirigible": 405,
1141
+ "albatross, mollymawk": 146,
1142
+ "alligator lizard": 44,
1143
+ "alp": 970,
1144
+ "altar": 406,
1145
+ "ambulance": 407,
1146
+ "amphibian, amphibious vehicle": 408,
1147
+ "analog clock": 409,
1148
+ "anemone fish": 393,
1149
+ "ant, emmet, pismire": 310,
1150
+ "apiary, bee house": 410,
1151
+ "apron": 411,
1152
+ "armadillo": 363,
1153
+ "artichoke, globe artichoke": 944,
1154
+ "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin": 412,
1155
+ "assault rifle, assault gun": 413,
1156
+ "axolotl, mud puppy, Ambystoma mexicanum": 29,
1157
+ "baboon": 372,
1158
+ "backpack, back pack, knapsack, packsack, rucksack, haversack": 414,
1159
+ "badger": 362,
1160
+ "bagel, beigel": 931,
1161
+ "bakery, bakeshop, bakehouse": 415,
1162
+ "balance beam, beam": 416,
1163
+ "bald eagle, American eagle, Haliaeetus leucocephalus": 22,
1164
+ "balloon": 417,
1165
+ "ballplayer, baseball player": 981,
1166
+ "ballpoint, ballpoint pen, ballpen, Biro": 418,
1167
+ "banana": 954,
1168
+ "banded gecko": 38,
1169
+ "banjo": 420,
1170
+ "bannister, banister, balustrade, balusters, handrail": 421,
1171
+ "barbell": 422,
1172
+ "barber chair": 423,
1173
+ "barbershop": 424,
1174
+ "barn": 425,
1175
+ "barn spider, Araneus cavaticus": 73,
1176
+ "barometer": 426,
1177
+ "barracouta, snoek": 389,
1178
+ "barrel, cask": 427,
1179
+ "barrow, garden cart, lawn cart, wheelbarrow": 428,
1180
+ "baseball": 429,
1181
+ "basenji": 253,
1182
+ "basketball": 430,
1183
+ "basset, basset hound": 161,
1184
+ "bassinet": 431,
1185
+ "bassoon": 432,
1186
+ "bath towel": 434,
1187
+ "bathing cap, swimming cap": 433,
1188
+ "bathtub, bathing tub, bath, tub": 435,
1189
+ "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon": 436,
1190
+ "beacon, lighthouse, beacon light, pharos": 437,
1191
+ "beagle": 162,
1192
+ "beaker": 438,
1193
+ "bearskin, busby, shako": 439,
1194
+ "beaver": 337,
1195
+ "bee": 309,
1196
+ "bee eater": 92,
1197
+ "beer bottle": 440,
1198
+ "beer glass": 441,
1199
+ "bell cote, bell cot": 442,
1200
+ "bell pepper": 945,
1201
+ "bib": 443,
1202
+ "bicycle-built-for-two, tandem bicycle, tandem": 444,
1203
+ "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis": 349,
1204
+ "bikini, two-piece": 445,
1205
+ "binder, ring-binder": 446,
1206
+ "binoculars, field glasses, opera glasses": 447,
1207
+ "birdhouse": 448,
1208
+ "bison": 347,
1209
+ "bittern": 133,
1210
+ "black and gold garden spider, Argiope aurantia": 72,
1211
+ "black grouse": 80,
1212
+ "black stork, Ciconia nigra": 128,
1213
+ "black swan, Cygnus atratus": 100,
1214
+ "black widow, Latrodectus mactans": 75,
1215
+ "black-and-tan coonhound": 165,
1216
+ "black-footed ferret, ferret, Mustela nigripes": 359,
1217
+ "bloodhound, sleuthhound": 163,
1218
+ "bluetick": 164,
1219
+ "boa constrictor, Constrictor constrictor": 61,
1220
+ "boathouse": 449,
1221
+ "bobsled, bobsleigh, bob": 450,
1222
+ "bolete": 997,
1223
+ "bolo tie, bolo, bola tie, bola": 451,
1224
+ "bonnet, poke bonnet": 452,
1225
+ "book jacket, dust cover, dust jacket, dust wrapper": 921,
1226
+ "bookcase": 453,
1227
+ "bookshop, bookstore, bookstall": 454,
1228
+ "borzoi, Russian wolfhound": 169,
1229
+ "bottlecap": 455,
1230
+ "bow": 456,
1231
+ "bow tie, bow-tie, bowtie": 457,
1232
+ "box turtle, box tortoise": 37,
1233
+ "boxer": 242,
1234
+ "brain coral": 109,
1235
+ "brambling, Fringilla montifringilla": 10,
1236
+ "brass, memorial tablet, plaque": 458,
1237
+ "brassiere, bra, bandeau": 459,
1238
+ "breakwater, groin, groyne, mole, bulwark, seawall, jetty": 460,
1239
+ "breastplate, aegis, egis": 461,
1240
+ "briard": 226,
1241
+ "broccoli": 937,
1242
+ "broom": 462,
1243
+ "brown bear, bruin, Ursus arctos": 294,
1244
+ "bubble": 971,
1245
+ "bucket, pail": 463,
1246
+ "buckeye, horse chestnut, conker": 990,
1247
+ "buckle": 464,
1248
+ "bulbul": 16,
1249
+ "bull mastiff": 243,
1250
+ "bullet train, bullet": 466,
1251
+ "bulletproof vest": 465,
1252
+ "bullfrog, Rana catesbeiana": 30,
1253
+ "burrito": 965,
1254
+ "bustard": 138,
1255
+ "butcher shop, meat market": 467,
1256
+ "butternut squash": 942,
1257
+ "cab, hack, taxi, taxicab": 468,
1258
+ "cabbage butterfly": 324,
1259
+ "cairn, cairn terrier": 192,
1260
+ "caldron, cauldron": 469,
1261
+ "can opener, tin opener": 473,
1262
+ "candle, taper, wax light": 470,
1263
+ "cannon": 471,
1264
+ "canoe": 472,
1265
+ "capuchin, ringtail, Cebus capucinus": 378,
1266
+ "car mirror": 475,
1267
+ "car wheel": 479,
1268
+ "carbonara": 959,
1269
+ "cardigan": 474,
1270
+ "cardoon": 946,
1271
+ "carousel, carrousel, merry-go-round, roundabout, whirligig": 476,
1272
+ "carpenter's kit, tool kit": 477,
1273
+ "carton": 478,
1274
+ "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM": 480,
1275
+ "cassette": 481,
1276
+ "cassette player": 482,
1277
+ "castle": 483,
1278
+ "catamaran": 484,
1279
+ "cauliflower": 938,
1280
+ "cello, violoncello": 486,
1281
+ "cellular telephone, cellular phone, cellphone, cell, mobile phone": 487,
1282
+ "centipede": 79,
1283
+ "chain": 488,
1284
+ "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour": 490,
1285
+ "chain saw, chainsaw": 491,
1286
+ "chainlink fence": 489,
1287
+ "chambered nautilus, pearly nautilus, nautilus": 117,
1288
+ "cheeseburger": 933,
1289
+ "cheetah, chetah, Acinonyx jubatus": 293,
1290
+ "chest": 492,
1291
+ "chickadee": 19,
1292
+ "chiffonier, commode": 493,
1293
+ "chime, bell, gong": 494,
1294
+ "chimpanzee, chimp, Pan troglodytes": 367,
1295
+ "china cabinet, china closet": 495,
1296
+ "chiton, coat-of-mail shell, sea cradle, polyplacophore": 116,
1297
+ "chocolate sauce, chocolate syrup": 960,
1298
+ "chow, chow chow": 260,
1299
+ "church, church building": 497,
1300
+ "cicada, cicala": 316,
1301
+ "cinema, movie theater, movie theatre, movie house, picture palace": 498,
1302
+ "cleaver, meat cleaver, chopper": 499,
1303
+ "cliff dwelling": 500,
1304
+ "cliff, drop, drop-off": 972,
1305
+ "cloak": 501,
1306
+ "clog, geta, patten, sabot": 502,
1307
+ "clumber, clumber spaniel": 216,
1308
+ "cock": 7,
1309
+ "cocker spaniel, English cocker spaniel, cocker": 219,
1310
+ "cockroach, roach": 314,
1311
+ "cocktail shaker": 503,
1312
+ "coffee mug": 504,
1313
+ "coffeepot": 505,
1314
+ "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch": 391,
1315
+ "coil, spiral, volute, whorl, helix": 506,
1316
+ "collie": 231,
1317
+ "colobus, colobus monkey": 375,
1318
+ "combination lock": 507,
1319
+ "comic book": 917,
1320
+ "common iguana, iguana, Iguana iguana": 39,
1321
+ "common newt, Triturus vulgaris": 26,
1322
+ "computer keyboard, keypad": 508,
1323
+ "conch": 112,
1324
+ "confectionery, confectionary, candy store": 509,
1325
+ "consomme": 925,
1326
+ "container ship, containership, container vessel": 510,
1327
+ "convertible": 511,
1328
+ "coral fungus": 991,
1329
+ "coral reef": 973,
1330
+ "corkscrew, bottle screw": 512,
1331
+ "corn": 987,
1332
+ "cornet, horn, trumpet, trump": 513,
1333
+ "coucal": 91,
1334
+ "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor": 286,
1335
+ "cowboy boot": 514,
1336
+ "cowboy hat, ten-gallon hat": 515,
1337
+ "coyote, prairie wolf, brush wolf, Canis latrans": 272,
1338
+ "cradle": 516,
1339
+ "crane": 517,
1340
+ "crash helmet": 518,
1341
+ "crate": 519,
1342
+ "crayfish, crawfish, crawdad, crawdaddy": 124,
1343
+ "crib, cot": 520,
1344
+ "cricket": 312,
1345
+ "croquet ball": 522,
1346
+ "crossword puzzle, crossword": 918,
1347
+ "crutch": 523,
1348
+ "cucumber, cuke": 943,
1349
+ "cuirass": 524,
1350
+ "cup": 968,
1351
+ "curly-coated retriever": 206,
1352
+ "custard apple": 956,
1353
+ "daisy": 985,
1354
+ "dalmatian, coach dog, carriage dog": 251,
1355
+ "dam, dike, dyke": 525,
1356
+ "damselfly": 320,
1357
+ "desk": 526,
1358
+ "desktop computer": 527,
1359
+ "dhole, Cuon alpinus": 274,
1360
+ "dial telephone, dial phone": 528,
1361
+ "diamondback, diamondback rattlesnake, Crotalus adamanteus": 67,
1362
+ "diaper, nappy, napkin": 529,
1363
+ "digital clock": 530,
1364
+ "digital watch": 531,
1365
+ "dingo, warrigal, warragal, Canis dingo": 273,
1366
+ "dining table, board": 532,
1367
+ "dishrag, dishcloth": 533,
1368
+ "dishwasher, dish washer, dishwashing machine": 534,
1369
+ "disk brake, disc brake": 535,
1370
+ "dock, dockage, docking facility": 536,
1371
+ "dogsled, dog sled, dog sleigh": 537,
1372
+ "dome": 538,
1373
+ "doormat, welcome mat": 539,
1374
+ "dough": 961,
1375
+ "dowitcher": 142,
1376
+ "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk": 319,
1377
+ "drake": 97,
1378
+ "drilling platform, offshore rig": 540,
1379
+ "drum, membranophone, tympan": 541,
1380
+ "drumstick": 542,
1381
+ "dugong, Dugong dugon": 149,
1382
+ "dumbbell": 543,
1383
+ "dung beetle": 305,
1384
+ "ear, spike, capitulum": 998,
1385
+ "earthstar": 995,
1386
+ "echidna, spiny anteater, anteater": 102,
1387
+ "eel": 390,
1388
+ "eft": 27,
1389
+ "eggnog": 969,
1390
+ "electric fan, blower": 545,
1391
+ "electric guitar": 546,
1392
+ "electric locomotive": 547,
1393
+ "electric ray, crampfish, numbfish, torpedo": 5,
1394
+ "entertainment center": 548,
1395
+ "envelope": 549,
1396
+ "espresso": 967,
1397
+ "espresso maker": 550,
1398
+ "face powder": 551,
1399
+ "feather boa, boa": 552,
1400
+ "fiddler crab": 120,
1401
+ "fig": 952,
1402
+ "file, file cabinet, filing cabinet": 553,
1403
+ "fire engine, fire truck": 555,
1404
+ "fire screen, fireguard": 556,
1405
+ "fireboat": 554,
1406
+ "flagpole, flagstaff": 557,
1407
+ "flamingo": 130,
1408
+ "flat-coated retriever": 205,
1409
+ "flatworm, platyhelminth": 110,
1410
+ "flute, transverse flute": 558,
1411
+ "fly": 308,
1412
+ "folding chair": 559,
1413
+ "football helmet": 560,
1414
+ "forklift": 561,
1415
+ "fountain": 562,
1416
+ "fountain pen": 563,
1417
+ "four-poster": 564,
1418
+ "fox squirrel, eastern fox squirrel, Sciurus niger": 335,
1419
+ "freight car": 565,
1420
+ "frilled lizard, Chlamydosaurus kingi": 43,
1421
+ "frying pan, frypan, skillet": 567,
1422
+ "fur coat": 568,
1423
+ "gar, garfish, garpike, billfish, Lepisosteus osseus": 395,
1424
+ "garbage truck, dustcart": 569,
1425
+ "garden spider, Aranea diademata": 74,
1426
+ "garter snake, grass snake": 57,
1427
+ "gas pump, gasoline pump, petrol pump, island dispenser": 571,
1428
+ "gasmask, respirator, gas helmet": 570,
1429
+ "gazelle": 353,
1430
+ "geyser": 974,
1431
+ "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca": 388,
1432
+ "giant schnauzer": 197,
1433
+ "gibbon, Hylobates lar": 368,
1434
+ "go-kart": 573,
1435
+ "goblet": 572,
1436
+ "golden retriever": 207,
1437
+ "goldfinch, Carduelis carduelis": 11,
1438
+ "goldfish, Carassius auratus": 1,
1439
+ "golf ball": 574,
1440
+ "golfcart, golf cart": 575,
1441
+ "gondola": 576,
1442
+ "gong, tam-tam": 577,
1443
+ "goose": 99,
1444
+ "gorilla, Gorilla gorilla": 366,
1445
+ "gown": 578,
1446
+ "grand piano, grand": 579,
1447
+ "grasshopper, hopper": 311,
1448
+ "great grey owl, great gray owl, Strix nebulosa": 24,
1449
+ "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias": 2,
1450
+ "green lizard, Lacerta viridis": 46,
1451
+ "green mamba": 64,
1452
+ "green snake, grass snake": 55,
1453
+ "greenhouse, nursery, glasshouse": 580,
1454
+ "grey fox, gray fox, Urocyon cinereoargenteus": 280,
1455
+ "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus": 147,
1456
+ "grille, radiator grille": 581,
1457
+ "grocery store, grocery, food market, market": 582,
1458
+ "groenendael": 224,
1459
+ "groom, bridegroom": 982,
1460
+ "ground beetle, carabid beetle": 302,
1461
+ "guacamole": 924,
1462
+ "guenon, guenon monkey": 370,
1463
+ "guillotine": 583,
1464
+ "guinea pig, Cavia cobaya": 338,
1465
+ "gyromitra": 993,
1466
+ "hair slide": 584,
1467
+ "hair spray": 585,
1468
+ "half track": 586,
1469
+ "hammer": 587,
1470
+ "hammerhead, hammerhead shark": 4,
1471
+ "hamper": 588,
1472
+ "hamster": 333,
1473
+ "hand blower, blow dryer, blow drier, hair dryer, hair drier": 589,
1474
+ "hand-held computer, hand-held microcomputer": 590,
1475
+ "handkerchief, hankie, hanky, hankey": 591,
1476
+ "hard disc, hard disk, fixed disk": 592,
1477
+ "hare": 331,
1478
+ "harmonica, mouth organ, harp, mouth harp": 593,
1479
+ "harp": 594,
1480
+ "hartebeest": 351,
1481
+ "harvester, reaper": 595,
1482
+ "harvestman, daddy longlegs, Phalangium opilio": 70,
1483
+ "hatchet": 596,
1484
+ "hay": 958,
1485
+ "head cabbage": 936,
1486
+ "hen": 8,
1487
+ "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa": 996,
1488
+ "hermit crab": 125,
1489
+ "hip, rose hip, rosehip": 989,
1490
+ "hippopotamus, hippo, river horse, Hippopotamus amphibius": 344,
1491
+ "hog, pig, grunter, squealer, Sus scrofa": 341,
1492
+ "hognose snake, puff adder, sand viper": 54,
1493
+ "holster": 597,
1494
+ "home theater, home theatre": 598,
1495
+ "honeycomb": 599,
1496
+ "hook, claw": 600,
1497
+ "hoopskirt, crinoline": 601,
1498
+ "horizontal bar, high bar": 602,
1499
+ "hornbill": 93,
1500
+ "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus": 66,
1501
+ "horse cart, horse-cart": 603,
1502
+ "hot pot, hotpot": 926,
1503
+ "hotdog, hot dog, red hot": 934,
1504
+ "hourglass": 604,
1505
+ "house finch, linnet, Carpodacus mexicanus": 12,
1506
+ "howler monkey, howler": 379,
1507
+ "hummingbird": 94,
1508
+ "hyena, hyaena": 276,
1509
+ "iPod": 605,
1510
+ "ibex, Capra ibex": 350,
1511
+ "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus": 296,
1512
+ "ice cream, icecream": 928,
1513
+ "ice lolly, lolly, lollipop, popsicle": 929,
1514
+ "impala, Aepyceros melampus": 352,
1515
+ "indigo bunting, indigo finch, indigo bird, Passerina cyanea": 14,
1516
+ "indri, indris, Indri indri, Indri brevicaudatus": 384,
1517
+ "iron, smoothing iron": 606,
1518
+ "isopod": 126,
1519
+ "jacamar": 95,
1520
+ "jack-o'-lantern": 607,
1521
+ "jackfruit, jak, jack": 955,
1522
+ "jaguar, panther, Panthera onca, Felis onca": 290,
1523
+ "jay": 17,
1524
+ "jean, blue jean, denim": 608,
1525
+ "jeep, landrover": 609,
1526
+ "jellyfish": 107,
1527
+ "jersey, T-shirt, tee shirt": 610,
1528
+ "jigsaw puzzle": 611,
1529
+ "jinrikisha, ricksha, rickshaw": 612,
1530
+ "joystick": 613,
1531
+ "junco, snowbird": 13,
1532
+ "keeshond": 261,
1533
+ "kelpie": 227,
1534
+ "killer whale, killer, orca, grampus, sea wolf, Orcinus orca": 148,
1535
+ "kimono": 614,
1536
+ "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica": 121,
1537
+ "king penguin, Aptenodytes patagonica": 145,
1538
+ "king snake, kingsnake": 56,
1539
+ "kit fox, Vulpes macrotis": 278,
1540
+ "kite": 21,
1541
+ "knee pad": 615,
1542
+ "knot": 616,
1543
+ "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus": 105,
1544
+ "komondor": 228,
1545
+ "kuvasz": 222,
1546
+ "lab coat, laboratory coat": 617,
1547
+ "lacewing, lacewing fly": 318,
1548
+ "ladle": 618,
1549
+ "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle": 301,
1550
+ "lakeside, lakeshore": 975,
1551
+ "lampshade, lamp shade": 619,
1552
+ "langur": 374,
1553
+ "laptop, laptop computer": 620,
1554
+ "lawn mower, mower": 621,
1555
+ "leaf beetle, chrysomelid": 304,
1556
+ "leafhopper": 317,
1557
+ "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea": 34,
1558
+ "lemon": 951,
1559
+ "lens cap, lens cover": 622,
1560
+ "leopard, Panthera pardus": 288,
1561
+ "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens": 387,
1562
+ "letter opener, paper knife, paperknife": 623,
1563
+ "library": 624,
1564
+ "lifeboat": 625,
1565
+ "lighter, light, igniter, ignitor": 626,
1566
+ "limousine, limo": 627,
1567
+ "limpkin, Aramus pictus": 135,
1568
+ "liner, ocean liner": 628,
1569
+ "lion, king of beasts, Panthera leo": 291,
1570
+ "lionfish": 396,
1571
+ "lipstick, lip rouge": 629,
1572
+ "little blue heron, Egretta caerulea": 131,
1573
+ "llama": 355,
1574
+ "loggerhead, loggerhead turtle, Caretta caretta": 33,
1575
+ "long-horned beetle, longicorn, longicorn beetle": 303,
1576
+ "lorikeet": 90,
1577
+ "lotion": 631,
1578
+ "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system": 632,
1579
+ "loupe, jeweler's loupe": 633,
1580
+ "lumbermill, sawmill": 634,
1581
+ "lycaenid, lycaenid butterfly": 326,
1582
+ "lynx, catamount": 287,
1583
+ "macaque": 373,
1584
+ "macaw": 88,
1585
+ "magnetic compass": 635,
1586
+ "magpie": 18,
1587
+ "mailbag, postbag": 636,
1588
+ "mailbox, letter box": 637,
1589
+ "maillot": 638,
1590
+ "maillot, tank suit": 639,
1591
+ "malamute, malemute, Alaskan malamute": 249,
1592
+ "malinois": 225,
1593
+ "manhole cover": 640,
1594
+ "mantis, mantid": 315,
1595
+ "maraca": 641,
1596
+ "marimba, xylophone": 642,
1597
+ "marmoset": 377,
1598
+ "marmot": 336,
1599
+ "mashed potato": 935,
1600
+ "mask": 643,
1601
+ "matchstick": 644,
1602
+ "maypole": 645,
1603
+ "maze, labyrinth": 646,
1604
+ "measuring cup": 647,
1605
+ "meat loaf, meatloaf": 962,
1606
+ "medicine chest, medicine cabinet": 648,
1607
+ "meerkat, mierkat": 299,
1608
+ "megalith, megalithic structure": 649,
1609
+ "menu": 922,
1610
+ "microphone, mike": 650,
1611
+ "microwave, microwave oven": 651,
1612
+ "military uniform": 652,
1613
+ "milk can": 653,
1614
+ "miniature pinscher": 237,
1615
+ "miniature poodle": 266,
1616
+ "miniature schnauzer": 196,
1617
+ "minibus": 654,
1618
+ "miniskirt, mini": 655,
1619
+ "minivan": 656,
1620
+ "mink": 357,
1621
+ "missile": 657,
1622
+ "mitten": 658,
1623
+ "mixing bowl": 659,
1624
+ "mobile home, manufactured home": 660,
1625
+ "modem": 662,
1626
+ "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus": 323,
1627
+ "monastery": 663,
1628
+ "mongoose": 298,
1629
+ "monitor": 664,
1630
+ "moped": 665,
1631
+ "mortar": 666,
1632
+ "mortarboard": 667,
1633
+ "mosque": 668,
1634
+ "mosquito net": 669,
1635
+ "motor scooter, scooter": 670,
1636
+ "mountain bike, all-terrain bike, off-roader": 671,
1637
+ "mountain tent": 672,
1638
+ "mouse, computer mouse": 673,
1639
+ "mousetrap": 674,
1640
+ "moving van": 675,
1641
+ "mud turtle": 35,
1642
+ "mushroom": 947,
1643
+ "muzzle": 676,
1644
+ "nail": 677,
1645
+ "neck brace": 678,
1646
+ "necklace": 679,
1647
+ "nematode, nematode worm, roundworm": 111,
1648
+ "night snake, Hypsiglena torquata": 60,
1649
+ "nipple": 680,
1650
+ "notebook, notebook computer": 681,
1651
+ "obelisk": 682,
1652
+ "oboe, hautboy, hautbois": 683,
1653
+ "ocarina, sweet potato": 684,
1654
+ "odometer, hodometer, mileometer, milometer": 685,
1655
+ "oil filter": 686,
1656
+ "orange": 950,
1657
+ "orangutan, orang, orangutang, Pongo pygmaeus": 365,
1658
+ "organ, pipe organ": 687,
1659
+ "oscilloscope, scope, cathode-ray oscilloscope, CRO": 688,
1660
+ "ostrich, Struthio camelus": 9,
1661
+ "otter": 360,
1662
+ "otterhound, otter hound": 175,
1663
+ "overskirt": 689,
1664
+ "ox": 345,
1665
+ "oxcart": 690,
1666
+ "oxygen mask": 691,
1667
+ "oystercatcher, oyster catcher": 143,
1668
+ "packet": 692,
1669
+ "paddle, boat paddle": 693,
1670
+ "paddlewheel, paddle wheel": 694,
1671
+ "padlock": 695,
1672
+ "paintbrush": 696,
1673
+ "pajama, pyjama, pj's, jammies": 697,
1674
+ "palace": 698,
1675
+ "panpipe, pandean pipe, syrinx": 699,
1676
+ "paper towel": 700,
1677
+ "papillon": 157,
1678
+ "parachute, chute": 701,
1679
+ "parallel bars, bars": 702,
1680
+ "park bench": 703,
1681
+ "parking meter": 704,
1682
+ "partridge": 86,
1683
+ "passenger car, coach, carriage": 705,
1684
+ "patas, hussar monkey, Erythrocebus patas": 371,
1685
+ "patio, terrace": 706,
1686
+ "pay-phone, pay-station": 707,
1687
+ "peacock": 84,
1688
+ "pedestal, plinth, footstall": 708,
1689
+ "pelican": 144,
1690
+ "pencil box, pencil case": 709,
1691
+ "pencil sharpener": 710,
1692
+ "perfume, essence": 711,
1693
+ "photocopier": 713,
1694
+ "pick, plectrum, plectron": 714,
1695
+ "pickelhaube": 715,
1696
+ "picket fence, paling": 716,
1697
+ "pickup, pickup truck": 717,
1698
+ "pier": 718,
1699
+ "piggy bank, penny bank": 719,
1700
+ "pill bottle": 720,
1701
+ "pillow": 721,
1702
+ "pineapple, ananas": 953,
1703
+ "ping-pong ball": 722,
1704
+ "pinwheel": 723,
1705
+ "pirate, pirate ship": 724,
1706
+ "pitcher, ewer": 725,
1707
+ "pizza, pizza pie": 963,
1708
+ "plane, carpenter's plane, woodworking plane": 726,
1709
+ "planetarium": 727,
1710
+ "plastic bag": 728,
1711
+ "plate": 923,
1712
+ "plate rack": 729,
1713
+ "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus": 103,
1714
+ "plow, plough": 730,
1715
+ "plunger, plumber's helper": 731,
1716
+ "pole": 733,
1717
+ "polecat, fitch, foulmart, foumart, Mustela putorius": 358,
1718
+ "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria": 734,
1719
+ "pomegranate": 957,
1720
+ "poncho": 735,
1721
+ "pool table, billiard table, snooker table": 736,
1722
+ "pop bottle, soda bottle": 737,
1723
+ "porcupine, hedgehog": 334,
1724
+ "pot, flowerpot": 738,
1725
+ "potpie": 964,
1726
+ "potter's wheel": 739,
1727
+ "power drill": 740,
1728
+ "prairie chicken, prairie grouse, prairie fowl": 83,
1729
+ "prayer rug, prayer mat": 741,
1730
+ "pretzel": 932,
1731
+ "printer": 742,
1732
+ "prison, prison house": 743,
1733
+ "proboscis monkey, Nasalis larvatus": 376,
1734
+ "projectile, missile": 744,
1735
+ "projector": 745,
1736
+ "promontory, headland, head, foreland": 976,
1737
+ "ptarmigan": 81,
1738
+ "puck, hockey puck": 746,
1739
+ "puffer, pufferfish, blowfish, globefish": 397,
1740
+ "pug, pug-dog": 254,
1741
+ "punching bag, punch bag, punching ball, punchball": 747,
1742
+ "purse": 748,
1743
+ "quail": 85,
1744
+ "quill, quill pen": 749,
1745
+ "quilt, comforter, comfort, puff": 750,
1746
+ "racer, race car, racing car": 751,
1747
+ "racket, racquet": 752,
1748
+ "radiator": 753,
1749
+ "radio telescope, radio reflector": 755,
1750
+ "radio, wireless": 754,
1751
+ "rain barrel": 756,
1752
+ "ram, tup": 348,
1753
+ "rapeseed": 984,
1754
+ "recreational vehicle, RV, R.V.": 757,
1755
+ "red fox, Vulpes vulpes": 277,
1756
+ "red wine": 966,
1757
+ "red wolf, maned wolf, Canis rufus, Canis niger": 271,
1758
+ "red-backed sandpiper, dunlin, Erolia alpina": 140,
1759
+ "red-breasted merganser, Mergus serrator": 98,
1760
+ "redbone": 168,
1761
+ "redshank, Tringa totanus": 141,
1762
+ "reel": 758,
1763
+ "reflex camera": 759,
1764
+ "refrigerator, icebox": 760,
1765
+ "remote control, remote": 761,
1766
+ "restaurant, eating house, eating place, eatery": 762,
1767
+ "revolver, six-gun, six-shooter": 763,
1768
+ "rhinoceros beetle": 306,
1769
+ "rifle": 764,
1770
+ "ringlet, ringlet butterfly": 322,
1771
+ "ringneck snake, ring-necked snake, ring snake": 53,
1772
+ "robin, American robin, Turdus migratorius": 15,
1773
+ "rock beauty, Holocanthus tricolor": 392,
1774
+ "rock crab, Cancer irroratus": 119,
1775
+ "rock python, rock snake, Python sebae": 62,
1776
+ "rocking chair, rocker": 765,
1777
+ "rotisserie": 766,
1778
+ "rubber eraser, rubber, pencil eraser": 767,
1779
+ "ruddy turnstone, Arenaria interpres": 139,
1780
+ "ruffed grouse, partridge, Bonasa umbellus": 82,
1781
+ "rugby ball": 768,
1782
+ "rule, ruler": 769,
1783
+ "running shoe": 770,
1784
+ "safe": 771,
1785
+ "safety pin": 772,
1786
+ "saltshaker, salt shaker": 773,
1787
+ "sandal": 774,
1788
+ "sandbar, sand bar": 977,
1789
+ "sarong": 775,
1790
+ "sax, saxophone": 776,
1791
+ "scabbard": 777,
1792
+ "scale, weighing machine": 778,
1793
+ "schipperke": 223,
1794
+ "school bus": 779,
1795
+ "schooner": 780,
1796
+ "scoreboard": 781,
1797
+ "scorpion": 71,
1798
+ "screen, CRT screen": 782,
1799
+ "screw": 783,
1800
+ "screwdriver": 784,
1801
+ "scuba diver": 983,
1802
+ "sea anemone, anemone": 108,
1803
+ "sea cucumber, holothurian": 329,
1804
+ "sea lion": 150,
1805
+ "sea slug, nudibranch": 115,
1806
+ "sea snake": 65,
1807
+ "sea urchin": 328,
1808
+ "seashore, coast, seacoast, sea-coast": 978,
1809
+ "seat belt, seatbelt": 785,
1810
+ "sewing machine": 786,
1811
+ "shield, buckler": 787,
1812
+ "shoe shop, shoe-shop, shoe store": 788,
1813
+ "shoji": 789,
1814
+ "shopping basket": 790,
1815
+ "shopping cart": 791,
1816
+ "shovel": 792,
1817
+ "shower cap": 793,
1818
+ "shower curtain": 794,
1819
+ "siamang, Hylobates syndactylus, Symphalangus syndactylus": 369,
1820
+ "sidewinder, horned rattlesnake, Crotalus cerastes": 68,
1821
+ "silky terrier, Sydney silky": 201,
1822
+ "ski": 795,
1823
+ "ski mask": 796,
1824
+ "skunk, polecat, wood pussy": 361,
1825
+ "sleeping bag": 797,
1826
+ "slide rule, slipstick": 798,
1827
+ "sliding door": 799,
1828
+ "slot, one-armed bandit": 800,
1829
+ "sloth bear, Melursus ursinus, Ursus ursinus": 297,
1830
+ "slug": 114,
1831
+ "snail": 113,
1832
+ "snorkel": 801,
1833
+ "snow leopard, ounce, Panthera uncia": 289,
1834
+ "snowmobile": 802,
1835
+ "snowplow, snowplough": 803,
1836
+ "soap dispenser": 804,
1837
+ "soccer ball": 805,
1838
+ "sock": 806,
1839
+ "soft-coated wheaten terrier": 202,
1840
+ "solar dish, solar collector, solar furnace": 807,
1841
+ "sombrero": 808,
1842
+ "sorrel": 339,
1843
+ "soup bowl": 809,
1844
+ "space bar": 810,
1845
+ "space heater": 811,
1846
+ "space shuttle": 812,
1847
+ "spaghetti squash": 940,
1848
+ "spatula": 813,
1849
+ "speedboat": 814,
1850
+ "spider monkey, Ateles geoffroyi": 381,
1851
+ "spider web, spider's web": 815,
1852
+ "spindle": 816,
1853
+ "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish": 123,
1854
+ "spoonbill": 129,
1855
+ "sports car, sport car": 817,
1856
+ "spotlight, spot": 818,
1857
+ "spotted salamander, Ambystoma maculatum": 28,
1858
+ "squirrel monkey, Saimiri sciureus": 382,
1859
+ "stage": 819,
1860
+ "standard poodle": 267,
1861
+ "standard schnauzer": 198,
1862
+ "starfish, sea star": 327,
1863
+ "steam locomotive": 820,
1864
+ "steel arch bridge": 821,
1865
+ "steel drum": 822,
1866
+ "stethoscope": 823,
1867
+ "stingray": 6,
1868
+ "stinkhorn, carrion fungus": 994,
1869
+ "stole": 824,
1870
+ "stone wall": 825,
1871
+ "stopwatch, stop watch": 826,
1872
+ "stove": 827,
1873
+ "strainer": 828,
1874
+ "strawberry": 949,
1875
+ "street sign": 919,
1876
+ "streetcar, tram, tramcar, trolley, trolley car": 829,
1877
+ "stretcher": 830,
1878
+ "studio couch, day bed": 831,
1879
+ "stupa, tope": 832,
1880
+ "sturgeon": 394,
1881
+ "submarine, pigboat, sub, U-boat": 833,
1882
+ "suit, suit of clothes": 834,
1883
+ "sulphur butterfly, sulfur butterfly": 325,
1884
+ "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita": 89,
1885
+ "sundial": 835,
1886
+ "sunglass": 836,
1887
+ "sunglasses, dark glasses, shades": 837,
1888
+ "sunscreen, sunblock, sun blocker": 838,
1889
+ "suspension bridge": 839,
1890
+ "swab, swob, mop": 840,
1891
+ "sweatshirt": 841,
1892
+ "swimming trunks, bathing trunks": 842,
1893
+ "swing": 843,
1894
+ "switch, electric switch, electrical switch": 844,
1895
+ "syringe": 845,
1896
+ "tabby, tabby cat": 281,
1897
+ "table lamp": 846,
1898
+ "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui": 32,
1899
+ "tank, army tank, armored combat vehicle, armoured combat vehicle": 847,
1900
+ "tape player": 848,
1901
+ "tarantula": 76,
1902
+ "teapot": 849,
1903
+ "teddy, teddy bear": 850,
1904
+ "television, television system": 851,
1905
+ "tench, Tinca tinca": 0,
1906
+ "tennis ball": 852,
1907
+ "terrapin": 36,
1908
+ "thatch, thatched roof": 853,
1909
+ "theater curtain, theatre curtain": 854,
1910
+ "thimble": 855,
1911
+ "three-toed sloth, ai, Bradypus tridactylus": 364,
1912
+ "thresher, thrasher, threshing machine": 856,
1913
+ "throne": 857,
1914
+ "thunder snake, worm snake, Carphophis amoenus": 52,
1915
+ "tick": 78,
1916
+ "tiger beetle": 300,
1917
+ "tiger cat": 282,
1918
+ "tiger shark, Galeocerdo cuvieri": 3,
1919
+ "tiger, Panthera tigris": 292,
1920
+ "tile roof": 858,
1921
+ "timber wolf, grey wolf, gray wolf, Canis lupus": 269,
1922
+ "titi, titi monkey": 380,
1923
+ "toaster": 859,
1924
+ "tobacco shop, tobacconist shop, tobacconist": 860,
1925
+ "toilet seat": 861,
1926
+ "toilet tissue, toilet paper, bathroom tissue": 999,
1927
+ "torch": 862,
1928
+ "totem pole": 863,
1929
+ "toucan": 96,
1930
+ "tow truck, tow car, wrecker": 864,
1931
+ "toy poodle": 265,
1932
+ "toy terrier": 158,
1933
+ "toyshop": 865,
1934
+ "tractor": 866,
1935
+ "traffic light, traffic signal, stoplight": 920,
1936
+ "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi": 867,
1937
+ "tray": 868,
1938
+ "tree frog, tree-frog": 31,
1939
+ "trench coat": 869,
1940
+ "triceratops": 51,
1941
+ "tricycle, trike, velocipede": 870,
1942
+ "trifle": 927,
1943
+ "trilobite": 69,
1944
+ "trimaran": 871,
1945
+ "tripod": 872,
1946
+ "triumphal arch": 873,
1947
+ "trolleybus, trolley coach, trackless trolley": 874,
1948
+ "trombone": 875,
1949
+ "tub, vat": 876,
1950
+ "turnstile": 877,
1951
+ "tusker": 101,
1952
+ "typewriter keyboard": 878,
1953
+ "umbrella": 879,
1954
+ "unicycle, monocycle": 880,
1955
+ "upright, upright piano": 881,
1956
+ "vacuum, vacuum cleaner": 882,
1957
+ "valley, vale": 979,
1958
+ "vase": 883,
1959
+ "vault": 884,
1960
+ "velvet": 885,
1961
+ "vending machine": 886,
1962
+ "vestment": 887,
1963
+ "viaduct": 888,
1964
+ "vine snake": 59,
1965
+ "violin, fiddle": 889,
1966
+ "vizsla, Hungarian pointer": 211,
1967
+ "volcano": 980,
1968
+ "volleyball": 890,
1969
+ "vulture": 23,
1970
+ "waffle iron": 891,
1971
+ "walking stick, walkingstick, stick insect": 313,
1972
+ "wall clock": 892,
1973
+ "wallaby, brush kangaroo": 104,
1974
+ "wallet, billfold, notecase, pocketbook": 893,
1975
+ "wardrobe, closet, press": 894,
1976
+ "warplane, military plane": 895,
1977
+ "warthog": 343,
1978
+ "washbasin, handbasin, washbowl, lavabo, wash-hand basin": 896,
1979
+ "washer, automatic washer, washing machine": 897,
1980
+ "water bottle": 898,
1981
+ "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis": 346,
1982
+ "water jug": 899,
1983
+ "water ouzel, dipper": 20,
1984
+ "water snake": 58,
1985
+ "water tower": 900,
1986
+ "weasel": 356,
1987
+ "web site, website, internet site, site": 916,
1988
+ "weevil": 307,
1989
+ "whippet": 172,
1990
+ "whiptail, whiptail lizard": 41,
1991
+ "whiskey jug": 901,
1992
+ "whistle": 902,
1993
+ "white stork, Ciconia ciconia": 127,
1994
+ "white wolf, Arctic wolf, Canis lupus tundrarum": 270,
1995
+ "wig": 903,
1996
+ "wild boar, boar, Sus scrofa": 342,
1997
+ "window screen": 904,
1998
+ "window shade": 905,
1999
+ "wine bottle": 907,
2000
+ "wing": 908,
2001
+ "wire-haired fox terrier": 188,
2002
+ "wok": 909,
2003
+ "wolf spider, hunting spider": 77,
2004
+ "wombat": 106,
2005
+ "wood rabbit, cottontail, cottontail rabbit": 330,
2006
+ "wooden spoon": 910,
2007
+ "wool, woolen, woollen": 911,
2008
+ "worm fence, snake fence, snake-rail fence, Virginia fence": 912,
2009
+ "wreck": 913,
2010
+ "yawl": 914,
2011
+ "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum": 986,
2012
+ "yurt": 915,
2013
+ "zebra": 340,
2014
+ "zucchini, courgette": 939
2015
+ },
2016
+ "layer_norm_eps": 1e-12,
2017
+ "model_type": "vit",
2018
+ "num_attention_heads": 12,
2019
+ "num_channels": 3,
2020
+ "num_hidden_layers": 12,
2021
+ "patch_size": 16,
2022
+ "pooler_act": "tanh",
2023
+ "pooler_output_size": 768,
2024
+ "qkv_bias": true,
2025
+ "transformers_version": "4.57.1"
2026
+ }
core/.ipynb_checkpoints/distill-checkpoint.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Knowledge-distillation utilities (model-family agnostic).
2
+
3
+ This module provides:
4
+ - Losses: KL distillation, soft cross-entropy, cosine feature loss
5
+ - Helper to obtain logits from models with/without built-in heads
6
+ - Lightweight classification head for backbone models (e.g., ViTModel)
7
+ - Simple evaluators (agreement %, KL) and diagnostics
8
+
9
+ Adapters may override `adapter_get_logits(model, x)` if a family needs a
10
+ custom extraction (e.g., language models with past_key_values).
11
+ """
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+ from typing import Callable, Optional, Protocol, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ # -----------------------------------------------------------------------------
23
+ # Config
24
+ # -----------------------------------------------------------------------------
25
+
26
+ @dataclass
27
+ class KDConfig:
28
+ temperature: float = 2.0
29
+ alpha: float = 1.0 # multiplier for KL term; task loss handled outside
30
+
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Losses
34
+ # -----------------------------------------------------------------------------
35
+
36
+ def kl_divergence(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float = 2.0) -> torch.Tensor:
37
+ """Batchmean KL(student/ T || teacher/ T) scaled by T^2 (Hinton-style)."""
38
+ p_s = F.log_softmax(student_logits / T, dim=-1)
39
+ p_t = F.softmax(teacher_logits / T, dim=-1)
40
+ return F.kl_div(p_s, p_t, reduction="batchmean") * (T * T)
41
+
42
+
43
+ def kd_loss(student_logits: torch.Tensor, teacher_logits: torch.Tensor, cfg: KDConfig) -> torch.Tensor:
44
+ return cfg.alpha * kl_divergence(student_logits, teacher_logits, T=cfg.temperature)
45
+
46
+
47
+ def soft_ce(student_logits: torch.Tensor, soft_targets: torch.Tensor) -> torch.Tensor:
48
+ """Soft cross-entropy: expects `soft_targets` already normalized."""
49
+ logp = F.log_softmax(student_logits, dim=-1)
50
+ return -(soft_targets * logp).sum(dim=-1).mean()
51
+
52
+
53
+ def cosine_feature_loss(student_feats: torch.Tensor, teacher_feats: torch.Tensor) -> torch.Tensor:
54
+ """1 - cosine similarity averaged over batch and time/patch dims."""
55
+ s = F.normalize(student_feats, dim=-1)
56
+ t = F.normalize(teacher_feats, dim=-1)
57
+ return (1.0 - (s * t).sum(dim=-1)).mean()
58
+
59
+ def mse_reg(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float = 2.0) -> torch.Tensor:
60
+ mse = F.mse_loss(student_logits,teacher_logits, reduction="mean")
61
+ return mse * (T * T)
62
+
63
+ # -----------------------------------------------------------------------------
64
+ # Logit extraction
65
+ # -----------------------------------------------------------------------------
66
+
67
+ class LogitsProvider(Protocol):
68
+ def __call__(self, model: nn.Module, x: torch.Tensor) -> torch.Tensor: ...
69
+
70
+
71
+ class ClsHead(nn.Module):
72
+ """Minimal classification head: LN + Linear.
73
+
74
+ Useful when the backbone outputs hidden states (e.g., ViTModel) and you
75
+ want logits comparable to a teacher with a classification head.
76
+ """
77
+
78
+ def __init__(self, hidden_size: int, num_classes: int = 1000, base_head: Optional[nn.Module] = None):
79
+ super().__init__()
80
+ self.norm = nn.LayerNorm(hidden_size)
81
+ self.fc = nn.Linear(hidden_size, num_classes)
82
+ if base_head is not None:
83
+ # Try to load weights if shapes match (e.g., from HF classifier)
84
+ try:
85
+ self.load_state_dict(base_head.state_dict(), strict=False)
86
+ except Exception:
87
+ pass
88
+
89
+ def forward(self, cls_token: torch.Tensor) -> torch.Tensor:
90
+ return self.fc(self.norm(cls_token))
91
+
92
+
93
+ @torch.no_grad()
94
+ def infer_hidden_size(model: nn.Module, sample: torch.Tensor) -> int:
95
+ # Run a tiny forward to inspect hidden size when unknown
96
+ model.eval()
97
+ out = model(pixel_values=sample)
98
+ if hasattr(out, "last_hidden_state"):
99
+ return int(out.last_hidden_state.shape[-1])
100
+ if hasattr(out, "logits"):
101
+ return int(out.logits.shape[-1])
102
+ raise RuntimeError("Cannot infer hidden size; provide explicitly.")
103
+
104
+
105
+ def default_get_logits(model: nn.Module, x: torch.Tensor, *, head: Optional[nn.Module] = None) -> torch.Tensor:
106
+ """Family-agnostic logits extractor.
107
+
108
+ - If model output has `.logits`, return it.
109
+ - Else expects `.last_hidden_state` and uses [CLS] via provided `head`.
110
+ """
111
+ out = model(pixel_values=x)
112
+ if hasattr(out, "logits"):
113
+ return out.logits
114
+ if hasattr(out, "last_hidden_state"):
115
+ if head is None:
116
+ raise ValueError("Backbone returned hidden states; supply a classification head.")
117
+ cls_tok = out.last_hidden_state[:, 0, :]
118
+ return head(cls_tok)
119
+ raise ValueError("Model output lacks logits and last_hidden_state.")
120
+
121
+
122
+ # -----------------------------------------------------------------------------
123
+ # Evaluators & diagnostics
124
+ # -----------------------------------------------------------------------------
125
+
126
+ @torch.inference_mode()
127
+ def logits_std(model: nn.Module, loader, *, get_logits: LogitsProvider, batches: int = 10, device: str = "cuda") -> Tuple[float, int]:
128
+ s = 0.0
129
+ k = 0
130
+ for x in loader:
131
+ if k >= batches:
132
+ break
133
+ x = x.to(device)
134
+ y = get_logits(model, x)
135
+ s += y.std().item()
136
+ k += 1
137
+ return (s / max(1, k), k)
138
+
139
+
140
+ @torch.inference_mode()
141
+ def agreement_metrics(
142
+ student: nn.Module,
143
+ teacher: nn.Module,
144
+ loader,
145
+ *,
146
+ get_student_logits: LogitsProvider,
147
+ get_teacher_logits: LogitsProvider,
148
+ batches: int = 20,
149
+ T: float = 1.0,
150
+ device: str = "cuda",
151
+ ) -> dict:
152
+ kl_sum = 0.0
153
+ n = 0
154
+ top1 = 0
155
+ tot = 0
156
+ for i, x in enumerate(loader):
157
+ if i >= batches:
158
+ break
159
+ x = x.to(device)
160
+ t = get_teacher_logits(teacher, x)
161
+ s = get_student_logits(student, x)
162
+ p_s = F.log_softmax(s / T, dim=-1)
163
+ p_t = F.softmax(t / T, dim=-1)
164
+ kl_sum += (F.kl_div(p_s, p_t, reduction="batchmean") * (T * T)).item()
165
+ top1 += (s.argmax(-1) == t.argmax(-1)).sum().item()
166
+ tot += x.size(0)
167
+ n += 1
168
+ return {"kl_TT": kl_sum / max(1, n), "top1_agreement": top1 / max(1, tot)}
169
+
170
+
171
+ # -----------------------------------------------------------------------------
172
+ # Small trainer helpers
173
+ # -----------------------------------------------------------------------------
174
+
175
+ class DualEMA:
176
+ """Simple exponential moving average for a scalar (e.g., lambda or latency)."""
177
+
178
+ def __init__(self, beta: float = 0.9, value: float = 0.0):
179
+ self.beta = float(beta)
180
+ self.value = float(value)
181
+
182
+ def update(self, x: float) -> float:
183
+ self.value = self.beta * self.value + (1 - self.beta) * float(x)
184
+ return self.value
core/.ipynb_checkpoints/finetune-checkpoint.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # core/finetune.py
2
+ """Post-pruning fine-tuning utilities (distillation)."""
3
+
4
+ from __future__ import annotations
5
+ from dataclasses import dataclass
6
+ from typing import Callable, Optional, Tuple, Iterable
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from core.distill import KDConfig, kd_loss, mse_reg
12
+ from core.utils import ensure_trainable_parameters
13
+
14
+ import copy
15
+
16
+
17
+ @dataclass
18
+ class FinetuneConfig:
19
+ epochs: int = 5
20
+ lr: float = 3e-4
21
+ wd: float = 0.0
22
+ kd: KDConfig = KDConfig(temperature=2.0, alpha=1.0)
23
+ amp: bool = True
24
+ # "auto" -> bf16 if supported else fp16; "bf16" | "fp16" | "off" also allowed
25
+ amp_dtype: str = "auto"
26
+ device: str = "cuda"
27
+ log_every: int = 200
28
+ # diagnostics
29
+ grad_check_every: int = 50
30
+ grad_warn_if_zero_steps: int = 2 # consecutive checks with zero grad -> warn
31
+ mse_weight: float = 0.0
32
+
33
+
34
+ def _autocast_and_scaler(amp: bool, amp_dtype: str) -> Tuple[torch.autocast, Optional[torch.amp.GradScaler], bool, str]:
35
+ """
36
+ Returns (autocast_ctx, scaler_or_None, use_scaler_bool, amp_mode_str)
37
+ - BF16 -> autocast(bfloat16), NO GradScaler
38
+ - FP16 -> autocast(float16), GradScaler ENABLED
39
+ - OFF -> disabled autocast, NO GradScaler
40
+ """
41
+ if not amp or amp_dtype == "off":
42
+ ctx = torch.amp.autocast(device_type="cuda", enabled=False)
43
+ return ctx, None, False, "OFF"
44
+
45
+ if amp_dtype == "auto":
46
+ use_bf16 = torch.cuda.is_bf16_supported()
47
+ elif amp_dtype == "bf16":
48
+ use_bf16 = True
49
+ elif amp_dtype == "fp16":
50
+ use_bf16 = False
51
+ else:
52
+ raise ValueError(f"Unknown amp_dtype={amp_dtype!r}")
53
+
54
+ if use_bf16:
55
+ ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True)
56
+ return ctx, None, False, "BF16"
57
+ else:
58
+ ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True)
59
+ try:
60
+ scaler = torch.amp.GradScaler("cuda", enabled=True)
61
+ except TypeError:
62
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
63
+ return ctx, scaler, True, "FP16"
64
+
65
+
66
+ def _images_from_batch(batch):
67
+ if isinstance(batch, dict):
68
+ return batch.get("pixel_values") or batch.get("input")
69
+ if isinstance(batch, (tuple, list)):
70
+ return batch[0]
71
+ return batch
72
+
73
+
74
+ def _param_iter_trainable(model: nn.Module) -> Iterable[torch.nn.Parameter]:
75
+ for p in model.parameters():
76
+ if p.requires_grad:
77
+ yield p
78
+
79
+
80
+ def _grad_norm_and_nonzero(params: Iterable[torch.nn.Parameter]) -> Tuple[float, int]:
81
+ total_sq, nonzero = 0.0, 0
82
+ for p in params:
83
+ g = p.grad
84
+ if g is None:
85
+ continue
86
+ if g.is_sparse:
87
+ g = g.coalesce().values()
88
+ gn = float(g.detach().norm().cpu())
89
+ if gn > 0.0:
90
+ nonzero += 1
91
+ total_sq += gn * gn
92
+ return (total_sq ** 0.5), nonzero
93
+
94
+ @torch.no_grad()
95
+ def recalibrate_bn_stats(model, loader, max_batches=200, device="cuda"):
96
+ model.train() # use training mode to update running stats
97
+ seen = 0
98
+ for i, batch in enumerate(loader):
99
+ if i >= max_batches: break
100
+ x = batch[0] if isinstance(batch, (tuple, list)) else batch
101
+ if not torch.is_tensor(x): continue
102
+ x = x.to(device, non_blocking=True)
103
+ model(x)
104
+ seen += x.size(0)
105
+ return seen
106
+
107
+
108
+ def finetune_student(
109
+ student: nn.Module,
110
+ teacher: nn.Module,
111
+ train_loader,
112
+ *,
113
+ get_student_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
114
+ get_teacher_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
115
+ cfg: FinetuneConfig = FinetuneConfig(),
116
+ val_loader=None,
117
+ on_step: Optional[Callable[[int, float], None]] = None,
118
+ save_best=False
119
+ ) -> nn.Module:
120
+ """Fine-tune a pruned student against a frozen teacher using KD."""
121
+ dev = cfg.device
122
+ student = student.to(dev)
123
+ teacher = teacher.to(dev).eval()
124
+ for p in teacher.parameters():
125
+ p.requires_grad_(False)
126
+ for p in student.parameters():
127
+ p.requires_grad_(True)
128
+
129
+ # Make sure we can actually train
130
+ ensure_trainable_parameters(student, requires_grad=True)
131
+ trainable = sum(p.numel() for p in student.parameters() if p.requires_grad)
132
+ if trainable == 0:
133
+ raise RuntimeError("No trainable parameters in student — cannot finetune.")
134
+
135
+ opt = torch.optim.AdamW(
136
+ _param_iter_trainable(student),
137
+ lr=cfg.lr,
138
+ weight_decay=cfg.wd,
139
+ )
140
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs*len(train_loader), eta_min=3e-5)
141
+
142
+
143
+ autocast_ctx, scaler, use_scaler, amp_mode = _autocast_and_scaler(cfg.amp, cfg.amp_dtype)
144
+ print(f"[AMP] Mode={amp_mode} | GradScaler={'ON' if use_scaler else 'OFF'} | "
145
+ f"KD: T={cfg.kd.temperature} alpha={cfg.kd.alpha} | LR={cfg.lr} WD={cfg.wd} | Trainable params={trainable:,}")
146
+
147
+ zero_grad_streak = 0
148
+ global_step = 0
149
+
150
+ T_max = cfg.kd.temperature
151
+ T_min = 2.0
152
+ kd_conf = cfg.kd
153
+
154
+ best_state = None
155
+ best_val = float("inf")
156
+
157
+ for ep in range(cfg.epochs):
158
+ student.train()
159
+ running, seen = 0.0, 0
160
+
161
+ for i, batch in enumerate(train_loader):
162
+
163
+ step = ep*len(train_loader) + i # global step for T scheduling
164
+ max_steps = cfg.epochs*len(train_loader)
165
+ kd_conf.temperature = T_max - (step/max_steps)*(T_max - T_min)
166
+
167
+ # print(f"Step {step}/{max_steps}, T_min={T_min}, T={kd_conf.temperature}, T_max={T_max}")
168
+
169
+ x = _images_from_batch(batch)
170
+ if not torch.is_tensor(x):
171
+ raise ValueError("Train loader must yield tensors or (tensor, target) tuples.")
172
+ x = x.to(dev, non_blocking=True)
173
+
174
+ with torch.no_grad():
175
+ t = get_teacher_logits(teacher, x)
176
+ # Force numerically stable dtype for the loss
177
+ t = t.float()
178
+
179
+ # ---- forward student under autocast
180
+ with autocast_ctx:
181
+ s = get_student_logits(student, x)
182
+
183
+ # ---- compute KD loss in FP32 (outside autocast) for stability
184
+ s32 = s.float()
185
+ mse = cfg.mse_weight*mse_reg(s32, t, kd_conf.temperature)
186
+ loss = kd_loss(s32, t, kd_conf) + mse
187
+
188
+ opt.zero_grad(set_to_none=True)
189
+ if use_scaler:
190
+ scaler.scale(loss).backward()
191
+ scaler.step(opt)
192
+ scaler.update()
193
+ else:
194
+ loss.backward()
195
+ opt.step()
196
+
197
+ # ---- diagnostics
198
+ bs = x.size(0)
199
+ running += float(loss.detach()) * bs
200
+ seen += bs
201
+ global_step += 1
202
+
203
+ if cfg.grad_check_every and (global_step % cfg.grad_check_every == 0):
204
+ gnorm, n_nonzero = _grad_norm_and_nonzero(_param_iter_trainable(student))
205
+ if n_nonzero == 0 or gnorm == 0.0:
206
+ zero_grad_streak += 1
207
+ if zero_grad_streak >= cfg.grad_warn_if_zero_steps:
208
+ print(f"[WARN] Step {global_step}: zero gradients detected "
209
+ f"(nonzero={n_nonzero}, grad_norm={gnorm:.3e}). "
210
+ f"Check get_student_logits, requires_grad, AMP settings, and data pipeline.")
211
+ else:
212
+ zero_grad_streak = 0
213
+
214
+ if cfg.log_every and (i + 1) % cfg.log_every == 0:
215
+ print(f"Step {i+1}/{len(train_loader)} (ep {ep+1}/{cfg.epochs}): "
216
+ f"running loss = {running / max(1, seen):.4f}")
217
+
218
+ if on_step is not None:
219
+ on_step(global_step, float(loss.detach()))
220
+
221
+ # free ASAP
222
+ del s, s32, t, loss
223
+
224
+ # ---- validation
225
+ if val_loader is not None:
226
+ _ = recalibrate_bn_stats(student, train_loader, max_batches=1000, device=cfg.device)
227
+ student.eval()
228
+ val_loss, vseen = 0.0, 0
229
+ with torch.no_grad():
230
+ for vbatch in val_loader:
231
+ vx = _images_from_batch(vbatch)
232
+ if not torch.is_tensor(vx):
233
+ raise ValueError("Val loader must yield tensors or (tensor, target) tuples.")
234
+ vx = vx.to(dev, non_blocking=True)
235
+
236
+ vt = get_teacher_logits(teacher, vx).float()
237
+ with autocast_ctx:
238
+ vs = get_student_logits(student, vx)
239
+
240
+ vs32 = vs.float()
241
+ vmse = cfg.mse_weight*mse_reg(vs32, vt, kd_conf.temperature)
242
+ vloss = kd_loss(vs32, vt, kd_conf) + vmse
243
+ val_loss += float(vloss.detach()) * vx.size(0)
244
+ vseen += vx.size(0)
245
+
246
+ mean_val = val_loss / max(1, vseen)
247
+ print("\n------------------------------------------------")
248
+ print(f"Epoch {ep+1}/{cfg.epochs}: T={kd_conf.temperature:.2f}, train={running / max(1, seen):.6f}, "
249
+ f"val={mean_val:.6f}")
250
+
251
+ if save_best and (mean_val < best_val):
252
+ best_val = mean_val
253
+ best_state = copy.deepcopy(student.state_dict())
254
+
255
+ print("------------------------------------------------\n")
256
+
257
+ else:
258
+ print(f"Epoch {ep+1}/{cfg.epochs}: train={running / max(1, seen):.6f}")
259
+
260
+ scheduler.step()
261
+
262
+ if save_best and val_loader is not None and best_state is not None:
263
+ student.load_state_dict(best_state)
264
+
265
+ student.eval()
266
+ return student
267
+
core/.ipynb_checkpoints/profiler-checkpoint.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple, robust latency measurement utilities.
2
+
3
+ This module provides GPU-friendly profilers with warmup, multiple repeats,
4
+ median/percentile reporting, and optional outlier rejection via MAD.
5
+
6
+ Design goals:
7
+ - Family-agnostic: take a callable `forward(model, x)` or rely on HF `.forward`
8
+ - Deterministic when desired; avoids autograd by default
9
+ - Works with CUDA or CPU; uses `torch.cuda.Event` for accurate GPU timing
10
+
11
+ Key APIs:
12
+ - measure_latency_ms(model, input_shape | input_tensor, ...)
13
+ - profile(model, sample, settings) -> {mean, p50, p90, p95, p99}
14
+ - LatencyProfiler(settings).measure(...)
15
+ - profile_many_shapes(model, shapes, settings)
16
+ """
17
+ from __future__ import annotations
18
+
19
+ from dataclasses import dataclass
20
+ from statistics import median
21
+ from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple
22
+
23
+ import contextlib
24
+ import math
25
+ import time
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+
31
+ # -----------------------------------------------------------------------------
32
+ # Settings
33
+ # -----------------------------------------------------------------------------
34
+
35
+ @dataclass
36
+ class ProfileSettings:
37
+ warmup: int = 10
38
+ iters: int = 50
39
+ percentile: Sequence[int] = (50, 90, 95, 99)
40
+ sync_each_iter: bool = True
41
+ use_inference_mode: bool = True
42
+ cuda_graph: bool = False # advanced users can enable with static shapes
43
+ reject_outliers_mad: float = 0.0 # e.g., 3.5 to drop extreme spikes
44
+ cudnn_benchmark: bool = True
45
+ deterministic: bool = False # sets cudnn.deterministic
46
+
47
+
48
+ # -----------------------------------------------------------------------------
49
+ # Context helpers
50
+ # -----------------------------------------------------------------------------
51
+
52
+ @contextlib.contextmanager
53
+ def _torch_backend_ctx(settings: ProfileSettings):
54
+ prev_bench = torch.backends.cudnn.benchmark
55
+ prev_det = torch.backends.cudnn.deterministic
56
+ try:
57
+ torch.backends.cudnn.benchmark = bool(settings.cudnn_benchmark)
58
+ torch.backends.cudnn.deterministic = bool(settings.deterministic)
59
+ yield
60
+ finally:
61
+ torch.backends.cudnn.benchmark = prev_bench
62
+ torch.backends.cudnn.deterministic = prev_det
63
+
64
+
65
+ def _percentiles(sorted_vals: Sequence[float], qs: Sequence[int]) -> Dict[int, float]:
66
+ n = len(sorted_vals)
67
+ if n == 0:
68
+ return {q: float("nan") for q in qs}
69
+ out = {}
70
+ for q in qs:
71
+ if n == 1:
72
+ out[q] = sorted_vals[0]
73
+ continue
74
+ k = (q / 100.0) * (n - 1)
75
+ f = math.floor(k)
76
+ c = min(n - 1, f + 1)
77
+ if f == c:
78
+ out[q] = sorted_vals[int(k)]
79
+ else:
80
+ d0 = sorted_vals[f] * (c - k)
81
+ d1 = sorted_vals[c] * (k - f)
82
+ out[q] = d0 + d1
83
+ return out
84
+
85
+
86
+ def _apply_mad_filter(vals: Sequence[float], thresh: float) -> Sequence[float]:
87
+ if thresh <= 0 or len(vals) < 5:
88
+ return vals
89
+ med = median(vals)
90
+ dev = [abs(v - med) for v in vals]
91
+ mad = median(dev) or 1e-12
92
+ keep = [v for v, d in zip(vals, dev) if (d / mad) <= thresh]
93
+ return keep if keep else vals
94
+
95
+
96
+ # -----------------------------------------------------------------------------
97
+ # Core measurement
98
+ # -----------------------------------------------------------------------------
99
+
100
+ @torch.inference_mode()
101
+ def measure_latency_ms(
102
+ model: nn.Module,
103
+ sample: torch.Tensor | Tuple[int, ...],
104
+ *,
105
+ settings: Optional[ProfileSettings] = None,
106
+ device: str = "cuda",
107
+ forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
108
+ ) -> Tuple[float, float]:
109
+ """Return (mean_ms, p95_ms) over `iters` measurements.
110
+
111
+ If `sample` is a shape tuple, a random tensor is created on-device.
112
+ The default forward calls `model(pixel_values=x)` if available, else `model(x)`.
113
+ """
114
+ cfg = settings or ProfileSettings()
115
+
116
+ with _torch_backend_ctx(cfg):
117
+ m = model.to(device).eval()
118
+ if isinstance(sample, torch.Tensor):
119
+ x = sample.to(device)
120
+ else:
121
+ x = torch.randn(*sample, device=device)
122
+
123
+ # Default forward
124
+ def _fwd(mod, inp):
125
+ if hasattr(mod, "forward"):
126
+ try:
127
+ return mod(pixel_values=inp)
128
+ except TypeError:
129
+ return mod(inp)
130
+ return mod(inp)
131
+
132
+ fn = forward_fn or _fwd
133
+
134
+ # Warmup
135
+ if torch.cuda.is_available() and device.startswith("cuda"):
136
+ for _ in range(cfg.warmup):
137
+ _ = fn(m, x)
138
+ torch.cuda.synchronize()
139
+ else:
140
+ for _ in range(cfg.warmup):
141
+ _ = fn(m, x)
142
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
143
+
144
+ times: list[float] = []
145
+ if torch.cuda.is_available() and device.startswith("cuda"):
146
+ for _ in range(cfg.iters):
147
+ t0 = torch.cuda.Event(enable_timing=True)
148
+ t1 = torch.cuda.Event(enable_timing=True)
149
+ t0.record()
150
+ _ = fn(m, x)
151
+ t1.record()
152
+ if cfg.sync_each_iter:
153
+ torch.cuda.synchronize()
154
+ times.append(t0.elapsed_time(t1)) # milliseconds
155
+ else:
156
+ for _ in range(cfg.iters):
157
+ t0 = time.perf_counter()
158
+ _ = fn(m, x)
159
+ if cfg.sync_each_iter and torch.cuda.is_available():
160
+ torch.cuda.synchronize()
161
+ t1 = time.perf_counter()
162
+ times.append((t1 - t0) * 1000.0)
163
+
164
+ times = sorted(_apply_mad_filter(times, cfg.reject_outliers_mad))
165
+ mean_ms = sum(times) / max(1, len(times))
166
+ p = _percentiles(times, cfg.percentile)
167
+ p95 = p.get(95, times[int(0.95 * (len(times) - 1))] if times else float("nan"))
168
+ return mean_ms, p95
169
+
170
+
171
+ # Higher level wrapper returning multiple percentiles
172
+ @torch.inference_mode()
173
+ def profile(
174
+ model: nn.Module,
175
+ sample: torch.Tensor | Tuple[int, ...],
176
+ *,
177
+ settings: Optional[ProfileSettings] = None,
178
+ device: str = "cuda",
179
+ forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
180
+ ) -> Dict[str, float]:
181
+ cfg = settings or ProfileSettings()
182
+ mean_ms, _ = measure_latency_ms(model, sample, settings=cfg, device=device, forward_fn=forward_fn)
183
+ # Re-run percentile calc on same settings for consistency
184
+ m = model.to(device).eval()
185
+ if isinstance(sample, torch.Tensor):
186
+ x = sample.to(device)
187
+ else:
188
+ x = torch.randn(*sample, device=device)
189
+
190
+ if torch.cuda.is_available() and device.startswith("cuda"):
191
+ times = []
192
+ for _ in range(cfg.iters):
193
+ t0 = torch.cuda.Event(True); t1 = torch.cuda.Event(True)
194
+ t0.record(); _ = (forward_fn or (lambda a, b: a(pixel_values=b)))(m, x); t1.record();
195
+ if cfg.sync_each_iter: torch.cuda.synchronize()
196
+ times.append(t0.elapsed_time(t1))
197
+ else:
198
+ times = []
199
+ for _ in range(cfg.iters):
200
+ t0 = time.perf_counter(); _ = (forward_fn or (lambda a, b: a(pixel_values=b)))(m, x); t1 = time.perf_counter()
201
+ times.append((t1 - t0) * 1000.0)
202
+
203
+ times = sorted(_apply_mad_filter(times, cfg.reject_outliers_mad))
204
+ percs = _percentiles(times, cfg.percentile)
205
+ out = {"mean": sum(times) / max(1, len(times))}
206
+ out.update({f"p{q}": v for q, v in percs.items()})
207
+ return out
208
+
209
+
210
+ class LatencyProfiler:
211
+ """Reusable profiler with fixed settings."""
212
+
213
+ def __init__(self, settings: Optional[ProfileSettings] = None, device: str = "cuda"):
214
+ self.settings = settings or ProfileSettings()
215
+ self.device = device
216
+
217
+ def measure(self, model: nn.Module, sample: torch.Tensor | Tuple[int, ...], *, forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None) -> Tuple[float, float]:
218
+ return measure_latency_ms(model, sample, settings=self.settings, device=self.device, forward_fn=forward_fn)
219
+
220
+ def profile(self, model: nn.Module, sample: torch.Tensor | Tuple[int, ...], *, forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None) -> Dict[str, float]:
221
+ return profile(model, sample, settings=self.settings, device=self.device, forward_fn=forward_fn)
222
+
223
+
224
+ @torch.inference_mode()
225
+ def profile_many_shapes(
226
+ model: nn.Module,
227
+ shapes: Iterable[Tuple[int, ...]],
228
+ *,
229
+ settings: Optional[ProfileSettings] = None,
230
+ device: str = "cuda",
231
+ forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
232
+ ) -> Dict[Tuple[int, ...], Dict[str, float]]:
233
+ out: Dict[Tuple[int, ...], Dict[str, float]] = {}
234
+ for shp in shapes:
235
+ out[tuple(shp)] = profile(model, shp, settings=settings, device=device, forward_fn=forward_fn)
236
+ return out
core/.ipynb_checkpoints/proxy_cost-checkpoint.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # core/proxy_cost.py
2
+ """Latency proxy models and a tiny LUT for hardware correction.
3
+
4
+ This file defines a family-agnostic interface plus concrete proxies (ViT, ResNet, LLM)
5
+ that estimate latency from *soft structure* (gates) and input size. All proxies accept
6
+ the trainer's `(model, batch) -> ms` call signature directly (batches may be dict/tuple/tensor).
7
+ A small, in-memory LUT can be populated from real measurements during training to correct
8
+ analytic estimates.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Any, Dict, Optional, Tuple, Union, List
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from .gates import iter_gates, _as_like # _as_like is used by ViT proxy
19
+
20
+
21
+ # -----------------------------------------------------------------------------
22
+ # Small batch helpers (shared)
23
+ # -----------------------------------------------------------------------------
24
+
25
+ TensorOrBatch = Union[torch.Tensor, Tuple, List, Dict[str, Any]]
26
+
27
+ def _first_tensor(batch: TensorOrBatch) -> torch.Tensor:
28
+ """Find the first tensor inside a batch-like structure."""
29
+ if torch.is_tensor(batch):
30
+ return batch
31
+ if isinstance(batch, dict):
32
+ # Common keys across tasks
33
+ for k in ("input_ids", "pixel_values", "images", "x"):
34
+ v = batch.get(k, None)
35
+ if torch.is_tensor(v):
36
+ return v
37
+ # fallback: first tensor value
38
+ for v in batch.values():
39
+ if torch.is_tensor(v):
40
+ return v
41
+ raise ValueError("Batch dict has no tensor field I recognize.")
42
+ if isinstance(batch, (list, tuple)):
43
+ for v in batch:
44
+ if torch.is_tensor(v):
45
+ return v
46
+ # torchvision pattern: ([aug1, aug2], label)
47
+ if len(batch) and isinstance(batch[0], (list, tuple)):
48
+ for v in batch[0]:
49
+ if torch.is_tensor(v):
50
+ return v
51
+ raise ValueError("Cannot find a tensor in the provided batch.")
52
+
53
+ def _ids_from_batch(batch: TensorOrBatch) -> torch.Tensor:
54
+ """Return a 2D [B,S] tensor representing token ids for LLMs."""
55
+ if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
56
+ return batch["input_ids"]
57
+ t = _first_tensor(batch)
58
+ if t.dim() >= 2:
59
+ return t
60
+ raise ValueError("Cannot infer [B,S] from batch; need 'input_ids' or a 2D tensor.")
61
+
62
+ def _nchw_from_batch(batch: TensorOrBatch) -> Tuple[int, int, int, int]:
63
+ """Return NCHW shape from a batch or an explicit (N,C,H,W) tuple/list/tensor."""
64
+ if isinstance(batch, (tuple, list)) and len(batch) == 4 and all(isinstance(x, int) for x in batch):
65
+ return tuple(batch) # type: ignore[return-value]
66
+ x = _first_tensor(batch)
67
+ if x.dim() != 4:
68
+ raise ValueError(f"Expected NCHW tensor for CNN proxy; got tensor with shape {tuple(x.shape)}")
69
+ N, C, H, W = map(int, x.shape)
70
+ return (N, C, H, W)
71
+
72
+
73
+ # -----------------------------------------------------------------------------
74
+ # Base proxy + LUT
75
+ # -----------------------------------------------------------------------------
76
+
77
+ class LatencyProxy(nn.Module):
78
+ """Abstract proxy producing a scalar latency-like value (ms).
79
+
80
+ Subclasses implement `_predict_raw` and may define `_signature` keys used by
81
+ a LUT to refine estimates with real measurements. Proxies accept either a
82
+ batch-like object (dict/tuple/tensor) or an explicit shape tuple.
83
+ """
84
+
85
+ def __init__(self):
86
+ super().__init__()
87
+
88
+ def predict(
89
+ self,
90
+ model: nn.Module,
91
+ sample: TensorOrBatch,
92
+ *,
93
+ policy=None,
94
+ step: Optional[int] = None,
95
+ **kwargs,
96
+ ) -> torch.Tensor:
97
+ """Batch-friendly entry point. `sample` may be a batch or explicit shape."""
98
+ return self._predict_raw(model, sample, policy=policy, step=step, **kwargs)
99
+
100
+ def _predict_raw(
101
+ self,
102
+ model: nn.Module,
103
+ sample: TensorOrBatch,
104
+ *,
105
+ policy=None,
106
+ step: Optional[int] = None,
107
+ **kwargs,
108
+ ) -> torch.Tensor: # pragma: no cover - abstract
109
+ raise NotImplementedError
110
+
111
+ def signature(
112
+ self,
113
+ model: nn.Module,
114
+ sample: TensorOrBatch,
115
+ *,
116
+ policy=None,
117
+ step: Optional[int] = None
118
+ ) -> Tuple:
119
+ """Return a hashable signature describing the workload shape."""
120
+ if torch.is_tensor(sample):
121
+ shp = tuple(sample.shape)
122
+ elif isinstance(sample, (tuple, list)):
123
+ shp = tuple(sample)
124
+ elif isinstance(sample, dict):
125
+ # summarize the shapes of any tensors in dict
126
+ shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v))
127
+ else:
128
+ shp = (str(type(sample)),)
129
+ return (type(self).__name__, shp)
130
+
131
+
132
+ class LatencyLUT:
133
+ """Tiny LUT mapping `(signature) -> measured_ms`."""
134
+
135
+ def __init__(self):
136
+ self._table: Dict[Tuple[Any, ...], float] = {}
137
+
138
+ def update(self, signature: Tuple[Any, ...], measured_ms: float) -> None:
139
+ self._table[signature] = float(measured_ms)
140
+
141
+ def get(self, signature: Tuple[Any, ...]) -> Optional[float]:
142
+ return self._table.get(signature)
143
+
144
+ def blend(self, raw_estimate: torch.Tensor, signature: Tuple[Any, ...]) -> torch.Tensor:
145
+ val = self.get(signature)
146
+ if val is None:
147
+ return raw_estimate
148
+ # Put on same device/dtype as raw_estimate
149
+ return _as_like(raw_estimate, val)
150
+
151
+
152
+ # -----------------------------------------------------------------------------
153
+ # ViT proxy (analytic + gates), with scale and per-term weights
154
+ # -----------------------------------------------------------------------------
155
+
156
+ @dataclass
157
+ class ViTProxyConfig:
158
+ scale_ms: float = 1.0
159
+ alpha_qkv: float = 1.0
160
+ alpha_scores: float = 1.0
161
+ alpha_out: float = 1.0
162
+ alpha_mlp: float = 1.0
163
+
164
+ def _vit_layers(m):
165
+ enc = getattr(m, "encoder", None)
166
+ if enc is not None and hasattr(enc, "layer"):
167
+ return enc.layer
168
+ vit = getattr(m, "vit", None)
169
+ if vit is not None and hasattr(vit, "encoder") and hasattr(vit.encoder, "layer"):
170
+ return vit.encoder.layer
171
+ raise TypeError("Expected a HF ViT with *.encoder.layer (ViTModel or ViTForImageClassification).")
172
+
173
+
174
+ class ViTLatencyProxy(LatencyProxy):
175
+ """Latency proxy for ViT models. Accepts batches or (N,C,H,W) tuples."""
176
+
177
+ def __init__(self, cfg: Optional[ViTProxyConfig] = None, lut: Optional[LatencyLUT] = None):
178
+ super().__init__()
179
+ self.cfg = cfg or ViTProxyConfig()
180
+ self.lut = lut or LatencyLUT()
181
+
182
+ # ---- helpers -------------------------------------------------------------
183
+ @staticmethod
184
+ def _input_spec(sample: TensorOrBatch) -> Tuple[int, int, int]:
185
+ if isinstance(sample, (tuple, list)) and len(sample) == 4 and all(isinstance(x, int) for x in sample):
186
+ B, C, H, W = sample
187
+ return int(B), int(H), int(W)
188
+ x = _first_tensor(sample)
189
+ if x.dim() != 4:
190
+ raise ValueError("ViTLatencyProxy expects a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
191
+ B, C, H, W = x.shape
192
+ return int(B), int(H), int(W)
193
+
194
+ @staticmethod
195
+ def _patch_hw(cfg) -> Tuple[int, int]:
196
+ patch = getattr(cfg, "patch_size", 16)
197
+ if isinstance(patch, (tuple, list)):
198
+ return int(patch[0]), int(patch[1])
199
+ return int(patch), int(patch)
200
+
201
+ @staticmethod
202
+ def _soft_heads_from_block(blk) -> Optional[torch.Tensor]:
203
+ # Prefer a nested attention with kept_heads_soft()
204
+ attn = getattr(getattr(blk, "attention", None), "attention", None)
205
+ if attn is not None and hasattr(attn, "kept_heads_soft"):
206
+ return attn.kept_heads_soft()
207
+ return None
208
+
209
+ @staticmethod
210
+ def _find_ffn_gate(blk):
211
+ inter = getattr(blk, "intermediate", None)
212
+ if inter is None:
213
+ return None
214
+ # Common attribute names
215
+ for nm in ("neuron_gate", "gate", "ffn_gate"):
216
+ g = getattr(inter, nm, None)
217
+ if g is not None and hasattr(g, "logits") and hasattr(g, "tau"):
218
+ return g
219
+ # Last resort: scan children
220
+ for m in blk.modules():
221
+ if hasattr(m, "logits") and hasattr(m, "tau"):
222
+ return m
223
+ return None
224
+
225
+ # ---- proxy ---------------------------------------------------------------
226
+ def _predict_raw(
227
+ self,
228
+ model: nn.Module,
229
+ sample: TensorOrBatch,
230
+ *,
231
+ policy=None,
232
+ step: Optional[int] = None
233
+ ) -> torch.Tensor:
234
+ anchor = next((p for p in model.parameters()), torch.tensor(0.0))
235
+
236
+ B, H_img, W_img = self._input_spec(sample)
237
+ cfg = getattr(model, "config", None)
238
+ if cfg is None:
239
+ raise ValueError("Model must expose a HuggingFace-like .config for ViT proxy")
240
+ ph, pw = self._patch_hw(cfg)
241
+
242
+ S = _as_like(anchor, 1 + (H_img // ph) * (W_img // pw))
243
+ D = _as_like(anchor, int(getattr(cfg, "hidden_size", 768)))
244
+ Hh = _as_like(anchor, int(getattr(cfg, "num_attention_heads", 12)))
245
+ Dh = D // Hh
246
+
247
+ warm = False
248
+ if policy is not None and step is not None:
249
+ warm = (step < int(getattr(policy, "warmup_steps", 0)))
250
+
251
+ total_qkv = _as_like(anchor, 0.0)
252
+ total_scores = _as_like(anchor, 0.0)
253
+ total_out = _as_like(anchor, 0.0)
254
+ total_mlp = _as_like(anchor, 0.0)
255
+
256
+ default_hidden = _as_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D))))
257
+
258
+ layers = _vit_layers(model)
259
+ for blk in layers:
260
+ heads_soft = Hh if warm else (self._soft_heads_from_block(blk) or Hh)
261
+
262
+ # FFN hidden expectation
263
+ if warm:
264
+ hidden_soft = default_hidden
265
+ else:
266
+ g = self._find_ffn_gate(blk)
267
+ if g is None:
268
+ hidden_soft = default_hidden
269
+ else:
270
+ probs = torch.sigmoid(g.logits / g.tau)
271
+ group = int(getattr(g, "group", getattr(g, "group_size", 16)))
272
+ hidden_soft = probs.sum() * _as_like(anchor, group)
273
+
274
+ D_kept = heads_soft * Dh
275
+
276
+ total_qkv += 3 * S * D * D_kept
277
+ total_scores += (S * S) * heads_soft * Dh
278
+ total_out += S * D_kept * D
279
+ total_mlp += 2 * S * D * hidden_soft
280
+
281
+ raw = (
282
+ self.cfg.alpha_qkv * total_qkv
283
+ + self.cfg.alpha_scores * total_scores
284
+ + self.cfg.alpha_out * total_out
285
+ + self.cfg.alpha_mlp * total_mlp
286
+ )
287
+ raw_ms = raw * _as_like(anchor, float(self.cfg.scale_ms))
288
+
289
+ # optional LUT correction
290
+ sig = self.signature(model, sample, policy=policy, step=step)
291
+ return self.lut.blend(raw_ms, sig)
292
+
293
+ # A reasonable default signature for ViT workloads
294
+ def signature(self, model: nn.Module, sample, *, policy=None, step: Optional[int] = None) -> Tuple:
295
+ if torch.is_tensor(sample):
296
+ shp = tuple(sample.shape)
297
+ elif isinstance(sample, (tuple, list)):
298
+ shp = tuple(sample)
299
+ elif isinstance(sample, dict):
300
+ shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v))
301
+ else:
302
+ shp = (str(type(sample)),)
303
+ cfg = getattr(model, "config", None)
304
+ heads = int(getattr(cfg, "num_attention_heads", 12))
305
+ hidden = int(getattr(cfg, "hidden_size", 768))
306
+ inter = int(getattr(cfg, "intermediate_size", 3072))
307
+ return ("ViT", shp, heads, hidden, inter)
308
+
309
+ @torch.no_grad()
310
+ def calibrate(self, model: nn.Module, shape: tuple, measure_fn, *, device: str = "cuda") -> float:
311
+ """Set proxy scale so that keep-all student matches measured ms.
312
+
313
+ `measure_fn(model, shape_or_tensor)` should return `(mean_ms, p95_ms)`.
314
+ """
315
+
316
+ sample_t = torch.randn(shape, device=device)
317
+
318
+ sample_t = sample_t.to(device)
319
+ model = model.to(device).eval()
320
+ mean_ms, _ = measure_fn(model, shape, device=device)
321
+ soft_ms = self.predict(model, sample_t).item()
322
+ self.cfg.scale_ms = float(mean_ms / max(soft_ms, 1e-9))
323
+ return self.cfg.scale_ms
324
+
325
+ # ------------------------------ ResNet Proxy ------------------------------
326
+
327
+ @dataclass
328
+ class ResNetProxyConfig:
329
+ scale_ms: float = 1.0
330
+ alpha_conv: float = 1.0 # weight for conv FLOPs term
331
+
332
+
333
+ def _as_const_like_resnet(x_like: torch.Tensor, val):
334
+ return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype)
335
+
336
+
337
+ def _find_anchor_param(model: nn.Module) -> torch.Tensor:
338
+ # Prefer any gate-like parameter; otherwise any parameter; else cpu scalar
339
+ for m in model.modules():
340
+ for nm in ("logits", "head_gate"):
341
+ t = getattr(m, nm, None)
342
+ if isinstance(t, torch.Tensor):
343
+ return t
344
+ for p in model.parameters():
345
+ return p
346
+ return torch.tensor(0.0)
347
+
348
+
349
+ def _kept_from_gate(module, anchor: torch.Tensor) -> Optional[torch.Tensor]:
350
+ """Return expected kept channels for a BN gate: probs.sum() * group_size.
351
+ If no gate is found, return None.
352
+ """
353
+ g = None
354
+ for nm in ("gate", "neuron_gate", "channel_gate", "bn_gate"):
355
+ if hasattr(module, nm):
356
+ g = getattr(module, nm)
357
+ break
358
+ if g is None and hasattr(module, "logits") and hasattr(module, "tau"):
359
+ g = module
360
+
361
+ if g is None or not hasattr(g, "logits"):
362
+ return None
363
+ logits = g.logits
364
+ tau = float(getattr(g, "tau", 1.5))
365
+ group = int(getattr(g, "group", getattr(g, "group_size", 1)))
366
+ if group <= 0: group = 1
367
+ probs = torch.sigmoid(logits / tau)
368
+ return probs.sum() * _as_const_like_resnet(anchor, group)
369
+
370
+
371
+ class ResNetLatencyProxy(LatencyProxy):
372
+ """Latency proxy for ResNet-like backbones with BN gates.
373
+
374
+ Approximates latency with a FLOPs-style sum over convs, using the *expected*
375
+ kept channels after each BN gate (probs.sum()*group_size). Falls back to the
376
+ full channel count when a gate is not found.
377
+
378
+ Accepts a batch or an explicit (N,C,H,W) shape.
379
+ """
380
+
381
+ def __init__(self, cfg: Optional[ResNetProxyConfig] = None):
382
+ super().__init__()
383
+ self.cfg = cfg or ResNetProxyConfig()
384
+
385
+ def _add_cost(self, cost_like: torch.Tensor, oc, ic, k, stride, H, W):
386
+ alpha = _as_const_like_resnet(cost_like, self.cfg.alpha_conv)
387
+ # update spatial dims with conv stride (roughly, ignoring padding effects)
388
+ H = (H + stride - 1) // stride
389
+ W = (W + stride - 1) // stride
390
+ flops = _as_const_like_resnet(cost_like, oc) * _as_const_like_resnet(cost_like, ic) * (k * k) * _as_const_like_resnet(cost_like, H) * _as_const_like_resnet(cost_like, W)
391
+ return cost_like + alpha * flops, H, W
392
+
393
+ def _predict_raw(self, model: nn.Module, sample: TensorOrBatch, **_) -> torch.Tensor:
394
+ N, C_in, H0, W0 = _nchw_from_batch(sample)
395
+ anchor = _find_anchor_param(model)
396
+ cost = _as_const_like_resnet(anchor, 0.0)
397
+ H = _as_const_like_resnet(anchor, int(H0))
398
+ W = _as_const_like_resnet(anchor, int(W0))
399
+
400
+ # Stem
401
+ conv1 = getattr(model, "conv1")
402
+ bn1 = getattr(model, "bn1", None)
403
+ k = conv1.kernel_size[0]
404
+ s = conv1.stride[0]
405
+ kept_out = None
406
+ if bn1 is not None:
407
+ kept = _kept_from_gate(bn1, anchor)
408
+ if kept is not None:
409
+ kept_out = kept
410
+ oc_eff = kept_out if kept_out is not None else _as_const_like_resnet(anchor, conv1.out_channels)
411
+ cost, H, W = self._add_cost(cost, oc_eff, _as_const_like_resnet(anchor, C_in), k, s, H, W)
412
+ in_ch = oc_eff
413
+
414
+ def _block_cost(block, in_ch, H, W, cost):
415
+ # conv1 -> bn1
416
+ c1 = block.conv1
417
+ b1 = block.bn1 if hasattr(block, "bn1") else None
418
+ k1, s1 = c1.kernel_size[0], c1.stride[0]
419
+ oc1_eff = _kept_from_gate(b1, anchor) or _as_const_like_resnet(anchor, c1.out_channels)
420
+ cost, H, W = self._add_cost(cost, oc1_eff, in_ch, k1, s1, H, W)
421
+
422
+ # conv2 -> bn2
423
+ c2 = block.conv2
424
+ b2 = block.bn2 if hasattr(block, "bn2") else None
425
+ k2, s2 = c2.kernel_size[0], c2.stride[0]
426
+ oc2_eff = _kept_from_gate(b2, anchor) or _as_const_like_resnet(anchor, c2.out_channels)
427
+ cost, H, W = self._add_cost(cost, oc2_eff, oc1_eff, k2, s2, H, W)
428
+
429
+ return oc2_eff, H, W, cost
430
+
431
+ # Layers
432
+ for lname in ("layer1", "layer2", "layer3", "layer4"):
433
+ layer = getattr(model, lname, None)
434
+ if layer is None:
435
+ continue
436
+ for blk in layer:
437
+ in_ch, H, W, cost = _block_cost(blk, in_ch, H, W, cost)
438
+
439
+ scale = _as_const_like_resnet(anchor, self.cfg.scale_ms)
440
+ return cost * scale
441
+
442
+ @torch.no_grad()
443
+ def calibrate(self, model: nn.Module, keepall_export_fn, profiler_fn, sample: TensorOrBatch, device: str = "cuda") -> float:
444
+ """Calibrate `scale_ms` so proxy(model_keepall) ~= real latency in ms."""
445
+ keep = keepall_export_fn(model)
446
+ sample_shape = _nchw_from_batch(sample)
447
+ mean_ms, _ = profiler_fn(keep, sample_shape, device=device)
448
+ soft = float(self.predict(model, sample).detach().cpu())
449
+ self.cfg.scale_ms = mean_ms / max(soft, 1e-9)
450
+ return mean_ms
451
+
452
+
453
+ # -----------------------------------------------------------------------------
454
+ # LLM proxy
455
+ # -----------------------------------------------------------------------------
456
+
457
+ """
458
+ LatencyProxyLLM
459
+ ---------------
460
+ A lightweight latency proxy for decoder-only HF LLMs (LLaMA/Mistral style).
461
+
462
+ - Estimates end-to-end latency (ms-like scalar) for a given (B, S, T):
463
+ * Prefill on S tokens (build KV cache)
464
+ * Cached decode for T steps
465
+ - Uses soft gate expectations:
466
+ * Attention heads (HeadGate on GatedSelfAttentionLLM)
467
+ * FFN hidden (SwiGLUWidthGate via .mlp.neuron_gate)
468
+ - Calibrate .scale_ms so proxy ≈ real latency of a keep-all model.
469
+
470
+ Public API
471
+ ----------
472
+ - LatencyProxyLLM(...).predict(model, batch_or_shape) # trainer entry
473
+ - LatencyProxyLLM(...).predict(model, B=?, S=?, T=?) # explicit entry
474
+ - LatencyProxyLLM(...).debug_layer_view(...)
475
+ - calibrate_proxy_llm(...), calibrate_proxy_llm_from_batch(...)
476
+ """
477
+
478
+ # ------------------------------------------------------------
479
+ # Shared tiny utils (device/dtype-safe constants)
480
+ # ------------------------------------------------------------
481
+ def _find_gate_param_or_fallback(model: nn.Module) -> torch.Tensor:
482
+ """
483
+ Return a tensor to anchor device/dtype for proxy constants.
484
+ Prefer gate logits; else any parameter; else CPU fp32 scalar.
485
+ """
486
+ for m in model.modules():
487
+ if hasattr(m, "head_gate") and hasattr(getattr(m, "head_gate"), "logits"):
488
+ return m.head_gate.logits
489
+ if hasattr(m, "neuron_gate") and hasattr(m.neuron_gate, "logits"):
490
+ return m.neuron_gate.logits
491
+ if hasattr(m, "logits") and isinstance(getattr(m, "logits"), torch.Tensor):
492
+ return m.logits
493
+ for p in model.parameters():
494
+ return p
495
+ return torch.tensor(0.0)
496
+
497
+ def _as_const_like(x_like: torch.Tensor, val):
498
+ return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype)
499
+
500
+
501
+ # ------------------------------------------------------------
502
+ # Proxy
503
+ # ------------------------------------------------------------
504
+ @dataclass
505
+ class _WarmupOnlyPolicy:
506
+ """Tiny policy shim so you can pass warmup_steps to .predict()."""
507
+ warmup_steps: int = 0
508
+
509
+ class LatencyProxyLLM(LatencyProxy):
510
+ """
511
+ LLM latency proxy (ms ~ weighted FLOPs/bandwidth terms) for prefill + cached decode.
512
+ Accepts either a batch or explicit B,S,T.
513
+ """
514
+
515
+ def __init__(
516
+ self,
517
+ *,
518
+ scale_ms: float = 1.0,
519
+ alpha_qkv: float = 1.0,
520
+ alpha_scores: float = 1.0,
521
+ alpha_out: float = 1.0,
522
+ alpha_mlp: float = 1.0,
523
+ gate_kv_in_proxy: bool = False,
524
+ default_T: int = 128,
525
+ ):
526
+ super().__init__()
527
+ self.scale_ms = float(scale_ms)
528
+ self.alpha_qkv = float(alpha_qkv)
529
+ self.alpha_scores = float(alpha_scores)
530
+ self.alpha_out = float(alpha_out)
531
+ self.alpha_mlp = float(alpha_mlp)
532
+ self.gate_kv_in_proxy = bool(gate_kv_in_proxy)
533
+ self.default_T = int(default_T)
534
+
535
+ # ---------- gate discovery ----------
536
+ @staticmethod
537
+ def _soft_heads_from_block_llm(blk) -> Optional[torch.Tensor]:
538
+ attn = getattr(blk, "self_attn", None)
539
+ if attn is None:
540
+ return None
541
+ if hasattr(attn, "kept_heads_soft") and callable(attn.kept_heads_soft):
542
+ return attn.kept_heads_soft()
543
+ logits, tau = None, None
544
+ if hasattr(attn, "head_gate") and hasattr(attn.head_gate, "logits"):
545
+ logits = attn.head_gate.logits
546
+ tau = float(getattr(attn.head_gate, "tau", getattr(attn, "tau", 1.5)))
547
+ elif hasattr(attn, "logits"):
548
+ logits = attn.logits
549
+ tau = float(getattr(attn, "tau", 1.5))
550
+ if logits is None:
551
+ return None
552
+ return torch.sigmoid(logits / tau).sum()
553
+
554
+ @staticmethod
555
+ def _find_ffn_gate_llm(blk):
556
+ mlp = getattr(blk, "mlp", None)
557
+ g = getattr(mlp, "neuron_gate", None) if mlp is not None else None
558
+ if g is not None and hasattr(g, "logits") and hasattr(g, "tau"):
559
+ return g
560
+ return None
561
+
562
+ def _soft_hidden_from_block_llm(self, blk, default_hidden, anchor, warm=False):
563
+ if warm:
564
+ return default_hidden
565
+ g = self._find_ffn_gate_llm(blk)
566
+ if g is None:
567
+ return default_hidden
568
+ probs = torch.sigmoid(g.logits / float(g.tau)) # [#groups]
569
+ group = int(getattr(g, "group", getattr(g, "group_size", 128)))
570
+ kept_hidden = probs.sum() * _as_const_like(anchor, group)
571
+ return kept_hidden
572
+
573
+ # ---------- main ----------
574
+ def predict( # trainer entry and explicit-shape entry unified
575
+ self,
576
+ model: nn.Module,
577
+ sample: Optional[TensorOrBatch] = None,
578
+ *,
579
+ B: Optional[int] = None,
580
+ S: Optional[int] = None,
581
+ T: Optional[int] = None,
582
+ policy: Optional[object] = None,
583
+ step: Optional[int] = None,
584
+ return_terms: bool = False,
585
+ ):
586
+ # Allow explicit B,S,(T) path
587
+ if B is not None and S is not None:
588
+ ids_B, ids_S = int(B), int(S)
589
+ ids_T = int(T) if T is not None else int(self.default_T)
590
+ else:
591
+ if sample is None:
592
+ raise ValueError("LatencyProxyLLM.predict needs either a batch sample or explicit B,S.")
593
+ if isinstance(sample, (tuple, list)) and len(sample) in (2, 3) and all(isinstance(x, int) for x in sample):
594
+ # explicit (B,S) or (B,S,T)
595
+ ids_B, ids_S = int(sample[0]), int(sample[1])
596
+ ids_T = int(sample[2]) if len(sample) == 3 else int(self.default_T)
597
+ else:
598
+ ids = _ids_from_batch(sample)
599
+ ids_B, ids_S = int(ids.size(0)), int(ids.size(1))
600
+ ids_T = int(self.default_T) if T is None else int(T)
601
+
602
+ anchor = _find_gate_param_or_fallback(model)
603
+
604
+ # scalar tensors (same device/dtype)
605
+ B_t = _as_const_like(anchor, ids_B)
606
+ S_t = _as_const_like(anchor, ids_S)
607
+ T_t = _as_const_like(anchor, ids_T)
608
+
609
+ cfg = model.config
610
+ D = _as_const_like(anchor, int(cfg.hidden_size))
611
+ Hh = _as_const_like(anchor, int(cfg.num_attention_heads))
612
+ Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hh))))
613
+ Dh = D // Hh
614
+
615
+ warmup_steps = int(getattr(policy, "warmup_steps", 0)) if policy is not None else 0
616
+ warm = bool(step is not None and step < warmup_steps)
617
+
618
+ total_qkv = anchor.new_zeros(())
619
+ total_scores = anchor.new_zeros(())
620
+ total_out = anchor.new_zeros(())
621
+ total_mlp = anchor.new_zeros(())
622
+
623
+ default_hidden = _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D))))
624
+
625
+ layers = getattr(getattr(model, "model", model), "layers", [])
626
+ for blk in layers:
627
+ heads_soft = Hh if warm else (self._soft_heads_from_block_llm(blk) or Hh)
628
+ Dq = heads_soft * Dh
629
+ # K/V effective width
630
+ if self.gate_kv_in_proxy:
631
+ Dkv = heads_soft * Dh
632
+ else:
633
+ Dkv = Hkv * Dh
634
+ hidden_soft = self._soft_hidden_from_block_llm(blk, default_hidden, anchor, warm=warm)
635
+
636
+ # Prefill + decode (simplified aggregation)
637
+ Seff = S_t + T_t
638
+
639
+ # q/k/v linear FLOP-like terms
640
+ total_qkv = total_qkv + (
641
+ # q
642
+ B_t * Seff * D * Dq +
643
+ # k + v
644
+ 2 * B_t * Seff * D * Dkv
645
+ )
646
+ # attention scores (prefill SxS + decode triangular)
647
+ total_scores = total_scores + (
648
+ B_t * (S_t * S_t) * heads_soft * Dh +
649
+ B_t * heads_soft * Dh * (T_t * S_t + (T_t * (T_t + 1)) // 2)
650
+ )
651
+ # out proj
652
+ total_out = total_out + B_t * Seff * Dq * D
653
+ # mlp
654
+ total_mlp = total_mlp + B_t * Seff * 2 * D * hidden_soft
655
+
656
+ flops_like = (
657
+ self.alpha_qkv * total_qkv
658
+ + self.alpha_scores * total_scores
659
+ + self.alpha_out * total_out
660
+ + self.alpha_mlp * total_mlp
661
+ )
662
+
663
+ ms = flops_like * _as_const_like(anchor, self.scale_ms)
664
+ if return_terms:
665
+ return ms, {
666
+ "qkv": float((self.alpha_qkv * total_qkv).detach().cpu()),
667
+ "scores": float((self.alpha_scores * total_scores).detach().cpu()),
668
+ "out": float((self.alpha_out * total_out).detach().cpu()),
669
+ "mlp": float((self.alpha_mlp * total_mlp).detach().cpu()),
670
+ }
671
+ return ms
672
+
673
+ # ---------- per-layer debug ----------
674
+ @torch.no_grad()
675
+ def debug_layer_view(
676
+ self,
677
+ model: nn.Module,
678
+ *,
679
+ B: int,
680
+ S: int,
681
+ T: int,
682
+ policy: Optional[object] = None,
683
+ step: Optional[int] = None,
684
+ ) -> list:
685
+ anchor = _find_gate_param_or_fallback(model)
686
+ cfg = getattr(model, "config", None)
687
+ D = _as_const_like(anchor, int(getattr(cfg, "hidden_size", 0)))
688
+ Hq = _as_const_like(anchor, int(getattr(cfg, "num_attention_heads", 0)))
689
+ Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hq))))
690
+ Dh = D // Hq
691
+
692
+ warm = False
693
+ if policy is not None and step is not None:
694
+ warm = (int(step) < int(getattr(policy, "warmup_steps", 0)))
695
+
696
+ rows = []
697
+ layers = getattr(getattr(model, "model", model), "layers", None) or []
698
+ for i, blk in enumerate(layers):
699
+ heads_soft = Hq if warm else (self._soft_heads_from_block_llm(blk) or Hq)
700
+ Dq = heads_soft * Dh
701
+ Dkv = (heads_soft * Dh) if self.gate_kv_in_proxy else (Hkv * Dh)
702
+ hidden_soft = self._soft_hidden_from_block_llm(
703
+ blk, _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D)))), anchor, warm=warm
704
+ )
705
+ rows.append({
706
+ "layer": i,
707
+ "heads_soft": float(heads_soft.detach().cpu()),
708
+ "Dq≈heads*Dh": float(Dq.detach().cpu()),
709
+ "Dkv_used": float(Dkv.detach().cpu()),
710
+ "ffn_hidden_soft": float(hidden_soft.detach().cpu()),
711
+ })
712
+ return rows
713
+
714
+
715
+ # ------------------------------------------------------------
716
+ # Calibration helpers for LLM
717
+ # ------------------------------------------------------------
718
+ @torch.inference_mode()
719
+ def calibrate_proxy_llm(
720
+ proxy: LatencyProxyLLM,
721
+ model: nn.Module,
722
+ *,
723
+ B: int,
724
+ S: int,
725
+ T: int,
726
+ export_keepall_fn,
727
+ device: str = "cuda",
728
+ warmup: int = 10,
729
+ iters: int = 30,
730
+ ) -> float:
731
+ """
732
+ Calibrate proxy.scale_ms so proxy.predict(...) matches real keep-all latency for (B,S,T).
733
+ Returns the measured real mean latency in ms.
734
+ """
735
+ keepall = export_keepall_fn(model).to(device).eval()
736
+
737
+ # Measure real latency (prefill + decode)
738
+ from core.measure import measure_latency_text_ms as _measure # adjust if your path differs
739
+ real_ms, _ = _measure(keepall, B=B, S=S, T=T, warmup=warmup, iters=iters, device=device)
740
+
741
+ # Soft/proxy latency on *gated* model
742
+ ms_like = proxy.predict(model, B=B, S=S, T=T)
743
+ soft_ms = float(ms_like.detach().item()) if torch.is_tensor(ms_like) else float(ms_like)
744
+
745
+ proxy.scale_ms = float(real_ms / max(soft_ms, 1e-9))
746
+ return real_ms
747
+
748
+
749
+ @torch.inference_mode()
750
+ def calibrate_proxy_llm_from_batch(
751
+ proxy: LatencyProxyLLM,
752
+ model: nn.Module,
753
+ batch: Dict[str, torch.Tensor],
754
+ *,
755
+ T: int,
756
+ export_keepall_fn,
757
+ device: str = "cuda",
758
+ warmup: int = 10,
759
+ iters: int = 30,
760
+ ) -> Tuple[int, int, int, float]:
761
+ """
762
+ Infers (B,S) from a batch like {'input_ids': [B,S], ...},
763
+ calibrates for (B,S,T), and returns (B,S,T, real_ms).
764
+ """
765
+ input_ids = batch["input_ids"]
766
+ B, S = int(input_ids.size(0)), int(input_ids.size(1))
767
+ ms = calibrate_proxy_llm(
768
+ proxy, model, B=B, S=S, T=T, export_keepall_fn=export_keepall_fn,
769
+ device=device, warmup=warmup, iters=iters
770
+ )
771
+ return B, S, T, ms
core/.ipynb_checkpoints/train-checkpoint.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generic Lagrangian trainer (family-agnostic).
2
+
3
+ This module provides a light framework to optimize *gated* students against
4
+ teachers with a latency target enforced via a proxy + optional real probes.
5
+
6
+ It does not assume ViT/ResNet/LLM specifics; adapters provide tiny callables.
7
+
8
+ Key ingredients:
9
+ - Two-phase update per step: (A) weights w.r.t. KD/task, (B) gates w.r.t. KD +
10
+ sparsity + latency penalty with a dual variable λ.
11
+ - Optional periodic export + real-latency probe to correct λ.
12
+ - Constraint projection for gates after each step.
13
+
14
+ Adapters must provide:
15
+ - get_student_logits(model, x) -> Tensor
16
+ - get_teacher_logits(model, x) -> Tensor
17
+ - export_keepall(model) -> nn.Module (clean copy without gates)
18
+ - export_pruned(model, policy, step) -> nn.Module (transient copy for profiling)
19
+
20
+ Core modules used:
21
+ - `distill.KDConfig`, `distill.kd_loss`
22
+ - `gates.combined_penalty`, `gates.PenaltyWeights`, `gates.project_gates_into_constraints`
23
+ - `proxy_cost.LatencyProxy`
24
+ - `profiler.measure_latency_ms`
25
+ """
26
+ from __future__ import annotations
27
+
28
+ from dataclasses import dataclass
29
+ from typing import Callable, Optional
30
+ import gc
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+
35
+ from .distill import KDConfig, kd_loss, mse_reg
36
+ from .gates import PenaltyWeights, Constraints, combined_penalty, project_gates_into_constraints, collect_param_groups
37
+ from .proxy_cost import LatencyProxy
38
+ from .profiler import measure_latency_ms
39
+
40
+ # -----------------------------------------------------------------------------
41
+ # Config
42
+ # -----------------------------------------------------------------------------
43
+
44
+ @dataclass
45
+ class DualConfig:
46
+ lr: float = 0.05 # step for λ update
47
+ ema_beta: float = 0.5 # blend proxy-driven λ and real probe λ
48
+ clip: float = 10.0
49
+
50
+
51
+ @dataclass
52
+ class TrainerConfig:
53
+ kd: KDConfig = KDConfig()
54
+ penalties: PenaltyWeights = PenaltyWeights(l0=0.0, keep_floor_ratio=0.0, bimodality=0.0)
55
+ constraints: Constraints = Constraints(min_keep_ratio=0.0, min_groups=1, max_groups_drop=None)
56
+
57
+ latency_target_ms: float = 30.0
58
+ real_probe_every: int = 0 # steps; 0 disables real probes
59
+ probe_batch_override: Optional[int] = None
60
+ gate_warmup_steps: int = 0 # Freeze gates for early steps
61
+ mse_weight: float = 0.0
62
+
63
+ early_stopping_patience: int = 0
64
+ early_stopping_lambda: float = 1e-4
65
+
66
+ amp: bool = True
67
+ device: str = "cuda"
68
+
69
+ # Optimizers
70
+ lr_gate: float = 1e-2
71
+ lr_linear: float = 1e-4
72
+ lr_affine: float = 3e-4
73
+ wd_linear: float = 1e-4
74
+
75
+ # Mixed precision scaler
76
+ use_grad_scaler: bool = True
77
+
78
+ # Dual update
79
+ dual: DualConfig = DualConfig()
80
+
81
+
82
+ # -----------------------------------------------------------------------------
83
+ # Trainer
84
+ # -----------------------------------------------------------------------------
85
+
86
+ class LagrangeTrainer:
87
+ def __init__(
88
+ self,
89
+ student: nn.Module,
90
+ teacher: nn.Module,
91
+ proxy: LatencyProxy,
92
+ *,
93
+ adapter_get_student_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
94
+ adapter_get_teacher_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
95
+ adapter_export_keepall: Callable[[nn.Module], nn.Module],
96
+ adapter_export_pruned: Callable[[nn.Module, object, int], nn.Module],
97
+ export_policy: object,
98
+ cfg: TrainerConfig,
99
+ ) -> None:
100
+ self.student = student
101
+ self.teacher = teacher.eval()
102
+ for p in self.teacher.parameters():
103
+ p.requires_grad_(False)
104
+ self.proxy = proxy
105
+ self.get_s = adapter_get_student_logits
106
+ self.get_t = adapter_get_teacher_logits
107
+ self.export_keepall = adapter_export_keepall
108
+ self.export_pruned = adapter_export_pruned
109
+ self.export_policy = export_policy
110
+ self.cfg = cfg
111
+
112
+ # Build optimizers (grouped)
113
+ param_groups = collect_param_groups(
114
+ student,
115
+ lr_gate=cfg.lr_gate,
116
+ lr_linear=cfg.lr_linear,
117
+ lr_affine=cfg.lr_affine,
118
+ wd_linear=cfg.wd_linear,
119
+ )
120
+ # gates-only optimizer uses first group
121
+ self.opt_g = torch.optim.Adam([param_groups[0]], lr=param_groups[0]["lr"]) # type: ignore[arg-type]
122
+ # weights optimizer for the rest
123
+ self.opt_w = torch.optim.Adam(param_groups[1:])
124
+
125
+ self.scaler = torch.amp.GradScaler('cuda', enabled=(cfg.amp and cfg.use_grad_scaler))
126
+ self.lambda_: float = 0.0
127
+ self.mse_weight = cfg.mse_weight
128
+
129
+ # ---- internal helpers -----------------------------------------------------
130
+ def _zero_grads(self, params):
131
+ for p in params:
132
+ if p.grad is not None:
133
+ p.grad = None
134
+
135
+ def _has_grad(self, params) -> bool:
136
+ for p in params:
137
+ if p.grad is not None:
138
+ return True
139
+ return False
140
+
141
+ # ---- training -------------------------------------------------------------
142
+ def train_epoch(self, loader, *, real_policy=None, verbose_every: int = 50):
143
+ device = self.cfg.device
144
+ self.student.train().to(device)
145
+ self.teacher.to(device).eval()
146
+
147
+ running = 0.0
148
+ seen = 0
149
+ lam_real = self.lambda_
150
+
151
+ total_steps = len(loader)
152
+
153
+
154
+ for step, batch in enumerate(loader, 1):
155
+ # Move batch to device in a type-safe way
156
+ batch = _move_batch_to_device(batch, device)
157
+
158
+ # with torch.inference_mode():
159
+ with torch.no_grad():
160
+ t_logits = self.get_t(self.teacher, batch) # [B,1,V]
161
+ # match AMP compute dtype to avoid upcasting later
162
+ if self.cfg.amp:
163
+ # infer autocast dtype from student params (bf16 or fp16)
164
+ sparam = next(self.student.parameters())
165
+ t_logits = t_logits.to(dtype=sparam.dtype, non_blocking=True)
166
+
167
+
168
+ # -------- Pass A: WEIGHTS (KD only) --------
169
+ self.opt_w.zero_grad(set_to_none=True)
170
+
171
+ with torch.amp.autocast('cuda', enabled=self.cfg.amp):
172
+ # Adapters receive the batch object (dict/tuple/tensor)
173
+ s_logits = self.get_s(self.student, batch)
174
+ # with torch.no_grad():
175
+ # t_logits = self.get_t(self.teacher, batch)
176
+ mse = self.mse_weight*mse_reg(s_logits, t_logits, self.cfg.kd.temperature)
177
+ loss_w = kd_loss(s_logits, t_logits, self.cfg.kd) + mse
178
+
179
+ self.scaler.scale(loss_w).backward()
180
+ # Prevent gate params from changing in pass A
181
+ gate_params = self.opt_g.param_groups[0]["params"]
182
+ self._zero_grads(gate_params)
183
+
184
+ if any(p.grad is not None for pg in self.opt_w.param_groups for p in pg["params"]):
185
+ self.scaler.step(self.opt_w)
186
+ self.scaler.update()
187
+ else:
188
+ self.opt_w.zero_grad(set_to_none=True)
189
+
190
+ del s_logits
191
+ gc.collect()
192
+ torch.cuda.empty_cache()
193
+
194
+ if step > int(self.cfg.gate_warmup_steps):
195
+
196
+ # -------- Pass B: GATES (KD + sparsity + λ * gap) --------
197
+ self.opt_g.zero_grad(set_to_none=True)
198
+ with torch.amp.autocast('cuda', enabled=self.cfg.amp):
199
+ s_logits = self.get_s(self.student, batch)
200
+ # with torch.no_grad():
201
+ # t_logits = self.get_t(self.teacher, batch)
202
+ kd_g = kd_loss(s_logits, t_logits, self.cfg.kd)
203
+
204
+ # Proxy gets the batch object too; family-specific proxy can read (B,S) etc.
205
+ o1_ms = self.proxy.predict(self.student, batch)
206
+ gap = torch.relu(o1_ms - float(self.cfg.latency_target_ms))
207
+ reg = combined_penalty(self.student, self.cfg.penalties)
208
+ mse = self.mse_weight*mse_reg(s_logits, t_logits, self.cfg.kd.temperature)
209
+ loss_g = kd_g + _to_tensor(self.lambda_, o1_ms) * gap + reg + mse
210
+
211
+ self.scaler.scale(loss_g).backward()
212
+ # Prevent non-gate params from changing in pass B
213
+ for pg in self.opt_w.param_groups:
214
+ self._zero_grads(pg["params"])
215
+
216
+ if self._has_grad(self.opt_g.param_groups[0]["params"]):
217
+ self.scaler.step(self.opt_g)
218
+ self.scaler.update()
219
+ else:
220
+ self.opt_g.zero_grad(set_to_none=True)
221
+ else:
222
+ o1_ms = self.proxy.predict(self.student, batch)
223
+ s_logits = loss_g = kd_g = reg = torch.tensor(0.0, device=device)
224
+
225
+ # -------- Dual (λ) update using proxy --------
226
+ with torch.no_grad():
227
+ lam_proxy = max(0.0, self.lambda_ + self.cfg.dual.lr * (float(o1_ms.detach()) - self.cfg.latency_target_ms))
228
+ self.lambda_ = 0.5 * (lam_real + lam_proxy)
229
+
230
+ # -------- Constraint projection, optional real probe --------
231
+ project_gates_into_constraints(self.student, self.cfg.constraints)
232
+
233
+
234
+ if self.cfg.real_probe_every and (step % int(self.cfg.real_probe_every) == 0):
235
+ # Build a probe shape for latency func if needed
236
+ try:
237
+ from core.measure import measure_latency_text_ms # text-friendly
238
+ if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
239
+ B, S = int(batch["input_ids"].size(0)), int(batch["input_ids"].size(1))
240
+ else:
241
+ # Fallback: try tensor-like batch
242
+ x0 = batch["input_ids"] if isinstance(batch, dict) else (batch[0] if isinstance(batch, (tuple, list)) else batch)
243
+ B = int(x0.size(0)); S = int(x0.size(1))
244
+ slim = self.export_pruned(self.student, real_policy or self.export_policy, step)
245
+ mean_ms, p95_ms = measure_latency_text_ms(slim, B=B, S=S, T=128, device=device)
246
+ except Exception:
247
+ # If the project has a different profiler, retain compatibility:
248
+ from .profiler import measure_latency_ms
249
+ x0 = batch["input_ids"] if isinstance(batch, dict) else (batch[0] if isinstance(batch, (tuple, list)) else batch)
250
+ shape = (int(x0.size(0)), *list(x0.shape[1:]))
251
+ slim = self.export_pruned(self.student, real_policy or self.export_policy, step)
252
+ mean_ms, p95_ms = measure_latency_ms(slim, shape, device=device)
253
+
254
+ with torch.no_grad():
255
+ lam_real = max(0.0, self.lambda_ + self.cfg.dual.lr * (mean_ms - self.cfg.latency_target_ms))
256
+
257
+ # scale_correction = mean_ms / max(1e-9, o1_ms.detach())
258
+ # self.proxy.cfg.scale_ms = 0.9 * self.proxy.cfg.scale_ms + 0.1 * scale_correction * self.proxy.cfg.scale_ms
259
+
260
+
261
+ if (step % verbose_every) == 0:
262
+ print(
263
+ f"Step {step}/{len(loader)} | KL={float(loss_w.item()):.6f} | MSE={float(mse.item()):.6f} | "
264
+ f"Gate={float(loss_g.item()):.6f} | "
265
+ f"proxy={float(o1_ms.detach()):.3f}ms | real_mean={mean_ms:.3f}ms p95={p95_ms:.3f}ms | λ={self.lambda_:.6f}"
266
+ )
267
+
268
+ running += float(loss_g.detach())
269
+ seen += _batch_size(batch)
270
+
271
+ del s_logits, t_logits, o1_ms, kd_g, reg, loss_g, loss_w
272
+ torch.cuda.empty_cache()
273
+ gc.collect()
274
+
275
+ print(f"Epoch loss {running / max(1, seen):.6f}")
276
+ return self.lambda_
277
+
278
+
279
+ # -----------------------------------------------------------------------------
280
+ # Helpers
281
+ # -----------------------------------------------------------------------------
282
+
283
+ def _to_tensor(val: float, like: torch.Tensor) -> torch.Tensor:
284
+ return torch.as_tensor(val, device=like.device, dtype=like.dtype)
285
+
286
+ def _move_batch_to_device(batch, device: str):
287
+ """
288
+ Supports:
289
+ - dict with keys 'input_ids' and optional 'attention_mask'
290
+ - (x,) or (x, y) tuples/lists -> move each tensor-like to device
291
+ - single Tensor
292
+ Converts attention_mask to bool (preferred by HF SDPA).
293
+ """
294
+ if isinstance(batch, dict):
295
+ out = {}
296
+ for k, v in batch.items():
297
+ if torch.is_tensor(v):
298
+ v = v.to(device, non_blocking=True)
299
+ if k == "attention_mask" and v.dtype != torch.bool:
300
+ v = v.to(torch.bool)
301
+ out[k] = v
302
+ return out
303
+
304
+ if isinstance(batch, (tuple, list)):
305
+ moved = []
306
+ for v in batch:
307
+ if torch.is_tensor(v):
308
+ v = v.to(device, non_blocking=True)
309
+ moved.append(v)
310
+ return type(batch)(moved)
311
+
312
+ if torch.is_tensor(batch):
313
+ return batch.to(device, non_blocking=True)
314
+
315
+ # Unknown type: return as-is (adapters/proxy should handle it)
316
+ return batch
317
+
318
+
319
+ def _batch_size(batch) -> int:
320
+ """Best-effort batch size for logging/averages."""
321
+ if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
322
+ return int(batch["input_ids"].size(0))
323
+ if torch.is_tensor(batch):
324
+ return int(batch.size(0))
325
+ if isinstance(batch, (tuple, list)) and len(batch) and torch.is_tensor(batch[0]):
326
+ return int(batch[0].size(0))
327
+ return 1
core/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared utilities used across core and adapters.
2
+
3
+ Consolidates helpers that are generic (device/dtype, seeding, shapes, rounding,
4
+ parameter grouping, model copying, etc.). Keep this file dependency-light.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
10
+
11
+ import copy
12
+ import random
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Device / dtype helpers
21
+ # -----------------------------------------------------------------------------
22
+
23
+ def as_like(x: torch.Tensor, val) -> torch.Tensor:
24
+ """Create a scalar/tensor constant on same device/dtype as `x`."""
25
+ return torch.as_tensor(val, device=x.device, dtype=x.dtype)
26
+
27
+
28
+ def first_param(module: nn.Module) -> torch.Tensor:
29
+ for p in module.parameters(recurse=True):
30
+ return p
31
+ return torch.tensor(0.0)
32
+
33
+
34
+ def to_device_dtype(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
35
+ return x.to(device=ref.device, dtype=ref.dtype)
36
+
37
+
38
+ # -----------------------------------------------------------------------------
39
+ # Seeding & determinism
40
+ # -----------------------------------------------------------------------------
41
+
42
+ def set_seed(seed: int = 42, deterministic: bool = False) -> None:
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed)
47
+ if deterministic:
48
+ torch.backends.cudnn.deterministic = True
49
+ torch.backends.cudnn.benchmark = False
50
+
51
+
52
+ # -----------------------------------------------------------------------------
53
+ # Model parameter helpers
54
+ # -----------------------------------------------------------------------------
55
+
56
+ def freeze(module: nn.Module) -> None:
57
+ for p in module.parameters():
58
+ p.requires_grad_(False)
59
+
60
+
61
+ def unfreeze(module: nn.Module) -> None:
62
+ for p in module.parameters():
63
+ p.requires_grad_(True)
64
+
65
+
66
+ def count_parameters(module: nn.Module, *, trainable_only: bool = False) -> int:
67
+ if trainable_only:
68
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
69
+ return sum(p.numel() for p in module.parameters())
70
+
71
+
72
+ # -----------------------------------------------------------------------------
73
+ # Shape/signature helpers
74
+ # -----------------------------------------------------------------------------
75
+
76
+ def input_spec_vision(sample) -> Tuple[int, int, int]:
77
+ """Accept either a 4D tensor [B,3,H,W] or a 4-tuple (B,3,H,W). Returns (B,H,W)."""
78
+ if isinstance(sample, torch.Tensor):
79
+ B, C, H, W = sample.shape
80
+ return int(B), int(H), int(W)
81
+ if isinstance(sample, (tuple, list)) and len(sample) == 4:
82
+ B, C, H, W = sample
83
+ return int(B), int(H), int(W)
84
+ raise ValueError("sample must be a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
85
+
86
+
87
+ # -----------------------------------------------------------------------------
88
+ # Rounding / multiples
89
+ # -----------------------------------------------------------------------------
90
+
91
+ def round_down_multiple(n: int, m: int) -> int:
92
+ if m is None or m <= 1:
93
+ return max(1, int(n))
94
+ n = int(n)
95
+ return max(m, (n // m) * m)
96
+
97
+
98
+ def clamp_int(v: int, lo: int, hi: int) -> int:
99
+ return max(lo, min(int(v), hi))
100
+
101
+
102
+ # -----------------------------------------------------------------------------
103
+ # Slicing helpers
104
+ # -----------------------------------------------------------------------------
105
+
106
+ @torch.no_grad()
107
+ def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
108
+ W = mat.weight.detach()
109
+ b = mat.bias.detach() if mat.bias is not None else None
110
+ if keep_out is not None:
111
+ idx_out = torch.as_tensor(keep_out, device=W.device)
112
+ W = W.index_select(0, idx_out)
113
+ if b is not None:
114
+ b = b.index_select(0, idx_out)
115
+ if keep_in is not None:
116
+ idx_in = torch.as_tensor(keep_in, device=W.device)
117
+ W = W.index_select(1, idx_in)
118
+ out_f, in_f = W.shape
119
+ new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
120
+ new.weight.copy_(W)
121
+ if b is not None:
122
+ new.bias.copy_(b)
123
+ return new
124
+
125
+
126
+ # -----------------------------------------------------------------------------
127
+ # Copying & detaching models
128
+ # -----------------------------------------------------------------------------
129
+
130
+ def deepcopy_eval_cpu(module: nn.Module) -> nn.Module:
131
+ m = copy.deepcopy(module).cpu().eval()
132
+ return m
133
+
134
+
135
+ # -----------------------------------------------------------------------------
136
+ # Gradient utilities
137
+ # -----------------------------------------------------------------------------
138
+
139
+ def zero_if_any(params: Iterable[torch.Tensor]) -> None:
140
+ for p in params:
141
+ if p.grad is not None:
142
+ p.grad = None
143
+
144
+
145
+ def any_grad(params: Iterable[torch.Tensor]) -> bool:
146
+ for p in params:
147
+ if p.grad is not None:
148
+ return True
149
+ return False
150
+
151
+ # -----------------------------------------------------------------------------
152
+ # For fine-tuning
153
+ # -----------------------------------------------------------------------------
154
+
155
+ def ensure_trainable_parameters(module: nn.Module, *, requires_grad: bool = True) -> nn.Module:
156
+ """
157
+ Rebuild all parameters as fresh nn.Parameter tensors (detach+clone),
158
+ which drops any 'inference tensor' tag and re-enables autograd.
159
+ """
160
+ for mod in module.modules():
161
+ for name, p in list(mod._parameters.items()):
162
+ if p is None:
163
+ continue
164
+ new_p = nn.Parameter(p.detach().clone(), requires_grad=requires_grad)
165
+ setattr(mod, name, new_p)
166
+ return module
167
+
168
+
169
+ # -----------------------------------------------------------------------------
170
+ # Misc
171
+ # -----------------------------------------------------------------------------
172
+
173
+ @dataclass
174
+ class ExportRounding:
175
+ head_floor_post: int = 1
176
+ head_multiple_post: int = 1
177
+ ffn_min_keep_ratio_post: float = 0.0
178
+ ffn_snap_groups_post: int = 1
179
+
180
+
181
+ def shape_signature_vit(cfg, sample_shape: Tuple[int, int, int, int]) -> Tuple:
182
+ B, C, H, W = sample_shape
183
+ return (
184
+ "ViT",
185
+ sample_shape,
186
+ int(getattr(cfg, "num_attention_heads", 12)),
187
+ int(getattr(cfg, "hidden_size", 768)),
188
+ int(getattr(cfg, "intermediate_size", 3072)),
189
+ int(getattr(cfg, "patch_size", 16)) if not isinstance(getattr(cfg, "patch_size", 16), (tuple, list)) else tuple(getattr(cfg, "patch_size", (16, 16))),
190
+ )
core/__init__.py ADDED
File without changes
core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (127 Bytes). View file
 
core/__pycache__/distill.cpython-310.pyc ADDED
Binary file (6.94 kB). View file
 
core/__pycache__/export.cpython-310.pyc ADDED
Binary file (7.31 kB). View file
 
core/__pycache__/finetune.cpython-310.pyc ADDED
Binary file (7.35 kB). View file
 
core/__pycache__/gates.cpython-310.pyc ADDED
Binary file (13.6 kB). View file
 
core/__pycache__/profiler.cpython-310.pyc ADDED
Binary file (7.68 kB). View file
 
core/__pycache__/proxy_cost.cpython-310.pyc ADDED
Binary file (22.8 kB). View file
 
core/__pycache__/search_export.cpython-310.pyc ADDED
Binary file (2.95 kB). View file
 
core/__pycache__/train.cpython-310.pyc ADDED
Binary file (9.12 kB). View file
 
core/__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.98 kB). View file
 
core/distill.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Knowledge-distillation utilities (model-family agnostic).
2
+
3
+ This module provides:
4
+ - Losses: KL distillation, soft cross-entropy, cosine feature loss
5
+ - Helper to obtain logits from models with/without built-in heads
6
+ - Lightweight classification head for backbone models (e.g., ViTModel)
7
+ - Simple evaluators (agreement %, KL) and diagnostics
8
+
9
+ Adapters may override `adapter_get_logits(model, x)` if a family needs a
10
+ custom extraction (e.g., language models with past_key_values).
11
+ """
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+ from typing import Callable, Optional, Protocol, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ # -----------------------------------------------------------------------------
23
+ # Config
24
+ # -----------------------------------------------------------------------------
25
+
26
+ @dataclass
27
+ class KDConfig:
28
+ temperature: float = 2.0
29
+ alpha: float = 1.0 # multiplier for KL term; task loss handled outside
30
+
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Losses
34
+ # -----------------------------------------------------------------------------
35
+
36
+ def kl_divergence(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float = 2.0) -> torch.Tensor:
37
+ """Batchmean KL(student/ T || teacher/ T) scaled by T^2 (Hinton-style)."""
38
+ p_s = F.log_softmax(student_logits / T, dim=-1)
39
+ p_t = F.softmax(teacher_logits / T, dim=-1)
40
+ return F.kl_div(p_s, p_t, reduction="batchmean") * (T * T)
41
+
42
+ def kd_loss(student_logits: torch.Tensor, teacher_logits: torch.Tensor, cfg: KDConfig) -> torch.Tensor:
43
+ return cfg.alpha * kl_divergence(student_logits, teacher_logits, T=cfg.temperature)
44
+
45
+ def mse_reg(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float = 2.0) -> torch.Tensor:
46
+ mse = F.mse_loss(student_logits,teacher_logits, reduction="mean")
47
+ return mse * (T * T)
48
+
49
+ def soft_ce(student_logits: torch.Tensor, soft_targets: torch.Tensor) -> torch.Tensor:
50
+ """Soft cross-entropy: expects `soft_targets` already normalized."""
51
+ logp = F.log_softmax(student_logits, dim=-1)
52
+ return -(soft_targets * logp).sum(dim=-1).mean()
53
+
54
+ def cosine_feature_loss(student_feats: torch.Tensor, teacher_feats: torch.Tensor) -> torch.Tensor:
55
+ """1 - cosine similarity averaged over batch and time/patch dims."""
56
+ s = F.normalize(student_feats, dim=-1)
57
+ t = F.normalize(teacher_feats, dim=-1)
58
+ return (1.0 - (s * t).sum(dim=-1)).mean()
59
+
60
+
61
+
62
+ # -----------------------------------------------------------------------------
63
+ # Logit extraction
64
+ # -----------------------------------------------------------------------------
65
+
66
+ class LogitsProvider(Protocol):
67
+ def __call__(self, model: nn.Module, x: torch.Tensor) -> torch.Tensor: ...
68
+
69
+
70
+ class ClsHead(nn.Module):
71
+ """Minimal classification head: LN + Linear.
72
+
73
+ Useful when the backbone outputs hidden states (e.g., ViTModel) and you
74
+ want logits comparable to a teacher with a classification head.
75
+ """
76
+
77
+ def __init__(self, hidden_size: int, num_classes: int = 1000, base_head: Optional[nn.Module] = None):
78
+ super().__init__()
79
+ self.norm = nn.LayerNorm(hidden_size)
80
+ self.fc = nn.Linear(hidden_size, num_classes)
81
+ if base_head is not None:
82
+ # Try to load weights if shapes match (e.g., from HF classifier)
83
+ try:
84
+ self.load_state_dict(base_head.state_dict(), strict=False)
85
+ except Exception:
86
+ pass
87
+
88
+ def forward(self, cls_token: torch.Tensor) -> torch.Tensor:
89
+ return self.fc(self.norm(cls_token))
90
+
91
+
92
+ @torch.no_grad()
93
+ def infer_hidden_size(model: nn.Module, sample: torch.Tensor) -> int:
94
+ # Run a tiny forward to inspect hidden size when unknown
95
+ model.eval()
96
+ out = model(pixel_values=sample)
97
+ if hasattr(out, "last_hidden_state"):
98
+ return int(out.last_hidden_state.shape[-1])
99
+ if hasattr(out, "logits"):
100
+ return int(out.logits.shape[-1])
101
+ raise RuntimeError("Cannot infer hidden size; provide explicitly.")
102
+
103
+
104
+ def default_get_logits(model: nn.Module, x: torch.Tensor, *, head: Optional[nn.Module] = None) -> torch.Tensor:
105
+ """Family-agnostic logits extractor.
106
+
107
+ - If model output has `.logits`, return it.
108
+ - Else expects `.last_hidden_state` and uses [CLS] via provided `head`.
109
+ """
110
+ out = model(pixel_values=x)
111
+ if hasattr(out, "logits"):
112
+ return out.logits
113
+ if hasattr(out, "last_hidden_state"):
114
+ if head is None:
115
+ raise ValueError("Backbone returned hidden states; supply a classification head.")
116
+ cls_tok = out.last_hidden_state[:, 0, :]
117
+ return head(cls_tok)
118
+ raise ValueError("Model output lacks logits and last_hidden_state.")
119
+
120
+
121
+ # -----------------------------------------------------------------------------
122
+ # Evaluators & diagnostics
123
+ # -----------------------------------------------------------------------------
124
+
125
+ @torch.inference_mode()
126
+ def logits_std(model: nn.Module, loader, *, get_logits: LogitsProvider, batches: int = 10, device: str = "cuda") -> Tuple[float, int]:
127
+ s = 0.0
128
+ k = 0
129
+ for x in loader:
130
+ if k >= batches:
131
+ break
132
+ x = x.to(device)
133
+ y = get_logits(model, x)
134
+ s += y.std().item()
135
+ k += 1
136
+ return (s / max(1, k), k)
137
+
138
+
139
+ @torch.inference_mode()
140
+ def agreement_metrics(
141
+ student: nn.Module,
142
+ teacher: nn.Module,
143
+ loader,
144
+ *,
145
+ get_student_logits: LogitsProvider,
146
+ get_teacher_logits: LogitsProvider,
147
+ batches: int = 20,
148
+ T: float = 1.0,
149
+ device: str = "cuda",
150
+ ) -> dict:
151
+ kl_sum = 0.0
152
+ n = 0
153
+ top1 = 0
154
+ tot = 0
155
+ for i, x in enumerate(loader):
156
+ if i >= batches:
157
+ break
158
+ x = x.to(device)
159
+ t = get_teacher_logits(teacher, x)
160
+ s = get_student_logits(student, x)
161
+ p_s = F.log_softmax(s / T, dim=-1)
162
+ p_t = F.softmax(t / T, dim=-1)
163
+ kl_sum += (F.kl_div(p_s, p_t, reduction="batchmean") * (T * T)).item()
164
+ top1 += (s.argmax(-1) == t.argmax(-1)).sum().item()
165
+ tot += x.size(0)
166
+ n += 1
167
+ return {"kl_TT": kl_sum / max(1, n), "top1_agreement": top1 / max(1, tot)}
168
+
169
+
170
+ # -----------------------------------------------------------------------------
171
+ # Small trainer helpers
172
+ # -----------------------------------------------------------------------------
173
+
174
+ class DualEMA:
175
+ """Simple exponential moving average for a scalar (e.g., lambda or latency)."""
176
+
177
+ def __init__(self, beta: float = 0.9, value: float = 0.0):
178
+ self.beta = float(beta)
179
+ self.value = float(value)
180
+
181
+ def update(self, x: float) -> float:
182
+ self.value = self.beta * self.value + (1 - self.beta) * float(x)
183
+ return self.value
core/export.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core export utilities for hard-pruning and kernel-aligned rounding.
2
+
3
+ This module is *family-agnostic*. Adapters (e.g., ViT, ResNet, LLM) should:
4
+ 1) decide which gates map to which structural dims (heads, hidden groups, channels),
5
+ 2) obtain KEEP indices using helpers in this file, and
6
+ 3) rebuild family-specific modules with the sliced weights.
7
+
8
+ Provided here:
9
+ - Rounding policies and helpers (floors, multiples, warmup keep-all)
10
+ - KEEP index selection from a `Gate` (or gate-like) object
11
+ - Generic weight slicers for Linear / Conv2d / Embedding
12
+ - Small safe-guards for dtype/device and shape checks
13
+
14
+ The library avoids touching family internals here. Exporters in adapters should
15
+ use these primitives to assemble a clean pruned model.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ from dataclasses import dataclass
20
+ from typing import Iterable, Optional, Sequence, Tuple
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from .gates import Gate, expand_group_indices
26
+
27
+ # -----------------------------------------------------------------------------
28
+ # Policies & rounding
29
+ # -----------------------------------------------------------------------------
30
+
31
+ @dataclass
32
+ class Rounding:
33
+ """Rounding policy for a single gated axis.
34
+
35
+ Attributes
36
+ ----------
37
+ floor_groups : int
38
+ Minimum number of groups to keep after rounding.
39
+ multiple_groups : int
40
+ Snap the number of groups kept down to a multiple of this (>=1).
41
+ min_keep_ratio : float
42
+ Optional fractional lower bound on expected keep; applied before rounding.
43
+ """
44
+
45
+ floor_groups: int = 1
46
+ multiple_groups: int = 1
47
+ min_keep_ratio: float = 0.0
48
+
49
+
50
+ @dataclass
51
+ class ExportPolicy:
52
+ """Export-time policy shared by families.
53
+
54
+ - `warmup_steps`: if current `step < warmup_steps`, keep-all.
55
+ - `rounding`: default rounding used unless adapter overrides per-axis.
56
+ """
57
+
58
+ warmup_steps: int = 0
59
+ rounding: Rounding = Rounding()
60
+
61
+
62
+ def _round_down_mult(n: int, m: int) -> int:
63
+ if m is None or m <= 1:
64
+ return max(1, int(n))
65
+ n = int(n)
66
+ return max(m, (n // m) * m)
67
+
68
+
69
+ def _compute_keep_k(
70
+ expected_kept: float,
71
+ total_groups: int,
72
+ *,
73
+ rounding: Rounding,
74
+ ) -> int:
75
+ # Start from nearest-integer expectation
76
+ k = int(round(expected_kept))
77
+ # Apply ratio floor, then absolute floor, then multiple snapping
78
+ k = max(k, int(rounding.min_keep_ratio * total_groups))
79
+ k = max(k, int(rounding.floor_groups))
80
+ k = min(k, total_groups)
81
+ k = _round_down_mult(k, int(rounding.multiple_groups))
82
+ return max(1, min(k, total_groups))
83
+
84
+
85
+ # -----------------------------------------------------------------------------
86
+ # KEEP index selection from a gate
87
+ # -----------------------------------------------------------------------------
88
+
89
+ @torch.no_grad()
90
+ def keep_group_indices_from_gate(
91
+ gate: Gate,
92
+ *,
93
+ policy: ExportPolicy,
94
+ step: Optional[int] = None,
95
+ custom_rounding: Optional[Rounding] = None,
96
+ ) -> torch.Tensor:
97
+ """Return sorted indices of groups to KEEP based on `gate` and policy.
98
+
99
+ If `step < warmup_steps`, returns all indices (keep-all). Otherwise, the
100
+ number of groups to keep is computed from the *expected keep* under the
101
+ current logits and snapped according to the rounding policy.
102
+ """
103
+ G = int(gate.num_groups)
104
+ if step is not None and step < int(policy.warmup_steps):
105
+ return torch.arange(G, device=gate.logits.device)
106
+
107
+ rounding = custom_rounding or policy.rounding
108
+ p = torch.sigmoid(gate.logits.detach().float() / float(gate.tau))
109
+ k = _compute_keep_k(expected_kept=float(p.sum()), total_groups=G, rounding=rounding)
110
+ idx = torch.topk(p, k, largest=True).indices.sort().values
111
+ return idx.to(torch.long)
112
+
113
+
114
+ @torch.no_grad()
115
+ def keep_element_indices_from_gate(
116
+ gate: Gate,
117
+ *,
118
+ policy: ExportPolicy,
119
+ step: Optional[int] = None,
120
+ custom_rounding: Optional[Rounding] = None,
121
+ ) -> torch.Tensor:
122
+ """Expand kept *group* indices into element indices using `group_size`."""
123
+ grp_idx = keep_group_indices_from_gate(gate, policy=policy, step=step, custom_rounding=custom_rounding)
124
+ return expand_group_indices(grp_idx, gate.group_size)
125
+
126
+
127
+ # -----------------------------------------------------------------------------
128
+ # Generic slicers
129
+ # -----------------------------------------------------------------------------
130
+
131
+ @torch.no_grad()
132
+ def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
133
+ """Create a new Linear with selected input/output features preserved.
134
+
135
+ - `keep_out` selects rows (output features)
136
+ - `keep_in` selects columns (input features)
137
+ """
138
+ W = mat.weight.detach()
139
+ b = mat.bias.detach() if mat.bias is not None else None
140
+
141
+ if keep_out is not None:
142
+ W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
143
+ if b is not None:
144
+ b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
145
+ if keep_in is not None:
146
+ W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))
147
+
148
+ out_f, in_f = W.shape
149
+ new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
150
+ new.weight.copy_(W)
151
+ if b is not None:
152
+ new.bias.copy_(b)
153
+ return new
154
+
155
+
156
+ @torch.no_grad()
157
+ def slice_conv2d(conv: nn.Conv2d, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Conv2d:
158
+ """Create a new Conv2d with selected in/out channels preserved.
159
+
160
+ Only supports standard conv2d (no groups/depthwise changes). For grouped
161
+ convs, the adapter should handle group alignment before calling this.
162
+ """
163
+ W = conv.weight.detach()
164
+ b = conv.bias.detach() if conv.bias is not None else None
165
+
166
+ if keep_out is not None:
167
+ W = W.index_select(0, torch.as_tensor(keep_out, device=W.device))
168
+ if b is not None:
169
+ b = b.index_select(0, torch.as_tensor(keep_out, device=b.device))
170
+ if keep_in is not None:
171
+ W = W.index_select(1, torch.as_tensor(keep_in, device=W.device))
172
+
173
+ out_c, in_c = W.shape[:2]
174
+ new = nn.Conv2d(
175
+ in_c,
176
+ out_c,
177
+ kernel_size=conv.kernel_size,
178
+ stride=conv.stride,
179
+ padding=conv.padding,
180
+ dilation=conv.dilation,
181
+ groups=1,
182
+ bias=(b is not None),
183
+ padding_mode=conv.padding_mode,
184
+ ).to(W.device)
185
+ new.weight.copy_(W)
186
+ if b is not None:
187
+ new.bias.copy_(b)
188
+ return new
189
+
190
+
191
+ @torch.no_grad()
192
+ def slice_embedding(emb: nn.Embedding, keep_rows: Optional[Sequence[int]] = None, keep_dim: Optional[Sequence[int]] = None) -> nn.Embedding:
193
+ """Create a new Embedding with selected rows (vocab) and/or dims kept."""
194
+ W = emb.weight.detach()
195
+ if keep_rows is not None:
196
+ W = W.index_select(0, torch.as_tensor(keep_rows, device=W.device))
197
+ if keep_dim is not None:
198
+ W = W.index_select(1, torch.as_tensor(keep_dim, device=W.device))
199
+ num, dim = W.shape
200
+ new = nn.Embedding(num, dim, padding_idx=emb.padding_idx, max_norm=emb.max_norm, norm_type=emb.norm_type, scale_grad_by_freq=emb.scale_grad_by_freq, sparse=emb.sparse, device=W.device, dtype=W.dtype)
201
+ new.weight.copy_(W)
202
+ return new
203
+
204
+
205
+ # -----------------------------------------------------------------------------
206
+ # Small helpers for adapters
207
+ # -----------------------------------------------------------------------------
208
+
209
+ @torch.no_grad()
210
+ def concat_index_ranges(ranges: Sequence[Tuple[int, int]]) -> torch.Tensor:
211
+ """Given [(start, end_exclusive), ...], return concatenated 1D indices."""
212
+ parts = [torch.arange(a, b, dtype=torch.long) for a, b in ranges if b > a]
213
+ return torch.cat(parts, dim=0) if parts else torch.empty(0, dtype=torch.long)
214
+
215
+
216
+ @torch.no_grad()
217
+ def block_indices_from_groups(groups: Sequence[int], group_size: int) -> torch.Tensor:
218
+ """Convert sorted group ids to expanded feature indices."""
219
+ groups = torch.as_tensor(groups, dtype=torch.long)
220
+ return expand_group_indices(groups, int(group_size))
core/finetune.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # core/finetune.py
2
+ """Post-pruning fine-tuning utilities (distillation)."""
3
+
4
+ from __future__ import annotations
5
+ from dataclasses import dataclass
6
+ from typing import Callable, Optional, Tuple, Iterable
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from core.distill import KDConfig, kd_loss, mse_reg
12
+ from core.utils import ensure_trainable_parameters
13
+
14
+ import copy
15
+
16
+
17
+ @dataclass
18
+ class FinetuneConfig:
19
+ epochs: int = 5
20
+ lr: float = 3e-4
21
+ wd: float = 0.0
22
+ kd: KDConfig = KDConfig(temperature=2.0, alpha=1.0)
23
+ amp: bool = True
24
+ # "auto" -> bf16 if supported else fp16; "bf16" | "fp16" | "off" also allowed
25
+ amp_dtype: str = "auto"
26
+ device: str = "cuda"
27
+ log_every: int = 200
28
+ # diagnostics
29
+ grad_check_every: int = 50
30
+ grad_warn_if_zero_steps: int = 2 # consecutive checks with zero grad -> warn
31
+ mse_weight: float = 0.0
32
+
33
+
34
+ def _autocast_and_scaler(amp: bool, amp_dtype: str) -> Tuple[torch.autocast, Optional[torch.amp.GradScaler], bool, str]:
35
+ """
36
+ Returns (autocast_ctx, scaler_or_None, use_scaler_bool, amp_mode_str)
37
+ - BF16 -> autocast(bfloat16), NO GradScaler
38
+ - FP16 -> autocast(float16), GradScaler ENABLED
39
+ - OFF -> disabled autocast, NO GradScaler
40
+ """
41
+ if not amp or amp_dtype == "off":
42
+ ctx = torch.amp.autocast(device_type="cuda", enabled=False)
43
+ return ctx, None, False, "OFF"
44
+
45
+ if amp_dtype == "auto":
46
+ use_bf16 = torch.cuda.is_bf16_supported()
47
+ elif amp_dtype == "bf16":
48
+ use_bf16 = True
49
+ elif amp_dtype == "fp16":
50
+ use_bf16 = False
51
+ else:
52
+ raise ValueError(f"Unknown amp_dtype={amp_dtype!r}")
53
+
54
+ if use_bf16:
55
+ ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True)
56
+ return ctx, None, False, "BF16"
57
+ else:
58
+ ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True)
59
+ try:
60
+ scaler = torch.amp.GradScaler("cuda", enabled=True)
61
+ except TypeError:
62
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
63
+ return ctx, scaler, True, "FP16"
64
+
65
+
66
+ def _images_from_batch(batch):
67
+ if isinstance(batch, dict):
68
+ return batch.get("pixel_values") or batch.get("input")
69
+ if isinstance(batch, (tuple, list)):
70
+ return batch[0]
71
+ return batch
72
+
73
+
74
+ def _param_iter_trainable(model: nn.Module) -> Iterable[torch.nn.Parameter]:
75
+ for p in model.parameters():
76
+ if p.requires_grad:
77
+ yield p
78
+
79
+
80
+ def _grad_norm_and_nonzero(params: Iterable[torch.nn.Parameter]) -> Tuple[float, int]:
81
+ total_sq, nonzero = 0.0, 0
82
+ for p in params:
83
+ g = p.grad
84
+ if g is None:
85
+ continue
86
+ if g.is_sparse:
87
+ g = g.coalesce().values()
88
+ gn = float(g.detach().norm().cpu())
89
+ if gn > 0.0:
90
+ nonzero += 1
91
+ total_sq += gn * gn
92
+ return (total_sq ** 0.5), nonzero
93
+
94
+ @torch.no_grad()
95
+ def recalibrate_bn_stats(model, loader, max_batches=200, device="cuda"):
96
+ model.train() # use training mode to update running stats
97
+ seen = 0
98
+ for i, batch in enumerate(loader):
99
+ if i >= max_batches: break
100
+ x = batch[0] if isinstance(batch, (tuple, list)) else batch
101
+ if not torch.is_tensor(x): continue
102
+ x = x.to(device, non_blocking=True)
103
+ model(x)
104
+ seen += x.size(0)
105
+ return seen
106
+
107
+
108
+ def finetune_student(
109
+ student: nn.Module,
110
+ teacher: nn.Module,
111
+ train_loader,
112
+ *,
113
+ get_student_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
114
+ get_teacher_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
115
+ cfg: FinetuneConfig = FinetuneConfig(),
116
+ val_loader=None,
117
+ on_step: Optional[Callable[[int, float], None]] = None,
118
+ save_best=False
119
+ ) -> nn.Module:
120
+ """Fine-tune a pruned student against a frozen teacher using KD."""
121
+ dev = cfg.device
122
+ student = student.to(dev)
123
+ teacher = teacher.to(dev).eval()
124
+ for p in teacher.parameters():
125
+ p.requires_grad_(False)
126
+ for p in student.parameters():
127
+ p.requires_grad_(True)
128
+
129
+ # Make sure we can actually train
130
+ ensure_trainable_parameters(student, requires_grad=True)
131
+ trainable = sum(p.numel() for p in student.parameters() if p.requires_grad)
132
+ if trainable == 0:
133
+ raise RuntimeError("No trainable parameters in student — cannot finetune.")
134
+
135
+ opt = torch.optim.AdamW(
136
+ _param_iter_trainable(student),
137
+ lr=cfg.lr,
138
+ weight_decay=cfg.wd,
139
+ )
140
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs*len(train_loader), eta_min=3e-5)
141
+
142
+
143
+ autocast_ctx, scaler, use_scaler, amp_mode = _autocast_and_scaler(cfg.amp, cfg.amp_dtype)
144
+ print(f"[AMP] Mode={amp_mode} | GradScaler={'ON' if use_scaler else 'OFF'} | "
145
+ f"KD: T={cfg.kd.temperature} alpha={cfg.kd.alpha} | LR={cfg.lr} WD={cfg.wd} | Trainable params={trainable:,}")
146
+
147
+ zero_grad_streak = 0
148
+ global_step = 0
149
+
150
+ T_max = cfg.kd.temperature
151
+ T_min = 2.0
152
+ kd_conf = cfg.kd
153
+
154
+ best_state = None
155
+ best_val = float("inf")
156
+
157
+ for ep in range(cfg.epochs):
158
+ student.train()
159
+ running, seen = 0.0, 0
160
+
161
+ for i, batch in enumerate(train_loader):
162
+
163
+ step = ep*len(train_loader) + i # global step for T scheduling
164
+ max_steps = cfg.epochs*len(train_loader)
165
+ kd_conf.temperature = T_max - (step/max_steps)*(T_max - T_min)
166
+
167
+ # print(f"Step {step}/{max_steps}, T_min={T_min}, T={kd_conf.temperature}, T_max={T_max}")
168
+
169
+ x = _images_from_batch(batch)
170
+ if not torch.is_tensor(x):
171
+ raise ValueError("Train loader must yield tensors or (tensor, target) tuples.")
172
+ x = x.to(dev, non_blocking=True)
173
+
174
+ with torch.no_grad():
175
+ t = get_teacher_logits(teacher, x)
176
+ # Force numerically stable dtype for the loss
177
+ t = t.float()
178
+
179
+ # ---- forward student under autocast
180
+ with autocast_ctx:
181
+ s = get_student_logits(student, x)
182
+
183
+ # ---- compute KD loss in FP32 (outside autocast) for stability
184
+ s32 = s.float()
185
+ mse = cfg.mse_weight*mse_reg(s32, t, kd_conf.temperature)
186
+ loss = kd_loss(s32, t, kd_conf) + mse
187
+
188
+ opt.zero_grad(set_to_none=True)
189
+ if use_scaler:
190
+ scaler.scale(loss).backward()
191
+ scaler.step(opt)
192
+ scaler.update()
193
+ else:
194
+ loss.backward()
195
+ opt.step()
196
+
197
+ # ---- diagnostics
198
+ bs = x.size(0)
199
+ running += float(loss.detach()) * bs
200
+ seen += bs
201
+ global_step += 1
202
+
203
+ if cfg.grad_check_every and (global_step % cfg.grad_check_every == 0):
204
+ gnorm, n_nonzero = _grad_norm_and_nonzero(_param_iter_trainable(student))
205
+ if n_nonzero == 0 or gnorm == 0.0:
206
+ zero_grad_streak += 1
207
+ if zero_grad_streak >= cfg.grad_warn_if_zero_steps:
208
+ print(f"[WARN] Step {global_step}: zero gradients detected "
209
+ f"(nonzero={n_nonzero}, grad_norm={gnorm:.3e}). "
210
+ f"Check get_student_logits, requires_grad, AMP settings, and data pipeline.")
211
+ else:
212
+ zero_grad_streak = 0
213
+
214
+ if cfg.log_every and (i + 1) % cfg.log_every == 0:
215
+ print(f"Step {i+1}/{len(train_loader)} (ep {ep+1}/{cfg.epochs}): "
216
+ f"running loss = {running / max(1, seen):.4f}")
217
+
218
+ if on_step is not None:
219
+ on_step(global_step, float(loss.detach()))
220
+
221
+ # free ASAP
222
+ del s, s32, t, loss
223
+
224
+ # ---- validation
225
+ if val_loader is not None:
226
+ _ = recalibrate_bn_stats(student, train_loader, max_batches=1000, device=cfg.device)
227
+ student.eval()
228
+ val_loss, vseen = 0.0, 0
229
+ with torch.no_grad():
230
+ for vbatch in val_loader:
231
+ vx = _images_from_batch(vbatch)
232
+ if not torch.is_tensor(vx):
233
+ raise ValueError("Val loader must yield tensors or (tensor, target) tuples.")
234
+ vx = vx.to(dev, non_blocking=True)
235
+
236
+ vt = get_teacher_logits(teacher, vx).float()
237
+ with autocast_ctx:
238
+ vs = get_student_logits(student, vx)
239
+
240
+ vs32 = vs.float()
241
+ vmse = cfg.mse_weight*mse_reg(vs32, vt, kd_conf.temperature)
242
+ vloss = kd_loss(vs32, vt, kd_conf) + vmse
243
+ val_loss += float(vloss.detach()) * vx.size(0)
244
+ vseen += vx.size(0)
245
+
246
+ mean_val = val_loss / max(1, vseen)
247
+ print("\n------------------------------------------------")
248
+ print(f"Epoch {ep+1}/{cfg.epochs}: T={kd_conf.temperature:.2f}, train={running / max(1, seen):.6f}, "
249
+ f"val={mean_val:.6f}")
250
+
251
+ if save_best and (mean_val < best_val):
252
+ best_val = mean_val
253
+ best_state = copy.deepcopy(student.state_dict())
254
+
255
+ print("------------------------------------------------\n")
256
+
257
+ else:
258
+ print(f"Epoch {ep+1}/{cfg.epochs}: train={running / max(1, seen):.6f}")
259
+
260
+ scheduler.step()
261
+
262
+ if save_best and val_loader is not None and best_state is not None:
263
+ student.load_state_dict(best_state)
264
+
265
+ student.eval()
266
+ return student
267
+
core/gates.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core gating primitives for hardware-aware model optimization.
2
+
3
+ This module defines:
4
+ - Base `Gate` interface (nn.Module) with a small, consistent API
5
+ - Concrete gates: HeadGate, GroupGate, LayerGate
6
+ - Straight-Through (ST) relaxed Bernoulli with Gumbel noise
7
+ - Penalties/regularizers commonly used during training
8
+ - Constraint projection helpers
9
+
10
+ Design goals:
11
+ - TorchScript-friendly where possible
12
+ - Minimal assumptions about model family (ViT, ResNet, LLM)
13
+ - Gates operate on *groups* of units; group_size controls expansion
14
+ - No direct knowledge of attention/FFN/etc. — adapters wire masks
15
+
16
+ Typical usage (adapter side):
17
+ >>> gate = GroupGate(num_groups=H, group_size=Dh, tau=1.5, init_logit=3.0)
18
+ >>> m = gate.mask(training=self.training) # [H * Dh]
19
+ >>> tensor = tensor * m.view(1, H, 1, Dh) # example broadcast
20
+
21
+ Penalties scan the module tree for objects exposing `.logits` and `.tau`.
22
+ """
23
+ from __future__ import annotations
24
+
25
+ from dataclasses import dataclass
26
+ from typing import Iterable, List, Optional, Tuple
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Utilities
34
+ # -----------------------------------------------------------------------------
35
+
36
+ def _as_like(x: torch.Tensor, val) -> torch.Tensor:
37
+ return torch.as_tensor(val, device=x.device, dtype=x.dtype)
38
+
39
+
40
+ def _gumbel_like(x: torch.Tensor) -> torch.Tensor:
41
+ # Uniform(0,1) clamped for numerical stability
42
+ u = torch.rand_like(x).clamp_(1e-6, 1 - 1e-6)
43
+ return u.log().neg_() - (1 - u).log().neg_() # log(u) - log(1-u)
44
+
45
+
46
+ # -----------------------------------------------------------------------------
47
+ # Base Gate
48
+ # -----------------------------------------------------------------------------
49
+
50
+ class Gate(nn.Module):
51
+ """Abstract gate over *groups*.
52
+
53
+ A gate controls `num_groups` binary decisions, typically expanded by
54
+ `group_size` when applied to tensors. For example, gating ViT MLP hidden
55
+ units in groups of 16: `num_groups = hidden // 16`, `group_size = 16`.
56
+
57
+ Subclasses may override `sample_mask` for custom relaxations.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ num_groups: int,
63
+ *,
64
+ group_size: int = 1,
65
+ tau: float = 1.5,
66
+ init_logit: float = 3.0,
67
+ hard_during_eval: bool = True,
68
+ ) -> None:
69
+ super().__init__()
70
+ assert num_groups > 0 and group_size > 0
71
+ self.num_groups = int(num_groups)
72
+ self.group_size = int(group_size)
73
+ self.tau = float(tau)
74
+ self.hard_during_eval = bool(hard_during_eval)
75
+ self.logits = nn.Parameter(torch.full((self.num_groups,), float(init_logit)))
76
+
77
+ # ----- probabilities & stats ------------------------------------------------
78
+ def probs(self) -> torch.Tensor:
79
+ """Return per-group keep probabilities (sigmoid(logit / tau))."""
80
+ # Using /tau here makes `tau` affect both train and eval statistics
81
+ return torch.sigmoid(self.logits / self.tau)
82
+
83
+ def expected_kept(self) -> torch.Tensor:
84
+ """Expected *elements* kept (groups × group_size)."""
85
+ return self.probs().sum() * _as_like(self.logits, self.group_size)
86
+
87
+ # ----- masks ----------------------------------------------------------------
88
+ def _hard_mask(self) -> torch.Tensor:
89
+ m = (self.logits > 0).to(self.logits.dtype)
90
+ return m.repeat_interleave(self.group_size)
91
+
92
+ def _soft_st_mask(self) -> torch.Tensor:
93
+ # Straight-through relaxed Bernoulli via Gumbel-sigmoid
94
+ s = _gumbel_like(self.logits)
95
+ y = torch.sigmoid((self.logits + s) / self.tau)
96
+ y_hard = (y > 0.5).to(y.dtype)
97
+ m = (y_hard - y).detach() + y
98
+ return m.repeat_interleave(self.group_size)
99
+
100
+ def mask(self, training: Optional[bool] = None) -> torch.Tensor:
101
+ """Return a 1D mask of length `num_groups * group_size`.
102
+
103
+ - Training: straight-through relaxed mask
104
+ - Eval: hard (thresholded) mask if `hard_during_eval` else probs expanded
105
+ """
106
+ if training is None:
107
+ training = self.training
108
+ if training:
109
+ return self._soft_st_mask()
110
+ if self.hard_during_eval:
111
+ return self._hard_mask()
112
+ p = self.probs()
113
+ return p.repeat_interleave(self.group_size)
114
+
115
+ # ----- export helpers -------------------------------------------------------
116
+ @torch.no_grad()
117
+ def topk_indices(self, k: int) -> torch.Tensor:
118
+ k = int(max(1, min(k, self.num_groups)))
119
+ return torch.topk(self.logits, k, largest=True).indices.sort().values
120
+
121
+ @torch.no_grad()
122
+ def threshold_count(self) -> int:
123
+ # Rounds to the nearest integer expectation, then clamps
124
+ p = self.probs()
125
+ k = int(torch.round(p.sum()).item())
126
+ return max(1, min(k, self.num_groups))
127
+
128
+
129
+ # -----------------------------------------------------------------------------
130
+ # Concrete gates
131
+ # -----------------------------------------------------------------------------
132
+
133
+ class HeadGate(Gate):
134
+ """Per-head gate. Often used with attention where group_size=head_dim."""
135
+
136
+ def __init__(self, num_heads: int, *, head_dim: int = 1, **kw):
137
+ super().__init__(num_groups=num_heads, group_size=head_dim, **kw)
138
+
139
+
140
+ class GroupGate(Gate):
141
+ """Generic group gate (e.g., MLP hidden grouped by `group_size`)."""
142
+
143
+ pass
144
+
145
+
146
+ class LayerGate(Gate):
147
+ """One bit per layer (group_size=1)."""
148
+
149
+ def __init__(self, num_layers: int, **kw):
150
+ super().__init__(num_groups=num_layers, group_size=1, **kw)
151
+
152
+
153
+ # -----------------------------------------------------------------------------
154
+ # Penalties / Regularizers
155
+ # -----------------------------------------------------------------------------
156
+
157
+ @dataclass
158
+ class PenaltyWeights:
159
+ """Scalars to blend regularization terms.
160
+
161
+ Attributes
162
+ ----------
163
+ l0 : float
164
+ Weight for the L0-like sparsity term (sum of keep probs).
165
+ keep_floor_ratio : float
166
+ Soft constraint: expected kept groups >= floor_ratio * groups.
167
+ bimodality : float
168
+ Encourages probabilities away from 0.5.
169
+ """
170
+
171
+ l0: float = 0.0
172
+ keep_floor_ratio: float = 0.0
173
+ bimodality: float = 0.0
174
+
175
+
176
+ def iter_gates(module: nn.Module) -> Iterable[Gate]:
177
+ for m in module.modules():
178
+ if isinstance(m, Gate):
179
+ yield m
180
+ else:
181
+ # Duck-typing compatibility: any module with `.logits` and `.tau`
182
+ if hasattr(m, "logits") and hasattr(m, "tau"):
183
+ logits = getattr(m, "logits")
184
+ if isinstance(logits, torch.Tensor) and logits.dim() == 1:
185
+ # Wrap view: expose basic API via adapter shim
186
+ g = _TensorBackedGateShim(m)
187
+ yield g
188
+
189
+
190
+ class _TensorBackedGateShim:
191
+ """Lightweight adapter exposing .logits, .tau, .group_size, .num_groups.
192
+
193
+ It is intentionally NOT an nn.Module and NOT a Gate subclass to avoid
194
+ ctor/signature constraints and registration side-effects. It's only used
195
+ by projection/regularization utilities that read/update .logits.
196
+ """
197
+ __slots__ = ("host", "logits", "tau", "group_size", "num_groups")
198
+
199
+ def __init__(self, host):
200
+ self.host = host
201
+ # logits must be a Tensor/Parameter on the host
202
+ self.logits = getattr(host, "logits")
203
+ # default tau=1.5 if not present
204
+ self.tau = float(getattr(host, "tau", 1.5))
205
+ # support either group_size or group attribute names
206
+ self.group_size = int(getattr(host, "group_size", getattr(host, "group", 1)))
207
+ self.num_groups = int(self.logits.numel())
208
+
209
+ def forward(self, *args, **kwargs): # pragma: no cover - shim is not used as a layer
210
+ raise RuntimeError("Gate shim is not a callable layer")
211
+
212
+
213
+ def l0_like_sparsity(module: nn.Module) -> torch.Tensor:
214
+ """Sum of keep probabilities across all gates (acts like L0/L1)."""
215
+ val = _as_like(next(module.parameters(), torch.tensor(0.0, device="cpu")), 0.0)
216
+ out = torch.as_tensor(0.0, device=val.device, dtype=val.dtype)
217
+ for g in iter_gates(module):
218
+ out = out + g.probs().sum()
219
+ return out
220
+
221
+
222
+ def keep_floor(module: nn.Module, floor_ratio: float) -> torch.Tensor:
223
+ """Soft penalty if expected-kept falls below a fraction per gate.
224
+
225
+ For each gate with G groups, penalize relu(floor*G - sum(p)).
226
+ """
227
+ if floor_ratio <= 0:
228
+ return torch.tensor(0.0, device=next(module.parameters(), torch.tensor(0.0)).device)
229
+ floor_ratio = float(floor_ratio)
230
+ val = _as_like(next(module.parameters(), torch.tensor(0.0, device="cpu")), 0.0)
231
+ out = torch.as_tensor(0.0, device=val.device, dtype=val.dtype)
232
+ for g in iter_gates(module):
233
+ G = _as_like(val, g.num_groups)
234
+ floor_groups = _as_like(val, max(1.0, floor_ratio * float(g.num_groups)))
235
+ out = out + F.relu(floor_groups - g.probs().sum())
236
+ return out
237
+
238
+
239
+ def bimodality(module: nn.Module) -> torch.Tensor:
240
+ """Sum over p*(1-p) to push probs away from 0.5 (minimum at 0 or 1)."""
241
+ val = _as_like(next(module.parameters(), torch.tensor(0.0, device="cpu")), 0.0)
242
+ out = torch.as_tensor(0.0, device=val.device, dtype=val.dtype)
243
+ for g in iter_gates(module):
244
+ p = g.probs()
245
+ out = out + (p * (1.0 - p)).sum()
246
+ return out
247
+
248
+
249
+ def combined_penalty(
250
+ module: nn.Module,
251
+ weights: PenaltyWeights,
252
+ ) -> torch.Tensor:
253
+ out = torch.tensor(0.0, device=next(module.parameters(), torch.tensor(0.0)).device)
254
+ if weights.l0:
255
+ out = out + weights.l0 * l0_like_sparsity(module)
256
+ if weights.keep_floor_ratio:
257
+ out = out + keep_floor(module, weights.keep_floor_ratio)
258
+ if weights.bimodality:
259
+ out = out + weights.bimodality * bimodality(module)
260
+ return out
261
+
262
+
263
+ # -----------------------------------------------------------------------------
264
+ # Constraint projection
265
+ # -----------------------------------------------------------------------------
266
+
267
+ @dataclass
268
+ class Constraints:
269
+ """High-level feasibility constraints.
270
+
271
+ * min_keep_ratio: per-gate minimum fraction of groups to keep (soft cap via
272
+ projection onto [min_k, G]).
273
+ * min_groups: absolute lower bound per gate (after rounding).
274
+ * max_groups_drop: optional ceiling on groups dropped per gate.
275
+ """
276
+
277
+ min_keep_ratio: float = 0.0
278
+ min_groups: int = 1
279
+ max_groups_drop: Optional[int] = None
280
+
281
+
282
+ @torch.no_grad()
283
+ def project_gates_into_constraints(module: nn.Module, cons: Constraints) -> None:
284
+ """Project gate logits so that expected kept groups respect constraints.
285
+
286
+ We rescale logits by an additive bias to achieve a target sum of probs when
287
+ violating the lower/upper bounds. This is a light-touch projection that
288
+ keeps relative ordering intact.
289
+ """
290
+ for g in iter_gates(module):
291
+ p = torch.sigmoid(g.logits / g.tau)
292
+ G = p.numel()
293
+ # Lower bound
294
+ min_keep = max(cons.min_groups, int(cons.min_keep_ratio * G))
295
+ if p.sum().item() < min_keep:
296
+ # Additive bias to increase sum(p)
297
+ bias = torch.tensor(2.0, device=p.device, dtype=p.dtype)
298
+ # Increase iteratively but cheaply
299
+ for _ in range(6):
300
+ p = torch.sigmoid((g.logits + bias) / g.tau)
301
+ if p.sum().item() >= min_keep:
302
+ break
303
+ bias = bias * 2
304
+ g.logits.add_(bias)
305
+ # Optional upper bound on drops
306
+ if cons.max_groups_drop is not None:
307
+ max_drop = int(cons.max_groups_drop)
308
+ max_keep = max(1, G - max_drop)
309
+ if p.sum().item() > max_keep:
310
+ bias = torch.tensor(-2.0, device=p.device, dtype=p.dtype)
311
+ for _ in range(6):
312
+ p = torch.sigmoid((g.logits + bias) / g.tau)
313
+ if p.sum().item() <= max_keep:
314
+ break
315
+ bias = bias * 2
316
+ g.logits.add_(bias)
317
+
318
+
319
+ # -----------------------------------------------------------------------------
320
+ # Export helpers (indices from gates)
321
+ # -----------------------------------------------------------------------------
322
+
323
+ @torch.no_grad()
324
+ def topk_group_indices(g: Gate, keep_k: Optional[int] = None) -> torch.Tensor:
325
+ """Return sorted group indices to KEEP based on logits/probs.
326
+
327
+ If `keep_k` is None, use nearest-integer of expected kept.
328
+ """
329
+ if keep_k is None:
330
+ keep_k = g.threshold_count()
331
+ idx = torch.topk(g.logits, int(keep_k), largest=True).indices
332
+ return idx.sort().values
333
+
334
+
335
+ @torch.no_grad()
336
+ def expand_group_indices(idx: torch.Tensor, group_size: int) -> torch.Tensor:
337
+ """Expand group indices into element indices by `group_size` blocks."""
338
+ if group_size == 1:
339
+ return idx.clone()
340
+ starts = idx * group_size
341
+ parts = [torch.arange(s, s + group_size, device=idx.device) for s in starts]
342
+ return torch.cat(parts, dim=0).long()
343
+
344
+
345
+ # -----------------------------------------------------------------------------
346
+ # Parameter utilities
347
+ # -----------------------------------------------------------------------------
348
+
349
+ def collect_gate_params(module: nn.Module) -> List[nn.Parameter]:
350
+ return [g.logits for g in iter_gates(module) if isinstance(g.logits, torch.Tensor)]
351
+
352
+
353
+ def collect_param_groups(
354
+ module: nn.Module,
355
+ *,
356
+ lr_gate: float = 1e-2,
357
+ lr_linear: float = 1e-4,
358
+ lr_affine: float = 3e-4,
359
+ wd_linear: float = 1e-4,
360
+ ) -> List[dict]:
361
+ """Convenience grouping matching common training setups.
362
+
363
+ Group 0: gate logits (no weight decay)
364
+ Group 1: linear weights (with weight decay)
365
+ Group 2: linear biases (no decay)
366
+ Group 3: norm affine params (no decay)
367
+ """
368
+ gates, ln_affine, linear_w, linear_b = [], [], [], []
369
+ for n, p in module.named_parameters():
370
+ if not p.requires_grad:
371
+ continue
372
+ if n.endswith((".logits", ".head_gate", ".channel_gate")):
373
+ gates.append(p)
374
+ continue
375
+ is_linear_path = (".weight" in n or ".bias" in n) and (
376
+ ".dense" in n or ".query" in n or ".key" in n or ".value" in n or ".proj" in n
377
+ )
378
+ if n.endswith(".weight") and is_linear_path:
379
+ linear_w.append(p)
380
+ elif n.endswith(".bias") and is_linear_path:
381
+ linear_b.append(p)
382
+ elif "layernorm" in n.lower() or "layer_norm" in n.lower() or "LayerNorm" in n:
383
+ ln_affine.append(p)
384
+ return [
385
+ {"params": gates, "lr": lr_gate, "weight_decay": 0.0},
386
+ {"params": linear_w, "lr": lr_linear, "weight_decay": wd_linear},
387
+ {"params": linear_b, "lr": lr_linear, "weight_decay": 0.0},
388
+ {"params": ln_affine, "lr": lr_affine, "weight_decay": 0.0},
389
+ ]
core/profiler.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple, robust latency measurement utilities.
2
+
3
+ This module provides GPU-friendly profilers with warmup, multiple repeats,
4
+ median/percentile reporting, and optional outlier rejection via MAD.
5
+
6
+ Design goals:
7
+ - Family-agnostic: take a callable `forward(model, x)` or rely on HF `.forward`
8
+ - Deterministic when desired; avoids autograd by default
9
+ - Works with CUDA or CPU; uses `torch.cuda.Event` for accurate GPU timing
10
+
11
+ Key APIs:
12
+ - measure_latency_ms(model, input_shape | input_tensor, ...)
13
+ - profile(model, sample, settings) -> {mean, p50, p90, p95, p99}
14
+ - LatencyProfiler(settings).measure(...)
15
+ - profile_many_shapes(model, shapes, settings)
16
+ """
17
+ from __future__ import annotations
18
+
19
+ from dataclasses import dataclass
20
+ from statistics import median
21
+ from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple
22
+
23
+ import contextlib
24
+ import math
25
+ import time
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+
31
+ # -----------------------------------------------------------------------------
32
+ # Settings
33
+ # -----------------------------------------------------------------------------
34
+
35
+ @dataclass
36
+ class ProfileSettings:
37
+ warmup: int = 10
38
+ iters: int = 50
39
+ percentile: Sequence[int] = (50, 90, 95, 99)
40
+ sync_each_iter: bool = True
41
+ use_inference_mode: bool = True
42
+ cuda_graph: bool = False # advanced users can enable with static shapes
43
+ reject_outliers_mad: float = 0.0 # e.g., 3.5 to drop extreme spikes
44
+ cudnn_benchmark: bool = True
45
+ deterministic: bool = False # sets cudnn.deterministic
46
+
47
+
48
+ # -----------------------------------------------------------------------------
49
+ # Context helpers
50
+ # -----------------------------------------------------------------------------
51
+
52
+ @contextlib.contextmanager
53
+ def _torch_backend_ctx(settings: ProfileSettings):
54
+ prev_bench = torch.backends.cudnn.benchmark
55
+ prev_det = torch.backends.cudnn.deterministic
56
+ try:
57
+ torch.backends.cudnn.benchmark = bool(settings.cudnn_benchmark)
58
+ torch.backends.cudnn.deterministic = bool(settings.deterministic)
59
+ yield
60
+ finally:
61
+ torch.backends.cudnn.benchmark = prev_bench
62
+ torch.backends.cudnn.deterministic = prev_det
63
+
64
+
65
+ def _percentiles(sorted_vals: Sequence[float], qs: Sequence[int]) -> Dict[int, float]:
66
+ n = len(sorted_vals)
67
+ if n == 0:
68
+ return {q: float("nan") for q in qs}
69
+ out = {}
70
+ for q in qs:
71
+ if n == 1:
72
+ out[q] = sorted_vals[0]
73
+ continue
74
+ k = (q / 100.0) * (n - 1)
75
+ f = math.floor(k)
76
+ c = min(n - 1, f + 1)
77
+ if f == c:
78
+ out[q] = sorted_vals[int(k)]
79
+ else:
80
+ d0 = sorted_vals[f] * (c - k)
81
+ d1 = sorted_vals[c] * (k - f)
82
+ out[q] = d0 + d1
83
+ return out
84
+
85
+
86
+ def _apply_mad_filter(vals: Sequence[float], thresh: float) -> Sequence[float]:
87
+ if thresh <= 0 or len(vals) < 5:
88
+ return vals
89
+ med = median(vals)
90
+ dev = [abs(v - med) for v in vals]
91
+ mad = median(dev) or 1e-12
92
+ keep = [v for v, d in zip(vals, dev) if (d / mad) <= thresh]
93
+ return keep if keep else vals
94
+
95
+
96
+ # -----------------------------------------------------------------------------
97
+ # Core measurement
98
+ # -----------------------------------------------------------------------------
99
+
100
+ @torch.inference_mode()
101
+ def measure_latency_ms(
102
+ model: nn.Module,
103
+ sample: torch.Tensor | Tuple[int, ...],
104
+ *,
105
+ settings: Optional[ProfileSettings] = None,
106
+ device: str = "cuda",
107
+ forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
108
+ ) -> Tuple[float, float]:
109
+ """Return (mean_ms, p95_ms) over `iters` measurements.
110
+
111
+ If `sample` is a shape tuple, a random tensor is created on-device.
112
+ The default forward calls `model(pixel_values=x)` if available, else `model(x)`.
113
+ """
114
+ cfg = settings or ProfileSettings()
115
+
116
+ with _torch_backend_ctx(cfg):
117
+ m = model.to(device).eval()
118
+ if isinstance(sample, torch.Tensor):
119
+ x = sample.to(device)
120
+ else:
121
+ x = torch.randn(*sample, device=device)
122
+
123
+ # Default forward
124
+ def _fwd(mod, inp):
125
+ if hasattr(mod, "forward"):
126
+ try:
127
+ return mod(pixel_values=inp)
128
+ except TypeError:
129
+ return mod(inp)
130
+ return mod(inp)
131
+
132
+ fn = forward_fn or _fwd
133
+
134
+ # Warmup
135
+ if torch.cuda.is_available() and device.startswith("cuda"):
136
+ for _ in range(cfg.warmup):
137
+ _ = fn(m, x)
138
+ torch.cuda.synchronize()
139
+ else:
140
+ for _ in range(cfg.warmup):
141
+ _ = fn(m, x)
142
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
143
+
144
+ times: list[float] = []
145
+ if torch.cuda.is_available() and device.startswith("cuda"):
146
+ for _ in range(cfg.iters):
147
+ t0 = torch.cuda.Event(enable_timing=True)
148
+ t1 = torch.cuda.Event(enable_timing=True)
149
+ t0.record()
150
+ _ = fn(m, x)
151
+ t1.record()
152
+ if cfg.sync_each_iter:
153
+ torch.cuda.synchronize()
154
+ times.append(t0.elapsed_time(t1)) # milliseconds
155
+ else:
156
+ for _ in range(cfg.iters):
157
+ t0 = time.perf_counter()
158
+ _ = fn(m, x)
159
+ if cfg.sync_each_iter and torch.cuda.is_available():
160
+ torch.cuda.synchronize()
161
+ t1 = time.perf_counter()
162
+ times.append((t1 - t0) * 1000.0)
163
+
164
+ times = sorted(_apply_mad_filter(times, cfg.reject_outliers_mad))
165
+ mean_ms = sum(times) / max(1, len(times))
166
+ p = _percentiles(times, cfg.percentile)
167
+ p95 = p.get(95, times[int(0.95 * (len(times) - 1))] if times else float("nan"))
168
+ return mean_ms, p95
169
+
170
+
171
+ # Higher level wrapper returning multiple percentiles
172
+ @torch.inference_mode()
173
+ def profile(
174
+ model: nn.Module,
175
+ sample: torch.Tensor | Tuple[int, ...],
176
+ *,
177
+ settings: Optional[ProfileSettings] = None,
178
+ device: str = "cuda",
179
+ forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
180
+ ) -> Dict[str, float]:
181
+ cfg = settings or ProfileSettings()
182
+ mean_ms, _ = measure_latency_ms(model, sample, settings=cfg, device=device, forward_fn=forward_fn)
183
+ # Re-run percentile calc on same settings for consistency
184
+ m = model.to(device).eval()
185
+ if isinstance(sample, torch.Tensor):
186
+ x = sample.to(device)
187
+ else:
188
+ x = torch.randn(*sample, device=device)
189
+
190
+ if torch.cuda.is_available() and device.startswith("cuda"):
191
+ times = []
192
+ for _ in range(cfg.iters):
193
+ t0 = torch.cuda.Event(True); t1 = torch.cuda.Event(True)
194
+ t0.record(); _ = (forward_fn or (lambda a, b: a(pixel_values=b)))(m, x); t1.record();
195
+ if cfg.sync_each_iter: torch.cuda.synchronize()
196
+ times.append(t0.elapsed_time(t1))
197
+ else:
198
+ times = []
199
+ for _ in range(cfg.iters):
200
+ t0 = time.perf_counter(); _ = (forward_fn or (lambda a, b: a(pixel_values=b)))(m, x); t1 = time.perf_counter()
201
+ times.append((t1 - t0) * 1000.0)
202
+
203
+ times = sorted(_apply_mad_filter(times, cfg.reject_outliers_mad))
204
+ percs = _percentiles(times, cfg.percentile)
205
+ out = {"mean": sum(times) / max(1, len(times))}
206
+ out.update({f"p{q}": v for q, v in percs.items()})
207
+ return out
208
+
209
+
210
+ class LatencyProfiler:
211
+ """Reusable profiler with fixed settings."""
212
+
213
+ def __init__(self, settings: Optional[ProfileSettings] = None, device: str = "cuda"):
214
+ self.settings = settings or ProfileSettings()
215
+ self.device = device
216
+
217
+ def measure(self, model: nn.Module, sample: torch.Tensor | Tuple[int, ...], *, forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None) -> Tuple[float, float]:
218
+ return measure_latency_ms(model, sample, settings=self.settings, device=self.device, forward_fn=forward_fn)
219
+
220
+ def profile(self, model: nn.Module, sample: torch.Tensor | Tuple[int, ...], *, forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None) -> Dict[str, float]:
221
+ return profile(model, sample, settings=self.settings, device=self.device, forward_fn=forward_fn)
222
+
223
+
224
+ @torch.inference_mode()
225
+ def profile_many_shapes(
226
+ model: nn.Module,
227
+ shapes: Iterable[Tuple[int, ...]],
228
+ *,
229
+ settings: Optional[ProfileSettings] = None,
230
+ device: str = "cuda",
231
+ forward_fn: Optional[Callable[[nn.Module, torch.Tensor], torch.Tensor]] = None,
232
+ ) -> Dict[Tuple[int, ...], Dict[str, float]]:
233
+ out: Dict[Tuple[int, ...], Dict[str, float]] = {}
234
+ for shp in shapes:
235
+ out[tuple(shp)] = profile(model, shp, settings=settings, device=device, forward_fn=forward_fn)
236
+ return out
core/proxy_cost.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # core/proxy_cost.py
2
+ """Latency proxy models and a tiny LUT for hardware correction.
3
+
4
+ This file defines a family-agnostic interface plus concrete proxies (ViT, ResNet, LLM)
5
+ that estimate latency from *soft structure* (gates) and input size. All proxies accept
6
+ the trainer's `(model, batch) -> ms` call signature directly (batches may be dict/tuple/tensor).
7
+ A small, in-memory LUT can be populated from real measurements during training to correct
8
+ analytic estimates.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Any, Dict, Optional, Tuple, Union, List
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from .gates import iter_gates, _as_like # _as_like is used by ViT proxy
19
+
20
+
21
+ # -----------------------------------------------------------------------------
22
+ # Small batch helpers (shared)
23
+ # -----------------------------------------------------------------------------
24
+
25
+ TensorOrBatch = Union[torch.Tensor, Tuple, List, Dict[str, Any]]
26
+
27
+ def _first_tensor(batch: TensorOrBatch) -> torch.Tensor:
28
+ """Find the first tensor inside a batch-like structure."""
29
+ if torch.is_tensor(batch):
30
+ return batch
31
+ if isinstance(batch, dict):
32
+ # Common keys across tasks
33
+ for k in ("input_ids", "pixel_values", "images", "x"):
34
+ v = batch.get(k, None)
35
+ if torch.is_tensor(v):
36
+ return v
37
+ # fallback: first tensor value
38
+ for v in batch.values():
39
+ if torch.is_tensor(v):
40
+ return v
41
+ raise ValueError("Batch dict has no tensor field I recognize.")
42
+ if isinstance(batch, (list, tuple)):
43
+ for v in batch:
44
+ if torch.is_tensor(v):
45
+ return v
46
+ # torchvision pattern: ([aug1, aug2], label)
47
+ if len(batch) and isinstance(batch[0], (list, tuple)):
48
+ for v in batch[0]:
49
+ if torch.is_tensor(v):
50
+ return v
51
+ raise ValueError("Cannot find a tensor in the provided batch.")
52
+
53
+ def _ids_from_batch(batch: TensorOrBatch) -> torch.Tensor:
54
+ """Return a 2D [B,S] tensor representing token ids for LLMs."""
55
+ if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
56
+ return batch["input_ids"]
57
+ t = _first_tensor(batch)
58
+ if t.dim() >= 2:
59
+ return t
60
+ raise ValueError("Cannot infer [B,S] from batch; need 'input_ids' or a 2D tensor.")
61
+
62
+ def _nchw_from_batch(batch: TensorOrBatch) -> Tuple[int, int, int, int]:
63
+ """Return NCHW shape from a batch or an explicit (N,C,H,W) tuple/list/tensor."""
64
+ if isinstance(batch, (tuple, list)) and len(batch) == 4 and all(isinstance(x, int) for x in batch):
65
+ return tuple(batch) # type: ignore[return-value]
66
+ x = _first_tensor(batch)
67
+ if x.dim() != 4:
68
+ raise ValueError(f"Expected NCHW tensor for CNN proxy; got tensor with shape {tuple(x.shape)}")
69
+ N, C, H, W = map(int, x.shape)
70
+ return (N, C, H, W)
71
+
72
+
73
+ # -----------------------------------------------------------------------------
74
+ # Base proxy + LUT
75
+ # -----------------------------------------------------------------------------
76
+
77
+ class LatencyProxy(nn.Module):
78
+ """Abstract proxy producing a scalar latency-like value (ms).
79
+
80
+ Subclasses implement `_predict_raw` and may define `_signature` keys used by
81
+ a LUT to refine estimates with real measurements. Proxies accept either a
82
+ batch-like object (dict/tuple/tensor) or an explicit shape tuple.
83
+ """
84
+
85
+ def __init__(self):
86
+ super().__init__()
87
+
88
+ def predict(
89
+ self,
90
+ model: nn.Module,
91
+ sample: TensorOrBatch,
92
+ *,
93
+ policy=None,
94
+ step: Optional[int] = None,
95
+ **kwargs,
96
+ ) -> torch.Tensor:
97
+ """Batch-friendly entry point. `sample` may be a batch or explicit shape."""
98
+ return self._predict_raw(model, sample, policy=policy, step=step, **kwargs)
99
+
100
+ def _predict_raw(
101
+ self,
102
+ model: nn.Module,
103
+ sample: TensorOrBatch,
104
+ *,
105
+ policy=None,
106
+ step: Optional[int] = None,
107
+ **kwargs,
108
+ ) -> torch.Tensor: # pragma: no cover - abstract
109
+ raise NotImplementedError
110
+
111
+ def signature(
112
+ self,
113
+ model: nn.Module,
114
+ sample: TensorOrBatch,
115
+ *,
116
+ policy=None,
117
+ step: Optional[int] = None
118
+ ) -> Tuple:
119
+ """Return a hashable signature describing the workload shape."""
120
+ if torch.is_tensor(sample):
121
+ shp = tuple(sample.shape)
122
+ elif isinstance(sample, (tuple, list)):
123
+ shp = tuple(sample)
124
+ elif isinstance(sample, dict):
125
+ # summarize the shapes of any tensors in dict
126
+ shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v))
127
+ else:
128
+ shp = (str(type(sample)),)
129
+ return (type(self).__name__, shp)
130
+
131
+
132
+ class LatencyLUT:
133
+ """Tiny LUT mapping `(signature) -> measured_ms`."""
134
+
135
+ def __init__(self):
136
+ self._table: Dict[Tuple[Any, ...], float] = {}
137
+
138
+ def update(self, signature: Tuple[Any, ...], measured_ms: float) -> None:
139
+ self._table[signature] = float(measured_ms)
140
+
141
+ def get(self, signature: Tuple[Any, ...]) -> Optional[float]:
142
+ return self._table.get(signature)
143
+
144
+ def blend(self, raw_estimate: torch.Tensor, signature: Tuple[Any, ...]) -> torch.Tensor:
145
+ val = self.get(signature)
146
+ if val is None:
147
+ return raw_estimate
148
+ # Put on same device/dtype as raw_estimate
149
+ return _as_like(raw_estimate, val)
150
+
151
+
152
+ # -----------------------------------------------------------------------------
153
+ # ViT proxy (analytic + gates), with scale and per-term weights
154
+ # -----------------------------------------------------------------------------
155
+
156
+ @dataclass
157
+ class ViTProxyConfig:
158
+ scale_ms: float = 1.0
159
+ alpha_qkv: float = 1.0
160
+ alpha_scores: float = 1.0
161
+ alpha_out: float = 1.0
162
+ alpha_mlp: float = 1.0
163
+
164
+ def _vit_layers(m):
165
+ enc = getattr(m, "encoder", None)
166
+ if enc is not None and hasattr(enc, "layer"):
167
+ return enc.layer
168
+ vit = getattr(m, "vit", None)
169
+ if vit is not None and hasattr(vit, "encoder") and hasattr(vit.encoder, "layer"):
170
+ return vit.encoder.layer
171
+ raise TypeError("Expected a HF ViT with *.encoder.layer (ViTModel or ViTForImageClassification).")
172
+
173
+
174
+ class ViTLatencyProxy(LatencyProxy):
175
+ """Latency proxy for ViT models. Accepts batches or (N,C,H,W) tuples."""
176
+
177
+ def __init__(self, cfg: Optional[ViTProxyConfig] = None, lut: Optional[LatencyLUT] = None):
178
+ super().__init__()
179
+ self.cfg = cfg or ViTProxyConfig()
180
+ self.lut = lut or LatencyLUT()
181
+
182
+ # ---- helpers -------------------------------------------------------------
183
+ @staticmethod
184
+ def _input_spec(sample: TensorOrBatch) -> Tuple[int, int, int]:
185
+ if isinstance(sample, (tuple, list)) and len(sample) == 4 and all(isinstance(x, int) for x in sample):
186
+ B, C, H, W = sample
187
+ return int(B), int(H), int(W)
188
+ x = _first_tensor(sample)
189
+ if x.dim() != 4:
190
+ raise ValueError("ViTLatencyProxy expects a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
191
+ B, C, H, W = x.shape
192
+ return int(B), int(H), int(W)
193
+
194
+ @staticmethod
195
+ def _patch_hw(cfg) -> Tuple[int, int]:
196
+ patch = getattr(cfg, "patch_size", 16)
197
+ if isinstance(patch, (tuple, list)):
198
+ return int(patch[0]), int(patch[1])
199
+ return int(patch), int(patch)
200
+
201
+ @staticmethod
202
+ def _soft_heads_from_block(blk) -> Optional[torch.Tensor]:
203
+ # Prefer a nested attention with kept_heads_soft()
204
+ attn = getattr(getattr(blk, "attention", None), "attention", None)
205
+ if attn is not None and hasattr(attn, "kept_heads_soft"):
206
+ return attn.kept_heads_soft()
207
+ return None
208
+
209
+ @staticmethod
210
+ def _find_ffn_gate(blk):
211
+ inter = getattr(blk, "intermediate", None)
212
+ if inter is None:
213
+ return None
214
+ # Common attribute names
215
+ for nm in ("neuron_gate", "gate", "ffn_gate"):
216
+ g = getattr(inter, nm, None)
217
+ if g is not None and hasattr(g, "logits") and hasattr(g, "tau"):
218
+ return g
219
+ # Last resort: scan children
220
+ for m in blk.modules():
221
+ if hasattr(m, "logits") and hasattr(m, "tau"):
222
+ return m
223
+ return None
224
+
225
+ # ---- proxy ---------------------------------------------------------------
226
+ def _predict_raw(
227
+ self,
228
+ model: nn.Module,
229
+ sample: TensorOrBatch,
230
+ *,
231
+ policy=None,
232
+ step: Optional[int] = None
233
+ ) -> torch.Tensor:
234
+ anchor = next((p for p in model.parameters()), torch.tensor(0.0))
235
+
236
+ B, H_img, W_img = self._input_spec(sample)
237
+ cfg = getattr(model, "config", None)
238
+ if cfg is None:
239
+ raise ValueError("Model must expose a HuggingFace-like .config for ViT proxy")
240
+ ph, pw = self._patch_hw(cfg)
241
+
242
+ S = _as_like(anchor, 1 + (H_img // ph) * (W_img // pw))
243
+ D = _as_like(anchor, int(getattr(cfg, "hidden_size", 768)))
244
+ Hh = _as_like(anchor, int(getattr(cfg, "num_attention_heads", 12)))
245
+ Dh = D // Hh
246
+
247
+ warm = False
248
+ if policy is not None and step is not None:
249
+ warm = (step < int(getattr(policy, "warmup_steps", 0)))
250
+
251
+ total_qkv = _as_like(anchor, 0.0)
252
+ total_scores = _as_like(anchor, 0.0)
253
+ total_out = _as_like(anchor, 0.0)
254
+ total_mlp = _as_like(anchor, 0.0)
255
+
256
+ default_hidden = _as_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D))))
257
+
258
+ layers = _vit_layers(model)
259
+ for blk in layers:
260
+ heads_soft = Hh if warm else (self._soft_heads_from_block(blk) or Hh)
261
+
262
+ # FFN hidden expectation
263
+ if warm:
264
+ hidden_soft = default_hidden
265
+ else:
266
+ g = self._find_ffn_gate(blk)
267
+ if g is None:
268
+ hidden_soft = default_hidden
269
+ else:
270
+ probs = torch.sigmoid(g.logits / g.tau)
271
+ group = int(getattr(g, "group", getattr(g, "group_size", 16)))
272
+ hidden_soft = probs.sum() * _as_like(anchor, group)
273
+
274
+ D_kept = heads_soft * Dh
275
+
276
+ total_qkv += 3 * S * D * D_kept
277
+ total_scores += (S * S) * heads_soft * Dh
278
+ total_out += S * D_kept * D
279
+ total_mlp += 2 * S * D * hidden_soft
280
+
281
+ raw = (
282
+ self.cfg.alpha_qkv * total_qkv
283
+ + self.cfg.alpha_scores * total_scores
284
+ + self.cfg.alpha_out * total_out
285
+ + self.cfg.alpha_mlp * total_mlp
286
+ )
287
+ raw_ms = raw * _as_like(anchor, float(self.cfg.scale_ms))
288
+
289
+ # optional LUT correction
290
+ sig = self.signature(model, sample, policy=policy, step=step)
291
+ return self.lut.blend(raw_ms, sig)
292
+
293
+ # A reasonable default signature for ViT workloads
294
+ def signature(self, model: nn.Module, sample, *, policy=None, step: Optional[int] = None) -> Tuple:
295
+ if torch.is_tensor(sample):
296
+ shp = tuple(sample.shape)
297
+ elif isinstance(sample, (tuple, list)):
298
+ shp = tuple(sample)
299
+ elif isinstance(sample, dict):
300
+ shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v))
301
+ else:
302
+ shp = (str(type(sample)),)
303
+ cfg = getattr(model, "config", None)
304
+ heads = int(getattr(cfg, "num_attention_heads", 12))
305
+ hidden = int(getattr(cfg, "hidden_size", 768))
306
+ inter = int(getattr(cfg, "intermediate_size", 3072))
307
+ return ("ViT", shp, heads, hidden, inter)
308
+
309
+ @torch.no_grad()
310
+ def calibrate(self, model: nn.Module, shape: tuple, measure_fn, *, device: str = "cuda") -> float:
311
+ """Set proxy scale so that keep-all student matches measured ms.
312
+
313
+ `measure_fn(model, shape_or_tensor)` should return `(mean_ms, p95_ms)`.
314
+ """
315
+
316
+ sample_t = torch.randn(shape, device=device)
317
+
318
+ sample_t = sample_t.to(device)
319
+ model = model.to(device).eval()
320
+ mean_ms, _ = measure_fn(model, shape, device=device)
321
+ soft_ms = self.predict(model, sample_t).item()
322
+ self.cfg.scale_ms = float(mean_ms / max(soft_ms, 1e-9))
323
+ return self.cfg.scale_ms
324
+
325
+ # ------------------------------ ResNet Proxy ------------------------------
326
+
327
+ @dataclass
328
+ class ResNetProxyConfig:
329
+ scale_ms: float = 1.0
330
+ alpha_conv: float = 1.0 # weight for conv FLOPs term
331
+
332
+
333
+ def _as_const_like_resnet(x_like: torch.Tensor, val):
334
+ return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype)
335
+
336
+
337
+ def _find_anchor_param(model: nn.Module) -> torch.Tensor:
338
+ # Prefer any gate-like parameter; otherwise any parameter; else cpu scalar
339
+ for m in model.modules():
340
+ for nm in ("logits", "head_gate"):
341
+ t = getattr(m, nm, None)
342
+ if isinstance(t, torch.Tensor):
343
+ return t
344
+ for p in model.parameters():
345
+ return p
346
+ return torch.tensor(0.0)
347
+
348
+
349
+ def _kept_from_gate(module, anchor: torch.Tensor) -> Optional[torch.Tensor]:
350
+ """Return expected kept channels for a BN gate: probs.sum() * group_size.
351
+ If no gate is found, return None.
352
+ """
353
+ g = None
354
+ for nm in ("gate", "neuron_gate", "channel_gate", "bn_gate"):
355
+ if hasattr(module, nm):
356
+ g = getattr(module, nm)
357
+ break
358
+ if g is None and hasattr(module, "logits") and hasattr(module, "tau"):
359
+ g = module
360
+
361
+ if g is None or not hasattr(g, "logits"):
362
+ return None
363
+ logits = g.logits
364
+ tau = float(getattr(g, "tau", 1.5))
365
+ group = int(getattr(g, "group", getattr(g, "group_size", 1)))
366
+ if group <= 0: group = 1
367
+ probs = torch.sigmoid(logits / tau)
368
+ return probs.sum() * _as_const_like_resnet(anchor, group)
369
+
370
+
371
+ class ResNetLatencyProxy(LatencyProxy):
372
+ """Latency proxy for ResNet-like backbones with BN gates.
373
+
374
+ Approximates latency with a FLOPs-style sum over convs, using the *expected*
375
+ kept channels after each BN gate (probs.sum()*group_size). Falls back to the
376
+ full channel count when a gate is not found.
377
+
378
+ Accepts a batch or an explicit (N,C,H,W) shape.
379
+ """
380
+
381
+ def __init__(self, cfg: Optional[ResNetProxyConfig] = None):
382
+ super().__init__()
383
+ self.cfg = cfg or ResNetProxyConfig()
384
+
385
+ def _add_cost(self, cost_like: torch.Tensor, oc, ic, k, stride, H, W):
386
+ alpha = _as_const_like_resnet(cost_like, self.cfg.alpha_conv)
387
+ # update spatial dims with conv stride (roughly, ignoring padding effects)
388
+ H = (H + stride - 1) // stride
389
+ W = (W + stride - 1) // stride
390
+ flops = _as_const_like_resnet(cost_like, oc) * _as_const_like_resnet(cost_like, ic) * (k * k) * _as_const_like_resnet(cost_like, H) * _as_const_like_resnet(cost_like, W)
391
+ return cost_like + alpha * flops, H, W
392
+
393
+ def _predict_raw(self, model: nn.Module, sample: TensorOrBatch, **_) -> torch.Tensor:
394
+ N, C_in, H0, W0 = _nchw_from_batch(sample)
395
+ anchor = _find_anchor_param(model)
396
+ cost = _as_const_like_resnet(anchor, 0.0)
397
+ H = _as_const_like_resnet(anchor, int(H0))
398
+ W = _as_const_like_resnet(anchor, int(W0))
399
+
400
+ # Stem
401
+ conv1 = getattr(model, "conv1")
402
+ bn1 = getattr(model, "bn1", None)
403
+ k = conv1.kernel_size[0]
404
+ s = conv1.stride[0]
405
+ kept_out = None
406
+ if bn1 is not None:
407
+ kept = _kept_from_gate(bn1, anchor)
408
+ if kept is not None:
409
+ kept_out = kept
410
+ oc_eff = kept_out if kept_out is not None else _as_const_like_resnet(anchor, conv1.out_channels)
411
+ cost, H, W = self._add_cost(cost, oc_eff, _as_const_like_resnet(anchor, C_in), k, s, H, W)
412
+ in_ch = oc_eff
413
+
414
+ def _block_cost(block, in_ch, H, W, cost):
415
+ # conv1 -> bn1
416
+ c1 = block.conv1
417
+ b1 = block.bn1 if hasattr(block, "bn1") else None
418
+ k1, s1 = c1.kernel_size[0], c1.stride[0]
419
+ oc1_eff = _kept_from_gate(b1, anchor) or _as_const_like_resnet(anchor, c1.out_channels)
420
+ cost, H, W = self._add_cost(cost, oc1_eff, in_ch, k1, s1, H, W)
421
+
422
+ # conv2 -> bn2
423
+ c2 = block.conv2
424
+ b2 = block.bn2 if hasattr(block, "bn2") else None
425
+ k2, s2 = c2.kernel_size[0], c2.stride[0]
426
+ oc2_eff = _kept_from_gate(b2, anchor) or _as_const_like_resnet(anchor, c2.out_channels)
427
+ cost, H, W = self._add_cost(cost, oc2_eff, oc1_eff, k2, s2, H, W)
428
+
429
+ return oc2_eff, H, W, cost
430
+
431
+ # Layers
432
+ for lname in ("layer1", "layer2", "layer3", "layer4"):
433
+ layer = getattr(model, lname, None)
434
+ if layer is None:
435
+ continue
436
+ for blk in layer:
437
+ in_ch, H, W, cost = _block_cost(blk, in_ch, H, W, cost)
438
+
439
+ scale = _as_const_like_resnet(anchor, self.cfg.scale_ms)
440
+ return cost * scale
441
+
442
+ @torch.no_grad()
443
+ def calibrate(self, model: nn.Module, keepall_export_fn, profiler_fn, sample: TensorOrBatch, device: str = "cuda") -> float:
444
+ """Calibrate `scale_ms` so proxy(model_keepall) ~= real latency in ms."""
445
+ keep = keepall_export_fn(model)
446
+ sample_shape = _nchw_from_batch(sample)
447
+ mean_ms, _ = profiler_fn(keep, sample_shape, device=device)
448
+ soft = float(self.predict(model, sample).detach().cpu())
449
+ self.cfg.scale_ms = mean_ms / max(soft, 1e-9)
450
+ return mean_ms
451
+
452
+
453
+ # -----------------------------------------------------------------------------
454
+ # LLM proxy
455
+ # -----------------------------------------------------------------------------
456
+
457
+ """
458
+ LatencyProxyLLM
459
+ ---------------
460
+ A lightweight latency proxy for decoder-only HF LLMs (LLaMA/Mistral style).
461
+
462
+ - Estimates end-to-end latency (ms-like scalar) for a given (B, S, T):
463
+ * Prefill on S tokens (build KV cache)
464
+ * Cached decode for T steps
465
+ - Uses soft gate expectations:
466
+ * Attention heads (HeadGate on GatedSelfAttentionLLM)
467
+ * FFN hidden (SwiGLUWidthGate via .mlp.neuron_gate)
468
+ - Calibrate .scale_ms so proxy ≈ real latency of a keep-all model.
469
+
470
+ Public API
471
+ ----------
472
+ - LatencyProxyLLM(...).predict(model, batch_or_shape) # trainer entry
473
+ - LatencyProxyLLM(...).predict(model, B=?, S=?, T=?) # explicit entry
474
+ - LatencyProxyLLM(...).debug_layer_view(...)
475
+ - calibrate_proxy_llm(...), calibrate_proxy_llm_from_batch(...)
476
+ """
477
+
478
+ # ------------------------------------------------------------
479
+ # Shared tiny utils (device/dtype-safe constants)
480
+ # ------------------------------------------------------------
481
+ def _find_gate_param_or_fallback(model: nn.Module) -> torch.Tensor:
482
+ """
483
+ Return a tensor to anchor device/dtype for proxy constants.
484
+ Prefer gate logits; else any parameter; else CPU fp32 scalar.
485
+ """
486
+ for m in model.modules():
487
+ if hasattr(m, "head_gate") and hasattr(getattr(m, "head_gate"), "logits"):
488
+ return m.head_gate.logits
489
+ if hasattr(m, "neuron_gate") and hasattr(m.neuron_gate, "logits"):
490
+ return m.neuron_gate.logits
491
+ if hasattr(m, "logits") and isinstance(getattr(m, "logits"), torch.Tensor):
492
+ return m.logits
493
+ for p in model.parameters():
494
+ return p
495
+ return torch.tensor(0.0)
496
+
497
+ def _as_const_like(x_like: torch.Tensor, val):
498
+ return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype)
499
+
500
+
501
+ # ------------------------------------------------------------
502
+ # Proxy
503
+ # ------------------------------------------------------------
504
+ @dataclass
505
+ class _WarmupOnlyPolicy:
506
+ """Tiny policy shim so you can pass warmup_steps to .predict()."""
507
+ warmup_steps: int = 0
508
+
509
+ class LatencyProxyLLM(LatencyProxy):
510
+ """
511
+ LLM latency proxy (ms ~ weighted FLOPs/bandwidth terms) for prefill + cached decode.
512
+ Accepts either a batch or explicit B,S,T.
513
+ """
514
+
515
+ def __init__(
516
+ self,
517
+ *,
518
+ scale_ms: float = 1.0,
519
+ alpha_qkv: float = 1.0,
520
+ alpha_scores: float = 1.0,
521
+ alpha_out: float = 1.0,
522
+ alpha_mlp: float = 1.0,
523
+ gate_kv_in_proxy: bool = False,
524
+ default_T: int = 128,
525
+ ):
526
+ super().__init__()
527
+ self.scale_ms = float(scale_ms)
528
+ self.alpha_qkv = float(alpha_qkv)
529
+ self.alpha_scores = float(alpha_scores)
530
+ self.alpha_out = float(alpha_out)
531
+ self.alpha_mlp = float(alpha_mlp)
532
+ self.gate_kv_in_proxy = bool(gate_kv_in_proxy)
533
+ self.default_T = int(default_T)
534
+
535
+ # ---------- gate discovery ----------
536
+ @staticmethod
537
+ def _soft_heads_from_block_llm(blk) -> Optional[torch.Tensor]:
538
+ attn = getattr(blk, "self_attn", None)
539
+ if attn is None:
540
+ return None
541
+ if hasattr(attn, "kept_heads_soft") and callable(attn.kept_heads_soft):
542
+ return attn.kept_heads_soft()
543
+ logits, tau = None, None
544
+ if hasattr(attn, "head_gate") and hasattr(attn.head_gate, "logits"):
545
+ logits = attn.head_gate.logits
546
+ tau = float(getattr(attn.head_gate, "tau", getattr(attn, "tau", 1.5)))
547
+ elif hasattr(attn, "logits"):
548
+ logits = attn.logits
549
+ tau = float(getattr(attn, "tau", 1.5))
550
+ if logits is None:
551
+ return None
552
+ return torch.sigmoid(logits / tau).sum()
553
+
554
+ @staticmethod
555
+ def _find_ffn_gate_llm(blk):
556
+ mlp = getattr(blk, "mlp", None)
557
+ g = getattr(mlp, "neuron_gate", None) if mlp is not None else None
558
+ if g is not None and hasattr(g, "logits") and hasattr(g, "tau"):
559
+ return g
560
+ return None
561
+
562
+ def _soft_hidden_from_block_llm(self, blk, default_hidden, anchor, warm=False):
563
+ if warm:
564
+ return default_hidden
565
+ g = self._find_ffn_gate_llm(blk)
566
+ if g is None:
567
+ return default_hidden
568
+ probs = torch.sigmoid(g.logits / float(g.tau)) # [#groups]
569
+ group = int(getattr(g, "group", getattr(g, "group_size", 128)))
570
+ kept_hidden = probs.sum() * _as_const_like(anchor, group)
571
+ return kept_hidden
572
+
573
+ # ---------- main ----------
574
+ def predict( # trainer entry and explicit-shape entry unified
575
+ self,
576
+ model: nn.Module,
577
+ sample: Optional[TensorOrBatch] = None,
578
+ *,
579
+ B: Optional[int] = None,
580
+ S: Optional[int] = None,
581
+ T: Optional[int] = None,
582
+ policy: Optional[object] = None,
583
+ step: Optional[int] = None,
584
+ return_terms: bool = False,
585
+ ):
586
+ # Allow explicit B,S,(T) path
587
+ if B is not None and S is not None:
588
+ ids_B, ids_S = int(B), int(S)
589
+ ids_T = int(T) if T is not None else int(self.default_T)
590
+ else:
591
+ if sample is None:
592
+ raise ValueError("LatencyProxyLLM.predict needs either a batch sample or explicit B,S.")
593
+ if isinstance(sample, (tuple, list)) and len(sample) in (2, 3) and all(isinstance(x, int) for x in sample):
594
+ # explicit (B,S) or (B,S,T)
595
+ ids_B, ids_S = int(sample[0]), int(sample[1])
596
+ ids_T = int(sample[2]) if len(sample) == 3 else int(self.default_T)
597
+ else:
598
+ ids = _ids_from_batch(sample)
599
+ ids_B, ids_S = int(ids.size(0)), int(ids.size(1))
600
+ ids_T = int(self.default_T) if T is None else int(T)
601
+
602
+ anchor = _find_gate_param_or_fallback(model)
603
+
604
+ # scalar tensors (same device/dtype)
605
+ B_t = _as_const_like(anchor, ids_B)
606
+ S_t = _as_const_like(anchor, ids_S)
607
+ T_t = _as_const_like(anchor, ids_T)
608
+
609
+ cfg = model.config
610
+ D = _as_const_like(anchor, int(cfg.hidden_size))
611
+ Hh = _as_const_like(anchor, int(cfg.num_attention_heads))
612
+ Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hh))))
613
+ Dh = D // Hh
614
+
615
+ warmup_steps = int(getattr(policy, "warmup_steps", 0)) if policy is not None else 0
616
+ warm = bool(step is not None and step < warmup_steps)
617
+
618
+ total_qkv = anchor.new_zeros(())
619
+ total_scores = anchor.new_zeros(())
620
+ total_out = anchor.new_zeros(())
621
+ total_mlp = anchor.new_zeros(())
622
+
623
+ default_hidden = _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D))))
624
+
625
+ layers = getattr(getattr(model, "model", model), "layers", [])
626
+ for blk in layers:
627
+ heads_soft = Hh if warm else (self._soft_heads_from_block_llm(blk) or Hh)
628
+ Dq = heads_soft * Dh
629
+ # K/V effective width
630
+ if self.gate_kv_in_proxy:
631
+ Dkv = heads_soft * Dh
632
+ else:
633
+ Dkv = Hkv * Dh
634
+ hidden_soft = self._soft_hidden_from_block_llm(blk, default_hidden, anchor, warm=warm)
635
+
636
+ # Prefill + decode (simplified aggregation)
637
+ Seff = S_t + T_t
638
+
639
+ # q/k/v linear FLOP-like terms
640
+ total_qkv = total_qkv + (
641
+ # q
642
+ B_t * Seff * D * Dq +
643
+ # k + v
644
+ 2 * B_t * Seff * D * Dkv
645
+ )
646
+ # attention scores (prefill SxS + decode triangular)
647
+ total_scores = total_scores + (
648
+ B_t * (S_t * S_t) * heads_soft * Dh +
649
+ B_t * heads_soft * Dh * (T_t * S_t + (T_t * (T_t + 1)) // 2)
650
+ )
651
+ # out proj
652
+ total_out = total_out + B_t * Seff * Dq * D
653
+ # mlp
654
+ total_mlp = total_mlp + B_t * Seff * 2 * D * hidden_soft
655
+
656
+ flops_like = (
657
+ self.alpha_qkv * total_qkv
658
+ + self.alpha_scores * total_scores
659
+ + self.alpha_out * total_out
660
+ + self.alpha_mlp * total_mlp
661
+ )
662
+
663
+ ms = flops_like * _as_const_like(anchor, self.scale_ms)
664
+ if return_terms:
665
+ return ms, {
666
+ "qkv": float((self.alpha_qkv * total_qkv).detach().cpu()),
667
+ "scores": float((self.alpha_scores * total_scores).detach().cpu()),
668
+ "out": float((self.alpha_out * total_out).detach().cpu()),
669
+ "mlp": float((self.alpha_mlp * total_mlp).detach().cpu()),
670
+ }
671
+ return ms
672
+
673
+ # ---------- per-layer debug ----------
674
+ @torch.no_grad()
675
+ def debug_layer_view(
676
+ self,
677
+ model: nn.Module,
678
+ *,
679
+ B: int,
680
+ S: int,
681
+ T: int,
682
+ policy: Optional[object] = None,
683
+ step: Optional[int] = None,
684
+ ) -> list:
685
+ anchor = _find_gate_param_or_fallback(model)
686
+ cfg = getattr(model, "config", None)
687
+ D = _as_const_like(anchor, int(getattr(cfg, "hidden_size", 0)))
688
+ Hq = _as_const_like(anchor, int(getattr(cfg, "num_attention_heads", 0)))
689
+ Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hq))))
690
+ Dh = D // Hq
691
+
692
+ warm = False
693
+ if policy is not None and step is not None:
694
+ warm = (int(step) < int(getattr(policy, "warmup_steps", 0)))
695
+
696
+ rows = []
697
+ layers = getattr(getattr(model, "model", model), "layers", None) or []
698
+ for i, blk in enumerate(layers):
699
+ heads_soft = Hq if warm else (self._soft_heads_from_block_llm(blk) or Hq)
700
+ Dq = heads_soft * Dh
701
+ Dkv = (heads_soft * Dh) if self.gate_kv_in_proxy else (Hkv * Dh)
702
+ hidden_soft = self._soft_hidden_from_block_llm(
703
+ blk, _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D)))), anchor, warm=warm
704
+ )
705
+ rows.append({
706
+ "layer": i,
707
+ "heads_soft": float(heads_soft.detach().cpu()),
708
+ "Dq≈heads*Dh": float(Dq.detach().cpu()),
709
+ "Dkv_used": float(Dkv.detach().cpu()),
710
+ "ffn_hidden_soft": float(hidden_soft.detach().cpu()),
711
+ })
712
+ return rows
713
+
714
+
715
+ # ------------------------------------------------------------
716
+ # Calibration helpers for LLM
717
+ # ------------------------------------------------------------
718
+ @torch.inference_mode()
719
+ def calibrate_proxy_llm(
720
+ proxy: LatencyProxyLLM,
721
+ model: nn.Module,
722
+ *,
723
+ B: int,
724
+ S: int,
725
+ T: int,
726
+ export_keepall_fn,
727
+ device: str = "cuda",
728
+ warmup: int = 10,
729
+ iters: int = 30,
730
+ ) -> float:
731
+ """
732
+ Calibrate proxy.scale_ms so proxy.predict(...) matches real keep-all latency for (B,S,T).
733
+ Returns the measured real mean latency in ms.
734
+ """
735
+ keepall = export_keepall_fn(model).to(device).eval()
736
+
737
+ # Measure real latency (prefill + decode)
738
+ from core.measure import measure_latency_text_ms as _measure # adjust if your path differs
739
+ real_ms, _ = _measure(keepall, B=B, S=S, T=T, warmup=warmup, iters=iters, device=device)
740
+
741
+ # Soft/proxy latency on *gated* model
742
+ ms_like = proxy.predict(model, B=B, S=S, T=T)
743
+ soft_ms = float(ms_like.detach().item()) if torch.is_tensor(ms_like) else float(ms_like)
744
+
745
+ proxy.scale_ms = float(real_ms / max(soft_ms, 1e-9))
746
+ return real_ms
747
+
748
+
749
+ @torch.inference_mode()
750
+ def calibrate_proxy_llm_from_batch(
751
+ proxy: LatencyProxyLLM,
752
+ model: nn.Module,
753
+ batch: Dict[str, torch.Tensor],
754
+ *,
755
+ T: int,
756
+ export_keepall_fn,
757
+ device: str = "cuda",
758
+ warmup: int = 10,
759
+ iters: int = 30,
760
+ ) -> Tuple[int, int, int, float]:
761
+ """
762
+ Infers (B,S) from a batch like {'input_ids': [B,S], ...},
763
+ calibrates for (B,S,T), and returns (B,S,T, real_ms).
764
+ """
765
+ input_ids = batch["input_ids"]
766
+ B, S = int(input_ids.size(0)), int(input_ids.size(1))
767
+ ms = calibrate_proxy_llm(
768
+ proxy, model, B=B, S=S, T=T, export_keepall_fn=export_keepall_fn,
769
+ device=device, warmup=warmup, iters=iters
770
+ )
771
+ return B, S, T, ms
core/search_export.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Export-parameter search (hardware-aware).
2
+
3
+ This module performs a small grid search over export rounding/multiple knobs and
4
+ picks the configuration that minimizes *measured* latency for the target batch
5
+ shape. It is family-agnostic; adapters provide the export function.
6
+
7
+ For ViT, see `vit_search_best_export` which scans per-head multiples and FFN
8
+ snap group sizes, mirroring kernel-friendly widths.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Callable, Iterable, List, Optional, Sequence, Tuple
14
+
15
+ import copy
16
+ import itertools
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .export import ExportPolicy as CoreExportPolicy, Rounding as CoreRounding
22
+ from .profiler import measure_latency_ms, ProfileSettings
23
+
24
+
25
+ # Type alias: adapter export function
26
+ ExportFn = Callable[[nn.Module, object, int], nn.Module]
27
+
28
+
29
+ @dataclass
30
+ class SearchResult:
31
+ best_model: nn.Module
32
+ best_params: dict
33
+ trials: List[dict]
34
+
35
+
36
+ def grid_search_latency(
37
+ model_with_gates: nn.Module,
38
+ export_fn: ExportFn,
39
+ *,
40
+ head_multiples: Sequence[int],
41
+ ffn_snaps: Sequence[int],
42
+ step: int,
43
+ batch_shape: Tuple[int, int, int, int], # (B,C,H,W)
44
+ measure_settings: Optional[ProfileSettings] = None,
45
+ device: str = "cuda",
46
+ make_policy: Optional[Callable[[int, int], object]] = None,
47
+ ) -> SearchResult:
48
+ """Generic grid search over (head_multiple, ffn_snap_groups).
49
+
50
+ - `make_policy(h_mult, ffn_snap)` must return an adapter-acceptable export policy.
51
+ If not provided, falls back to a single-rounding `CoreExportPolicy` using
52
+ `multiple_groups=head_multiple` for both heads and FFN.
53
+ """
54
+ trials: List[dict] = []
55
+ best = None
56
+
57
+ to_try = itertools.product(head_multiples, ffn_snaps)
58
+ for i, (hm, fs) in enumerate(to_try):
59
+ policy = make_policy(hm, fs) if make_policy is not None else CoreExportPolicy(
60
+ warmup_steps=0,
61
+ rounding=CoreRounding(floor_groups=1, multiple_groups=int(hm), min_keep_ratio=0.0),
62
+ )
63
+ slim = export_fn(model_with_gates, policy, step)
64
+ mean_ms, p95_ms = measure_latency_ms(slim, batch_shape, settings=measure_settings, device=device)
65
+ rec = {"head_multiple": int(hm), "ffn_snap": int(fs), "mean_ms": float(mean_ms), "p95_ms": float(p95_ms)}
66
+ print(f"[{i}/{len(list(to_try))}] head_multiple {int(hm)} | ffn_snap {int(fs)} | mean_ms = {float(mean_ms)}")
67
+ trials.append(rec)
68
+ if best is None or mean_ms < best[0]:
69
+ best = (mean_ms, hm, fs, slim)
70
+
71
+ assert best is not None
72
+ _, hm_best, fs_best, slim_best = best
73
+ return SearchResult(best_model=slim_best, best_params={"head_multiple": int(hm_best), "ffn_snap": int(fs_best)}, trials=trials)
74
+
75
+
76
+
core/train.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generic Lagrangian trainer (family-agnostic).
2
+
3
+ This module provides a light framework to optimize *gated* students against
4
+ teachers with a latency target enforced via a proxy + optional real probes.
5
+
6
+ It does not assume ViT/ResNet/LLM specifics; adapters provide tiny callables.
7
+
8
+ Key ingredients:
9
+ - Two-phase update per step: (A) weights w.r.t. KD/task, (B) gates w.r.t. KD +
10
+ sparsity + latency penalty with a dual variable λ.
11
+ - Optional periodic export + real-latency probe to correct λ.
12
+ - Constraint projection for gates after each step.
13
+
14
+ Adapters must provide:
15
+ - get_student_logits(model, x) -> Tensor
16
+ - get_teacher_logits(model, x) -> Tensor
17
+ - export_keepall(model) -> nn.Module (clean copy without gates)
18
+ - export_pruned(model, policy, step) -> nn.Module (transient copy for profiling)
19
+
20
+ Core modules used:
21
+ - `distill.KDConfig`, `distill.kd_loss`
22
+ - `gates.combined_penalty`, `gates.PenaltyWeights`, `gates.project_gates_into_constraints`
23
+ - `proxy_cost.LatencyProxy`
24
+ - `profiler.measure_latency_ms`
25
+ """
26
+ from __future__ import annotations
27
+
28
+ from dataclasses import dataclass
29
+ from typing import Callable, Optional
30
+ import gc
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+
35
+ from .distill import KDConfig, kd_loss, mse_reg
36
+ from .gates import PenaltyWeights, Constraints, combined_penalty, project_gates_into_constraints, collect_param_groups
37
+ from .proxy_cost import LatencyProxy
38
+ from .profiler import measure_latency_ms
39
+
40
+ # -----------------------------------------------------------------------------
41
+ # Config
42
+ # -----------------------------------------------------------------------------
43
+
44
+ @dataclass
45
+ class DualConfig:
46
+ lr: float = 0.05 # step for λ update
47
+ ema_beta: float = 0.5 # blend proxy-driven λ and real probe λ
48
+ clip: float = 10.0
49
+
50
+
51
+ @dataclass
52
+ class TrainerConfig:
53
+ kd: KDConfig = KDConfig()
54
+ penalties: PenaltyWeights = PenaltyWeights(l0=0.0, keep_floor_ratio=0.0, bimodality=0.0)
55
+ constraints: Constraints = Constraints(min_keep_ratio=0.0, min_groups=1, max_groups_drop=None)
56
+
57
+ latency_target_ms: float = 30.0
58
+ real_probe_every: int = 0 # steps; 0 disables real probes
59
+ probe_batch_override: Optional[int] = None
60
+ gate_warmup_steps: int = 0 # Freeze gates for early steps
61
+ mse_weight: float = 0.0
62
+
63
+ early_stopping_patience: int = 0
64
+ early_stopping_lambda: float = 1e-4
65
+
66
+ amp: bool = True
67
+ device: str = "cuda"
68
+
69
+ # Optimizers
70
+ lr_gate: float = 1e-2
71
+ lr_linear: float = 1e-4
72
+ lr_affine: float = 3e-4
73
+ wd_linear: float = 1e-4
74
+
75
+ # Mixed precision scaler
76
+ use_grad_scaler: bool = True
77
+
78
+ # Dual update
79
+ dual: DualConfig = DualConfig()
80
+
81
+
82
+ # -----------------------------------------------------------------------------
83
+ # Trainer
84
+ # -----------------------------------------------------------------------------
85
+
86
+ class LagrangeTrainer:
87
+ def __init__(
88
+ self,
89
+ student: nn.Module,
90
+ teacher: nn.Module,
91
+ proxy: LatencyProxy,
92
+ *,
93
+ adapter_get_student_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
94
+ adapter_get_teacher_logits: Callable[[nn.Module, torch.Tensor], torch.Tensor],
95
+ adapter_export_keepall: Callable[[nn.Module], nn.Module],
96
+ adapter_export_pruned: Callable[[nn.Module, object, int], nn.Module],
97
+ export_policy: object,
98
+ cfg: TrainerConfig,
99
+ ) -> None:
100
+ self.student = student
101
+ self.teacher = teacher.eval()
102
+ for p in self.teacher.parameters():
103
+ p.requires_grad_(False)
104
+ self.proxy = proxy
105
+ self.get_s = adapter_get_student_logits
106
+ self.get_t = adapter_get_teacher_logits
107
+ self.export_keepall = adapter_export_keepall
108
+ self.export_pruned = adapter_export_pruned
109
+ self.export_policy = export_policy
110
+ self.cfg = cfg
111
+
112
+ # Build optimizers (grouped)
113
+ param_groups = collect_param_groups(
114
+ student,
115
+ lr_gate=cfg.lr_gate,
116
+ lr_linear=cfg.lr_linear,
117
+ lr_affine=cfg.lr_affine,
118
+ wd_linear=cfg.wd_linear,
119
+ )
120
+ # gates-only optimizer uses first group
121
+ self.opt_g = torch.optim.Adam([param_groups[0]], lr=param_groups[0]["lr"]) # type: ignore[arg-type]
122
+ # weights optimizer for the rest
123
+ self.opt_w = torch.optim.Adam(param_groups[1:])
124
+
125
+ self.scaler = torch.amp.GradScaler('cuda', enabled=(cfg.amp and cfg.use_grad_scaler))
126
+ self.lambda_: float = 0.0
127
+ self.mse_weight = cfg.mse_weight
128
+
129
+ # ---- internal helpers -----------------------------------------------------
130
+ def _zero_grads(self, params):
131
+ for p in params:
132
+ if p.grad is not None:
133
+ p.grad = None
134
+
135
+ def _has_grad(self, params) -> bool:
136
+ for p in params:
137
+ if p.grad is not None:
138
+ return True
139
+ return False
140
+
141
+ # ---- training -------------------------------------------------------------
142
+ def train_epoch(self, loader, *, real_policy=None, verbose_every: int = 50):
143
+ device = self.cfg.device
144
+ self.student.train().to(device)
145
+ self.teacher.to(device).eval()
146
+
147
+ running = 0.0
148
+ seen = 0
149
+ lam_real = self.lambda_
150
+
151
+ total_steps = len(loader)
152
+
153
+
154
+ for step, batch in enumerate(loader, 1):
155
+ # Move batch to device in a type-safe way
156
+ batch = _move_batch_to_device(batch, device)
157
+
158
+ # with torch.inference_mode():
159
+ with torch.no_grad():
160
+ t_logits = self.get_t(self.teacher, batch) # [B,1,V]
161
+ # match AMP compute dtype to avoid upcasting later
162
+ if self.cfg.amp:
163
+ # infer autocast dtype from student params (bf16 or fp16)
164
+ sparam = next(self.student.parameters())
165
+ t_logits = t_logits.to(dtype=sparam.dtype, non_blocking=True)
166
+
167
+
168
+ # -------- Pass A: WEIGHTS (KD only) --------
169
+ self.opt_w.zero_grad(set_to_none=True)
170
+
171
+ with torch.amp.autocast('cuda', enabled=self.cfg.amp):
172
+ # Adapters receive the batch object (dict/tuple/tensor)
173
+ s_logits = self.get_s(self.student, batch)
174
+ # with torch.no_grad():
175
+ # t_logits = self.get_t(self.teacher, batch)
176
+ mse = self.mse_weight*mse_reg(s_logits, t_logits, self.cfg.kd.temperature)
177
+ loss_w = kd_loss(s_logits, t_logits, self.cfg.kd) + mse
178
+
179
+ self.scaler.scale(loss_w).backward()
180
+ # Prevent gate params from changing in pass A
181
+ gate_params = self.opt_g.param_groups[0]["params"]
182
+ self._zero_grads(gate_params)
183
+
184
+ if any(p.grad is not None for pg in self.opt_w.param_groups for p in pg["params"]):
185
+ self.scaler.step(self.opt_w)
186
+ self.scaler.update()
187
+ else:
188
+ self.opt_w.zero_grad(set_to_none=True)
189
+
190
+ del s_logits
191
+ gc.collect()
192
+ torch.cuda.empty_cache()
193
+
194
+ if step > int(self.cfg.gate_warmup_steps):
195
+
196
+ # -------- Pass B: GATES (KD + sparsity + λ * gap) --------
197
+ self.opt_g.zero_grad(set_to_none=True)
198
+ with torch.amp.autocast('cuda', enabled=self.cfg.amp):
199
+ s_logits = self.get_s(self.student, batch)
200
+ # with torch.no_grad():
201
+ # t_logits = self.get_t(self.teacher, batch)
202
+ kd_g = kd_loss(s_logits, t_logits, self.cfg.kd)
203
+
204
+ # Proxy gets the batch object too; family-specific proxy can read (B,S) etc.
205
+ o1_ms = self.proxy.predict(self.student, batch)
206
+ gap = torch.relu(o1_ms - float(self.cfg.latency_target_ms))
207
+ reg = combined_penalty(self.student, self.cfg.penalties)
208
+ mse = self.mse_weight*mse_reg(s_logits, t_logits, self.cfg.kd.temperature)
209
+ loss_g = kd_g + _to_tensor(self.lambda_, o1_ms) * gap + reg + mse
210
+
211
+ self.scaler.scale(loss_g).backward()
212
+ # Prevent non-gate params from changing in pass B
213
+ for pg in self.opt_w.param_groups:
214
+ self._zero_grads(pg["params"])
215
+
216
+ if self._has_grad(self.opt_g.param_groups[0]["params"]):
217
+ self.scaler.step(self.opt_g)
218
+ self.scaler.update()
219
+ else:
220
+ self.opt_g.zero_grad(set_to_none=True)
221
+ else:
222
+ o1_ms = self.proxy.predict(self.student, batch)
223
+ s_logits = loss_g = kd_g = reg = torch.tensor(0.0, device=device)
224
+
225
+ # -------- Dual (λ) update using proxy --------
226
+ with torch.no_grad():
227
+ lam_proxy = max(0.0, self.lambda_ + self.cfg.dual.lr * (float(o1_ms.detach()) - self.cfg.latency_target_ms))
228
+ self.lambda_ = 0.5 * (lam_real + lam_proxy)
229
+
230
+ # -------- Constraint projection, optional real probe --------
231
+ project_gates_into_constraints(self.student, self.cfg.constraints)
232
+
233
+
234
+ if self.cfg.real_probe_every and (step % int(self.cfg.real_probe_every) == 0):
235
+ # Build a probe shape for latency func if needed
236
+ try:
237
+ from core.measure import measure_latency_text_ms # text-friendly
238
+ if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
239
+ B, S = int(batch["input_ids"].size(0)), int(batch["input_ids"].size(1))
240
+ else:
241
+ # Fallback: try tensor-like batch
242
+ x0 = batch["input_ids"] if isinstance(batch, dict) else (batch[0] if isinstance(batch, (tuple, list)) else batch)
243
+ B = int(x0.size(0)); S = int(x0.size(1))
244
+ slim = self.export_pruned(self.student, real_policy or self.export_policy, step)
245
+ mean_ms, p95_ms = measure_latency_text_ms(slim, B=B, S=S, T=128, device=device)
246
+ except Exception:
247
+ # If the project has a different profiler, retain compatibility:
248
+ from .profiler import measure_latency_ms
249
+ x0 = batch["input_ids"] if isinstance(batch, dict) else (batch[0] if isinstance(batch, (tuple, list)) else batch)
250
+ shape = (int(x0.size(0)), *list(x0.shape[1:]))
251
+ slim = self.export_pruned(self.student, real_policy or self.export_policy, step)
252
+ mean_ms, p95_ms = measure_latency_ms(slim, shape, device=device)
253
+
254
+ with torch.no_grad():
255
+ lam_real = max(0.0, self.lambda_ + self.cfg.dual.lr * (mean_ms - self.cfg.latency_target_ms))
256
+
257
+ # scale_correction = mean_ms / max(1e-9, o1_ms.detach())
258
+ # self.proxy.cfg.scale_ms = 0.9 * self.proxy.cfg.scale_ms + 0.1 * scale_correction * self.proxy.cfg.scale_ms
259
+
260
+
261
+ if (step % verbose_every) == 0:
262
+ print(
263
+ f"Step {step}/{len(loader)} | KL={float(loss_w.item()):.6f} | MSE={float(mse.item()):.6f} | "
264
+ f"Gate={float(loss_g.item()):.6f} | "
265
+ f"proxy={float(o1_ms.detach()):.3f}ms | real_mean={mean_ms:.3f}ms p95={p95_ms:.3f}ms | λ={self.lambda_:.6f}"
266
+ )
267
+
268
+ running += float(loss_g.detach())
269
+ seen += _batch_size(batch)
270
+
271
+ del s_logits, t_logits, o1_ms, kd_g, reg, loss_g, loss_w
272
+ torch.cuda.empty_cache()
273
+ gc.collect()
274
+
275
+ print(f"Epoch loss {running / max(1, seen):.6f}")
276
+ return self.lambda_
277
+
278
+
279
+ # -----------------------------------------------------------------------------
280
+ # Helpers
281
+ # -----------------------------------------------------------------------------
282
+
283
+ def _to_tensor(val: float, like: torch.Tensor) -> torch.Tensor:
284
+ return torch.as_tensor(val, device=like.device, dtype=like.dtype)
285
+
286
+ def _move_batch_to_device(batch, device: str):
287
+ """
288
+ Supports:
289
+ - dict with keys 'input_ids' and optional 'attention_mask'
290
+ - (x,) or (x, y) tuples/lists -> move each tensor-like to device
291
+ - single Tensor
292
+ Converts attention_mask to bool (preferred by HF SDPA).
293
+ """
294
+ if isinstance(batch, dict):
295
+ out = {}
296
+ for k, v in batch.items():
297
+ if torch.is_tensor(v):
298
+ v = v.to(device, non_blocking=True)
299
+ if k == "attention_mask" and v.dtype != torch.bool:
300
+ v = v.to(torch.bool)
301
+ out[k] = v
302
+ return out
303
+
304
+ if isinstance(batch, (tuple, list)):
305
+ moved = []
306
+ for v in batch:
307
+ if torch.is_tensor(v):
308
+ v = v.to(device, non_blocking=True)
309
+ moved.append(v)
310
+ return type(batch)(moved)
311
+
312
+ if torch.is_tensor(batch):
313
+ return batch.to(device, non_blocking=True)
314
+
315
+ # Unknown type: return as-is (adapters/proxy should handle it)
316
+ return batch
317
+
318
+
319
+ def _batch_size(batch) -> int:
320
+ """Best-effort batch size for logging/averages."""
321
+ if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]):
322
+ return int(batch["input_ids"].size(0))
323
+ if torch.is_tensor(batch):
324
+ return int(batch.size(0))
325
+ if isinstance(batch, (tuple, list)) and len(batch) and torch.is_tensor(batch[0]):
326
+ return int(batch[0].size(0))
327
+ return 1
core/utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared utilities used across core and adapters.
2
+
3
+ Consolidates helpers that are generic (device/dtype, seeding, shapes, rounding,
4
+ parameter grouping, model copying, etc.). Keep this file dependency-light.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
10
+
11
+ import copy
12
+ import random
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Device / dtype helpers
21
+ # -----------------------------------------------------------------------------
22
+
23
+ def as_like(x: torch.Tensor, val) -> torch.Tensor:
24
+ """Create a scalar/tensor constant on same device/dtype as `x`."""
25
+ return torch.as_tensor(val, device=x.device, dtype=x.dtype)
26
+
27
+
28
+ def first_param(module: nn.Module) -> torch.Tensor:
29
+ for p in module.parameters(recurse=True):
30
+ return p
31
+ return torch.tensor(0.0)
32
+
33
+
34
+ def to_device_dtype(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
35
+ return x.to(device=ref.device, dtype=ref.dtype)
36
+
37
+
38
+ # -----------------------------------------------------------------------------
39
+ # Seeding & determinism
40
+ # -----------------------------------------------------------------------------
41
+
42
+ def set_seed(seed: int = 42, deterministic: bool = False) -> None:
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed)
47
+ if deterministic:
48
+ torch.backends.cudnn.deterministic = True
49
+ torch.backends.cudnn.benchmark = False
50
+
51
+
52
+ # -----------------------------------------------------------------------------
53
+ # Model parameter helpers
54
+ # -----------------------------------------------------------------------------
55
+
56
+ def freeze(module: nn.Module) -> None:
57
+ for p in module.parameters():
58
+ p.requires_grad_(False)
59
+
60
+
61
+ def unfreeze(module: nn.Module) -> None:
62
+ for p in module.parameters():
63
+ p.requires_grad_(True)
64
+
65
+
66
+ def count_parameters(module: nn.Module, *, trainable_only: bool = False) -> int:
67
+ if trainable_only:
68
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
69
+ return sum(p.numel() for p in module.parameters())
70
+
71
+
72
+ # -----------------------------------------------------------------------------
73
+ # Shape/signature helpers
74
+ # -----------------------------------------------------------------------------
75
+
76
+ def input_spec_vision(sample) -> Tuple[int, int, int]:
77
+ """Accept either a 4D tensor [B,3,H,W] or a 4-tuple (B,3,H,W). Returns (B,H,W)."""
78
+ if isinstance(sample, torch.Tensor):
79
+ B, C, H, W = sample.shape
80
+ return int(B), int(H), int(W)
81
+ if isinstance(sample, (tuple, list)) and len(sample) == 4:
82
+ B, C, H, W = sample
83
+ return int(B), int(H), int(W)
84
+ raise ValueError("sample must be a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
85
+
86
+
87
+ # -----------------------------------------------------------------------------
88
+ # Rounding / multiples
89
+ # -----------------------------------------------------------------------------
90
+
91
+ def round_down_multiple(n: int, m: int) -> int:
92
+ if m is None or m <= 1:
93
+ return max(1, int(n))
94
+ n = int(n)
95
+ return max(m, (n // m) * m)
96
+
97
+
98
+ def clamp_int(v: int, lo: int, hi: int) -> int:
99
+ return max(lo, min(int(v), hi))
100
+
101
+
102
+ # -----------------------------------------------------------------------------
103
+ # Slicing helpers
104
+ # -----------------------------------------------------------------------------
105
+
106
+ @torch.no_grad()
107
+ def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
108
+ W = mat.weight.detach()
109
+ b = mat.bias.detach() if mat.bias is not None else None
110
+ if keep_out is not None:
111
+ idx_out = torch.as_tensor(keep_out, device=W.device)
112
+ W = W.index_select(0, idx_out)
113
+ if b is not None:
114
+ b = b.index_select(0, idx_out)
115
+ if keep_in is not None:
116
+ idx_in = torch.as_tensor(keep_in, device=W.device)
117
+ W = W.index_select(1, idx_in)
118
+ out_f, in_f = W.shape
119
+ new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
120
+ new.weight.copy_(W)
121
+ if b is not None:
122
+ new.bias.copy_(b)
123
+ return new
124
+
125
+
126
+ # -----------------------------------------------------------------------------
127
+ # Copying & detaching models
128
+ # -----------------------------------------------------------------------------
129
+
130
+ def deepcopy_eval_cpu(module: nn.Module) -> nn.Module:
131
+ m = copy.deepcopy(module).cpu().eval()
132
+ return m
133
+
134
+
135
+ # -----------------------------------------------------------------------------
136
+ # Gradient utilities
137
+ # -----------------------------------------------------------------------------
138
+
139
+ def zero_if_any(params: Iterable[torch.Tensor]) -> None:
140
+ for p in params:
141
+ if p.grad is not None:
142
+ p.grad = None
143
+
144
+
145
+ def any_grad(params: Iterable[torch.Tensor]) -> bool:
146
+ for p in params:
147
+ if p.grad is not None:
148
+ return True
149
+ return False
150
+
151
+ # -----------------------------------------------------------------------------
152
+ # For fine-tuning
153
+ # -----------------------------------------------------------------------------
154
+
155
+ def ensure_trainable_parameters(module: nn.Module, *, requires_grad: bool = True) -> nn.Module:
156
+ """
157
+ Rebuild all parameters as fresh nn.Parameter tensors (detach+clone),
158
+ which drops any 'inference tensor' tag and re-enables autograd.
159
+ """
160
+ for mod in module.modules():
161
+ for name, p in list(mod._parameters.items()):
162
+ if p is None:
163
+ continue
164
+ new_p = nn.Parameter(p.detach().clone(), requires_grad=requires_grad)
165
+ setattr(mod, name, new_p)
166
+ return module
167
+
168
+
169
+ # -----------------------------------------------------------------------------
170
+ # Misc
171
+ # -----------------------------------------------------------------------------
172
+
173
+ @dataclass
174
+ class ExportRounding:
175
+ head_floor_post: int = 1
176
+ head_multiple_post: int = 1
177
+ ffn_min_keep_ratio_post: float = 0.0
178
+ ffn_snap_groups_post: int = 1
179
+
180
+
181
+ def shape_signature_vit(cfg, sample_shape: Tuple[int, int, int, int]) -> Tuple:
182
+ B, C, H, W = sample_shape
183
+ return (
184
+ "ViT",
185
+ sample_shape,
186
+ int(getattr(cfg, "num_attention_heads", 12)),
187
+ int(getattr(cfg, "hidden_size", 768)),
188
+ int(getattr(cfg, "intermediate_size", 3072)),
189
+ int(getattr(cfg, "patch_size", 16)) if not isinstance(getattr(cfg, "patch_size", 16), (tuple, list)) else tuple(getattr(cfg, "patch_size", (16, 16))),
190
+ )
custom_code.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Marker file so Hub shows 'custom code' banner.
huggingface/.ipynb_checkpoints/llama-checkpoint.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace LLaMA/Mistral adapter
2
+
3
+ Bridges the family-agnostic core (gates/export/proxy/train) to HF causal LMs
4
+ (LlamaForCausalLM / MistralForCausalLM, etc.).
5
+
6
+ Responsibilities
7
+ ----------------
8
+ - Attach gates to attention Q heads (and optional KV) + grouped MLP (SwiGLU)
9
+ - Provide a logits getter (student/teacher)
10
+ - Exporters:
11
+ * keep-all (unwrap gates, restore clean HF modules)
12
+ * pruned (slice q_proj/o_proj and SwiGLU up/gate/down; update HF metadata)
13
+ - Grid-search wrapper for post-export rounding/snap params
14
+
15
+ This adapter intentionally keeps the core unaware of LLaMA internals.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ # Ensure repo root on sys.path for absolute imports (core, adapters, data)
20
+ import sys, pathlib
21
+ sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
22
+
23
+ from dataclasses import dataclass
24
+ from typing import Optional, Sequence, Callable, Tuple
25
+
26
+ import copy
27
+ import math
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ # Core (absolute imports so running `-m examples.run_llama_optimize` works)
33
+ from core.gates import HeadGate, GroupGate
34
+ from core.export import (
35
+ ExportPolicy as CoreExportPolicy,
36
+ Rounding as CoreRounding,
37
+ keep_group_indices_from_gate,
38
+ slice_linear,
39
+ )
40
+ from core.utils import deepcopy_eval_cpu
41
+ from core.search_export import grid_search_latency
42
+
43
+ # -------------------------------------------------------------------------
44
+ # Configs
45
+ # -------------------------------------------------------------------------
46
+
47
+ @dataclass
48
+ class LlamaGatingConfig:
49
+ tau: float = 1.5
50
+ init_logit: float = 3.0
51
+ head_gating: bool = True
52
+ gate_kv: bool = False # optional: gate KV along with Q
53
+ ffn_group: int = 128 # SwiGLU groups
54
+ ffn_gating: bool = True
55
+ hard_eval: bool = True # use hard gates in eval forward
56
+
57
+
58
+ # -------------------------------------------------------------------------
59
+ # Helpers (GQA, rotary, cache-safe)
60
+ # -------------------------------------------------------------------------
61
+
62
+
63
+ def _last_nonpad_index(attn_mask: Optional[torch.Tensor], seq_len: int, device) -> torch.Tensor:
64
+ if attn_mask is None:
65
+ return torch.full((1,), seq_len - 1, device=device, dtype=torch.long) # will be expanded per-batch later
66
+ # attn_mask: [B, S] in {0,1}; works for left/right padding
67
+ return (attn_mask.sum(dim=1) - 1).clamp(min=0).long()
68
+
69
+ def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
70
+ if n_rep == 1:
71
+ return x
72
+ B, Hkv, T, Dh = x.shape
73
+ return x.unsqueeze(2).expand(B, Hkv, n_rep, T, Dh).reshape(B, Hkv * n_rep, T, Dh)
74
+
75
+ try:
76
+ from transformers.cache_utils import Cache
77
+ except Exception:
78
+ class Cache: # type: ignore
79
+ pass
80
+
81
+
82
+ # -------------------------------------------------------------------------
83
+ # Gated attention wrapper (Llama/Mistral ready)
84
+ # -------------------------------------------------------------------------
85
+
86
+ class GatedSelfAttentionLLM(nn.Module):
87
+ """
88
+ Thin wrapper around HF Llama/Mistral attention module.
89
+
90
+ - Uses the base module's q_proj/k_proj/v_proj/o_proj
91
+ - Applies per-Q-head gates (and optional KV gates)
92
+ - Handles rotary and cache (tuple or HF Cache)
93
+ - Runs SDPA directly, then o_proj
94
+ """
95
+ def __init__(self, attn_container: nn.Module,
96
+ num_q_heads: int, num_kv_heads: int, head_dim: int,
97
+ cfg: LlamaGatingConfig, layer_idx: int):
98
+ super().__init__()
99
+ self.base_attn = attn_container
100
+ self.q_proj = attn_container.q_proj
101
+ self.k_proj = attn_container.k_proj
102
+ self.v_proj = attn_container.v_proj
103
+ self.o_proj = getattr(attn_container, "o_proj", getattr(attn_container, "out_proj", None))
104
+
105
+ self.num_q_heads = int(num_q_heads)
106
+ self.num_kv_heads = int(num_kv_heads)
107
+ self.head_dim = int(head_dim)
108
+ self.gate_kv = bool(cfg.gate_kv)
109
+ self.drop_p = float(getattr(attn_container, "attention_dropout",
110
+ getattr(attn_container, "attn_dropout",
111
+ getattr(attn_container, "dropout", 0.0))))
112
+ self.head_gate = HeadGate(num_heads=self.num_q_heads,
113
+ head_dim=self.head_dim,
114
+ tau=cfg.tau, init_logit=cfg.init_logit,
115
+ hard_during_eval=cfg.hard_eval)
116
+
117
+ # rotary helpers if present on base
118
+ self.rotary_emb = getattr(attn_container, "rotary_emb", None)
119
+ self.apply_rotary_pos_emb = getattr(attn_container, "apply_rotary_pos_emb", None)
120
+ self.layer_idx = int(layer_idx)
121
+
122
+ @property
123
+ def logits(self) -> torch.Tensor:
124
+ return self.head_gate.logits
125
+
126
+ def kept_heads_soft(self) -> torch.Tensor:
127
+ p = self.head_gate.probs().detach().float().view(-1)
128
+ if p.numel() == self.num_q_heads * self.head_dim:
129
+ p = p.view(self.num_q_heads, self.head_dim).mean(dim=1)
130
+ return p.sum()
131
+
132
+
133
+ def forward(
134
+ self,
135
+ hidden_states: torch.Tensor, # [B,T,D]
136
+ attention_mask: Optional[torch.Tensor] = None, # additive mask [B,1,Tq,Tk] or None
137
+ position_ids: Optional[torch.Tensor] = None,
138
+ past_key_value = None, # tuple, list, Cache or None
139
+ output_attentions: bool = False,
140
+ use_cache: bool = False,
141
+ cache_position: Optional[torch.Tensor] = None,
142
+ position_embeddings: Optional[torch.Tensor] = None,
143
+ **kwargs,
144
+ ):
145
+ B, T, D = hidden_states.shape
146
+ Hq, Hkv, Dh = self.num_q_heads, self.num_kv_heads, self.head_dim
147
+ assert Hq * Dh == D, "hidden_size must equal num_heads * head_dim"
148
+ n_rep = max(1, Hq // Hkv)
149
+
150
+ # qkv projections
151
+ q = self.q_proj(hidden_states).view(B, T, Hq, Dh).transpose(1, 2) # [B,Hq,T,Dh]
152
+ k = self.k_proj(hidden_states).view(B, T, Hkv, Dh).transpose(1, 2) # [B,Hkv,T,Dh]
153
+ v = self.v_proj(hidden_states).view(B, T, Hkv, Dh).transpose(1, 2) # [B,Hkv,T,Dh]
154
+
155
+ # rotary
156
+ if (self.rotary_emb is not None) and (self.apply_rotary_pos_emb is not None):
157
+ Tpast = 0
158
+ if isinstance(past_key_value, (tuple, list)) and len(past_key_value) == 2:
159
+ Tpast = int(past_key_value[0].size(2))
160
+ elif isinstance(past_key_value, Cache):
161
+ Tpast = int(cache_position.max().item() if cache_position is not None else 0)
162
+ seq_len = Tpast + T
163
+ try:
164
+ cos, sin = self.rotary_emb(v, seq_len=seq_len)
165
+ except TypeError:
166
+ cos, sin = self.rotary_emb(q, seq_len=seq_len)
167
+ # try rich signature first
168
+ try:
169
+ q, k = self.apply_rotary_pos_emb(
170
+ q, k, cos, sin,
171
+ position_ids=position_ids,
172
+ cache_position=cache_position,
173
+ position_embeddings=position_embeddings
174
+ )
175
+ except TypeError:
176
+ try:
177
+ q, k = self.apply_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids)
178
+ except TypeError:
179
+ q, k = self.apply_rotary_pos_emb(q, k, cos, sin)
180
+
181
+ # cache merge
182
+ present = None
183
+ if past_key_value is None or (isinstance(past_key_value, (tuple, list)) and len(past_key_value) == 0):
184
+ pass
185
+ elif isinstance(past_key_value, (tuple, list)):
186
+ pk, pv = past_key_value # [B,Hkv,Tpast,Dh]
187
+ k = torch.cat([pk, k], dim=2)
188
+ v = torch.cat([pv, v], dim=2)
189
+ present = (k, v) if use_cache else None
190
+ elif isinstance(past_key_value, Cache):
191
+ k, v = past_key_value.update(k, v, self.layer_idx, cache_position)
192
+ present = past_key_value
193
+
194
+ # gates
195
+ # g = self.head_gate.mask(self.training).view(1, Hq, 1, 1)
196
+ # ---- gates (supports per-head OR per-channel HeadGate) ----
197
+ m = self.head_gate.mask(self.training) # 1D tensor
198
+ m = m.detach() if not self.training else m
199
+ if m.numel() == Hq:
200
+ # per-head gating
201
+ gH = m.view(1, Hq, 1, 1) # [1,Hq,1,1]
202
+ q = q * gH
203
+ if self.gate_kv:
204
+ if n_rep == 1:
205
+ k = k * gH; v = v * gH
206
+ else:
207
+ g_kv = gH.view(1, Hkv, n_rep, 1, 1).amax(dim=2)
208
+ k = k * g_kv; v = v * g_kv
209
+ elif m.numel() == Hq * Dh:
210
+ # per-channel gating
211
+ gHD = m.view(1, Hq, 1, Dh) # [1,Hq,1,Dh]
212
+ q = q * gHD
213
+ if self.gate_kv:
214
+ # collapse to per-head for KV, then map to Hkv via amax over replicas
215
+ gH = gHD.amax(dim=-1, keepdim=True) # [1,Hq,1,1]
216
+ if n_rep == 1:
217
+ k = k * gH; v = v * gH
218
+ else:
219
+ g_kv = gH.view(1, Hkv, n_rep, 1, 1).amax(dim=2)
220
+ k = k * g_kv; v = v * g_kv
221
+ else:
222
+ raise RuntimeError(
223
+ f"HeadGate mask has {m.numel()} elems; expected {Hq} or {Hq*Dh}"
224
+ )
225
+
226
+
227
+ # GQA replicate KV to Q count
228
+ k = _repeat_kv(k, n_rep)
229
+ v = _repeat_kv(v, n_rep)
230
+
231
+ attn = F.scaled_dot_product_attention(
232
+ q, k, v,
233
+ attn_mask=attention_mask,
234
+ dropout_p=self.drop_p if self.training else 0.0,
235
+ is_causal=True
236
+ )
237
+ out = attn.transpose(1, 2).contiguous().view(B, T, Hq * Dh)
238
+ out = self.o_proj(out)
239
+
240
+ attn_weights = None
241
+ # HF expects (attn_output, attn_weights, present_key_value) always
242
+ if output_attentions:
243
+ return (out, attn_weights, present)
244
+ else:
245
+ return (out, None, present)
246
+
247
+
248
+
249
+ # -------------------------------------------------------------------------
250
+ # Adapter
251
+ # -------------------------------------------------------------------------
252
+
253
+ class LlamaAdapter:
254
+ def __init__(self, model: nn.Module):
255
+ self.model = model
256
+ core = getattr(model, "model", model)
257
+ if not hasattr(core, "layers"):
258
+ raise ValueError("Provided model does not look like HF LLaMA/Mistral (missing .model.layers or .layers)")
259
+
260
+ # ---------- Gating attachment ----------
261
+ def attach_gates(self, cfg: LlamaGatingConfig) -> nn.Module:
262
+ m = self.model
263
+ core = getattr(m, "model", m)
264
+ layers = core.layers
265
+
266
+ Hq = int(core.config.num_attention_heads)
267
+ Hkv = int(getattr(core.config, "num_key_value_heads", Hq))
268
+ Dh = int(core.config.hidden_size // Hq)
269
+
270
+ for li, layer in enumerate(layers):
271
+ # Attention heads
272
+ if cfg.head_gating:
273
+ base = layer.self_attn
274
+ if not isinstance(base, GatedSelfAttentionLLM):
275
+ gated = GatedSelfAttentionLLM(
276
+ attn_container=base,
277
+ num_q_heads=Hq,
278
+ num_kv_heads=Hkv,
279
+ head_dim=Dh,
280
+ cfg=cfg,
281
+ layer_idx=li,
282
+ )
283
+ layer.self_attn = gated # route via our wrapper
284
+
285
+ # MLP grouped gating (SwiGLU)
286
+ if cfg.ffn_gating:
287
+ mlp = layer.mlp
288
+ I = int(mlp.up_proj.out_features)
289
+ assert I % cfg.ffn_group == 0, f"SwiGLU size {I} not divisible by group {cfg.ffn_group}"
290
+ if not hasattr(mlp, "neuron_gate"):
291
+ mlp.neuron_gate = GroupGate(
292
+ num_groups=I // cfg.ffn_group,
293
+ group_size=cfg.ffn_group,
294
+ tau=cfg.tau, init_logit=cfg.init_logit,
295
+ hard_during_eval=cfg.hard_eval,
296
+ )
297
+ if not hasattr(mlp, "_orig_forward"):
298
+ mlp._orig_forward = mlp.forward
299
+
300
+ def _gated_mlp_forward(this, x):
301
+ # LLaMA: z = silu(up(x)) * (gate(x) * m); out = down(z)
302
+ u = this.up_proj(x)
303
+ g = this.gate_proj(x)
304
+ m = this.neuron_gate.mask(this.training).view(1, 1, -1)
305
+ z = torch.nn.functional.silu(u) * (g * m)
306
+ return this.down_proj(z)
307
+
308
+ mlp.forward = _gated_mlp_forward.__get__(mlp, mlp.__class__)
309
+ return m
310
+
311
+ # ---------- Logits helper ----------
312
+ @staticmethod
313
+ def _last_token_index(attn_mask: torch.Tensor) -> torch.Tensor:
314
+ # attn_mask: [B, S] with 1 for tokens, 0 for padding
315
+ # returns [B] indices of last non-pad
316
+ # works for both bool and int masks
317
+ if attn_mask is None:
318
+ # no mask → use last position S-1
319
+ return None
320
+ if attn_mask.dtype != torch.long:
321
+ attn_mask = attn_mask.to(torch.long)
322
+ # idx = lengths - 1
323
+ return (attn_mask.sum(dim=-1) - 1).clamp_min(0)
324
+
325
+ @staticmethod
326
+ def get_logits(model: nn.Module,
327
+ input_ids: torch.Tensor,
328
+ attention_mask: Optional[torch.Tensor] = None,
329
+ last_only: bool = True,
330
+ **forward_kwargs) -> torch.Tensor:
331
+ """
332
+ Returns logits. If last_only=True, computes ONLY the last-token logits by:
333
+ 1) getting hidden states from the base decoder,
334
+ 2) selecting last non-pad position per sample,
335
+ 3) projecting through lm_head on that 1 position.
336
+ This avoids allocating [B,S,V].
337
+ """
338
+ # (1) run base decoder, not the full CausalLM head
339
+ core = getattr(model, "model", None)
340
+ if core is None:
341
+ # fallback if the model is already a bare decoder (rare)
342
+ core = model
343
+
344
+ # We only need last_hidden_state; no cache; avoid building logits for all S
345
+ # return_dict=False to grab tuple and avoid extra allocations
346
+ outputs = core(
347
+ input_ids=input_ids,
348
+ attention_mask=attention_mask,
349
+ use_cache=False,
350
+ return_dict=False,
351
+ **forward_kwargs
352
+ )
353
+ hidden = outputs[0] # [B, S, D]
354
+
355
+ if not last_only:
356
+ # If someone explicitly wants all logits, fine:
357
+ return model.lm_head(hidden) # [B,S,V] (expensive!)
358
+
359
+ # (2) select last token per sample
360
+ B, S, D = hidden.shape
361
+ if attention_mask is None:
362
+ # simple "last index"
363
+ idx = torch.full((B,), S - 1, device=hidden.device, dtype=torch.long)
364
+ else:
365
+ idx = LlamaAdapter._last_token_index(attention_mask)
366
+
367
+ # gather last hidden: [B, D]
368
+ last_h = hidden[torch.arange(B, device=hidden.device), idx] # [B, D]
369
+ # (3) project to logits for that 1 position
370
+ last_logits = model.lm_head(last_h).unsqueeze(1) # [B,1,V]
371
+ return last_logits
372
+
373
+ # ---------- Exporters ----------
374
+ @staticmethod
375
+ @torch.no_grad()
376
+ def export_keepall(model_with_gates: nn.Module) -> nn.Module:
377
+ """
378
+ Unwrap attention wrappers; restore original MLP.forward; drop gates.
379
+ """
380
+ slim = deepcopy_eval_cpu(model_with_gates)
381
+ core = getattr(slim, "model", slim)
382
+ if not hasattr(core, "layers"):
383
+ return slim
384
+
385
+ for layer in core.layers:
386
+ # attention
387
+ attn = layer.self_attn
388
+ if isinstance(attn, GatedSelfAttentionLLM):
389
+ gat = attn
390
+ new_attn = copy.deepcopy(gat.base_attn)
391
+ # keep metadata consistent
392
+ if hasattr(new_attn, "num_heads"):
393
+ new_attn.num_heads = int(gat.num_q_heads)
394
+ if hasattr(new_attn, "num_key_value_heads"):
395
+ new_attn.num_key_value_heads = int(gat.num_kv_heads)
396
+ if hasattr(new_attn, "head_dim"):
397
+ new_attn.head_dim = int(gat.head_dim)
398
+ layer.self_attn = new_attn
399
+
400
+ # mlp
401
+ mlp = layer.mlp
402
+ if hasattr(mlp, "_orig_forward"):
403
+ mlp.forward = mlp._orig_forward
404
+ delattr(mlp, "_orig_forward")
405
+ if hasattr(mlp, "neuron_gate"):
406
+ delattr(mlp, "neuron_gate")
407
+
408
+ return slim
409
+
410
+ @staticmethod
411
+ @torch.no_grad()
412
+ def export_pruned(model_with_gates: nn.Module, policy, step: int) -> nn.Module:
413
+ """
414
+ Produce a clean CPU eval model:
415
+ - Read gates to choose Q heads; slice q_proj rows and o_proj cols
416
+ - Snap kept heads to an LCM of (policy multiple, Hkv)
417
+ - Slice SwiGLU up/gate/down by groups
418
+ - Unwrap back to plain HF modules; update metadata
419
+ """
420
+ # Accept either CoreExportPolicy with per-axis rounding or family policy
421
+ if isinstance(policy, LlamaExportPolicy):
422
+ head_rounding = policy.head_rounding
423
+ ffn_rounding = policy.ffn_rounding
424
+ warmup_steps = policy.warmup_steps
425
+ else:
426
+ head_rounding = getattr(policy, "rounding", None)
427
+ ffn_rounding = getattr(policy, "rounding", None)
428
+ warmup_steps = int(getattr(policy, "warmup_steps", 0))
429
+
430
+ slim = deepcopy_eval_cpu(model_with_gates)
431
+ core = getattr(slim, "model", slim)
432
+ layers = getattr(core, "layers", None)
433
+ if layers is None:
434
+ return slim
435
+
436
+ warm = (step < warmup_steps)
437
+
438
+ def _lcm(a: int, b: int) -> int:
439
+ return abs(a * b) // math.gcd(max(a, 1), max(b, 1)) if a > 0 and b > 0 else max(a, b, 1)
440
+
441
+ for li, layer in enumerate(layers):
442
+ # ---- Attention (Q heads) ----
443
+ attn = layer.self_attn
444
+ if isinstance(attn, GatedSelfAttentionLLM):
445
+ gat = attn
446
+ base = gat.base_attn
447
+
448
+ Hq = int(gat.num_q_heads)
449
+ Hkv = int(gat.num_kv_heads)
450
+ Dh = int(gat.head_dim)
451
+
452
+ if warm:
453
+ keep_idx = torch.arange(Hq)
454
+ else:
455
+ # Build a "per-head" proxy gate if base gate is per-channel.
456
+ base_logits = gat.head_gate.logits.detach().float().view(-1)
457
+ tau = float(getattr(gat.head_gate, "tau", 1.0))
458
+
459
+ if base_logits.numel() == Hq:
460
+ # Native per-head gate: use as-is
461
+ proxy_gate = gat.head_gate
462
+ keep_idx = keep_group_indices_from_gate(
463
+ proxy_gate, policy=policy, step=step, custom_rounding=head_rounding
464
+ )
465
+ elif base_logits.numel() == Hq * Dh:
466
+ # Collapse per-channel → per-head (mean; or use .amax for stricter)
467
+ per_head_logits = base_logits.view(Hq, Dh).mean(dim=1)
468
+
469
+ class _PerHeadProxyGate:
470
+ def __init__(self, logits, tau):
471
+ self.logits = logits
472
+ self.tau = tau
473
+ self.num_groups = logits.numel()
474
+ self.group_size = 1
475
+
476
+ proxy_gate = _PerHeadProxyGate(per_head_logits, tau)
477
+ keep_idx = keep_group_indices_from_gate(
478
+ proxy_gate, policy=policy, step=step, custom_rounding=head_rounding
479
+ )
480
+ else:
481
+ raise RuntimeError(
482
+ f"Unexpected HeadGate logits len {base_logits.numel()} vs H={Hq} or H*Dh={Hq*Dh}"
483
+ )
484
+
485
+ # Enforce LCM with GQA (Hkv) via truncation to floor-multiple
486
+ def _lcm(a: int, b: int) -> int:
487
+ import math
488
+ return abs(a * b) // math.gcd(max(a, 1), max(b, 1)) if a > 0 and b > 0 else max(a, b, 1)
489
+
490
+ pol_mult = getattr(head_rounding, "multiple_groups", 1)
491
+ snap = _lcm(int(pol_mult), max(1, Hkv))
492
+ if keep_idx.numel() % snap != 0:
493
+ k = (keep_idx.numel() // snap) * snap
494
+ k = max(snap, min(Hq, k))
495
+ # recompute top-k by per-head logits (ensure same criterion used above)
496
+ if base_logits.numel() == Hq * Dh:
497
+ scores = per_head_logits
498
+ else:
499
+ scores = base_logits
500
+ keep_idx = torch.topk(scores, k=k, largest=True).indices.sort().values
501
+
502
+
503
+ H_keep = int(keep_idx.numel())
504
+ # channels for q/o slicing
505
+ ch_idx = torch.cat([torch.arange(h * Dh, (h + 1) * Dh) for h in keep_idx]).long()
506
+
507
+ # slice wrapper linears
508
+ gat.q_proj = slice_linear(gat.q_proj, keep_out=ch_idx)
509
+ gat.o_proj = slice_linear(gat.o_proj, keep_in=ch_idx)
510
+
511
+ # transplant into a clean HF attention
512
+ new_attn = copy.deepcopy(base)
513
+ if hasattr(new_attn, "q_proj"):
514
+ new_attn.q_proj = gat.q_proj
515
+ if hasattr(new_attn, "o_proj"):
516
+ new_attn.o_proj = gat.o_proj
517
+ elif hasattr(new_attn, "out_proj"):
518
+ new_attn.out_proj = gat.o_proj
519
+
520
+ # update metadata
521
+ if hasattr(new_attn, "num_heads"):
522
+ new_attn.num_heads = int(H_keep)
523
+ if hasattr(new_attn, "num_key_value_heads"):
524
+ new_attn.num_key_value_heads = int(Hkv)
525
+ if hasattr(new_attn, "head_dim"):
526
+ new_attn.head_dim = int(Dh)
527
+ if hasattr(core.config, "hidden_size"):
528
+ core.config.hidden_size = int(H_keep * Dh)
529
+
530
+ layer.self_attn = new_attn # unwrap
531
+
532
+ # ---- MLP (SwiGLU grouped) ----
533
+ mlp = layer.mlp
534
+ g = getattr(mlp, "neuron_gate", None)
535
+ if g is not None:
536
+ grp_idx = keep_group_indices_from_gate(
537
+ g, policy=policy, step=step, custom_rounding=ffn_rounding,
538
+ )
539
+ group = int(g.group_size) # GroupGate exposes group_size
540
+ keep_exp = torch.cat([torch.arange(i * group, (i + 1) * group) for i in grp_idx]).long()
541
+
542
+ mlp.up_proj = slice_linear(mlp.up_proj, keep_out=keep_exp)
543
+ mlp.gate_proj = slice_linear(mlp.gate_proj, keep_out=keep_exp)
544
+ mlp.down_proj = slice_linear(mlp.down_proj, keep_in=keep_exp)
545
+
546
+ # Restore clean forward & drop gate
547
+ if hasattr(mlp, "_orig_forward"):
548
+ mlp.forward = mlp._orig_forward
549
+ delattr(mlp, "_orig_forward")
550
+ if hasattr(mlp, "neuron_gate"):
551
+ delattr(mlp, "neuron_gate")
552
+
553
+ return slim
554
+
555
+
556
+ # -------------------------------------------------------------------------
557
+ # Export policy (allow different rounding for Heads vs FFN)
558
+ # -------------------------------------------------------------------------
559
+
560
+ @dataclass
561
+ class LlamaExportPolicy:
562
+ warmup_steps: int = 0
563
+ head_rounding: CoreRounding = CoreRounding() # e.g., CoreRounding(floor=8, multiple=8)
564
+ ffn_rounding: CoreRounding = CoreRounding() # e.g., CoreRounding(min_keep_ratio=0.8, multiple=32)
565
+
566
+
567
+ # -------------------------------------------------------------------------
568
+ # Grid-search convenience
569
+ # -------------------------------------------------------------------------
570
+
571
+ @dataclass
572
+ class LlamaGrid:
573
+ head_multiple_grid: Optional[Sequence[int]] = (1, 2, 4, 8)
574
+ ffn_snap_grid: Sequence[int] = (1, 32, 64, 128)
575
+
576
+ def llama_search_best_export(
577
+ model_with_gates: nn.Module,
578
+ *,
579
+ export_fn: Callable[[nn.Module, CoreExportPolicy, int], nn.Module],
580
+ num_q_heads: int,
581
+ num_kv_heads: int,
582
+ step: int,
583
+ batch_shape: Tuple[int, int], # (B,S) for text
584
+ grid: Optional[LlamaGrid] = None,
585
+ device: str = "cuda",
586
+ measure_settings=None,
587
+ make_policy: Optional[Callable[[int, int], object]] = None,
588
+ ):
589
+ """
590
+ Convenience wrapper for LLaMA-style search.
591
+ Uses the same `grid_search_latency` as ViT; we just feed head/ffn grids.
592
+ """
593
+ g = grid or LlamaGrid()
594
+ head_grid = g.head_multiple_grid or [1]
595
+ ffn_grid = list(g.ffn_snap_grid)
596
+
597
+ return grid_search_latency(
598
+ model_with_gates,
599
+ export_fn,
600
+ head_multiples=head_grid,
601
+ ffn_snaps=ffn_grid,
602
+ step=step,
603
+ batch_shape=batch_shape, # adapter’s runner should interpret as (B,S)
604
+ measure_settings=measure_settings,
605
+ device=device,
606
+ make_policy=make_policy,
607
+ )
huggingface/.ipynb_checkpoints/vit-checkpoint.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace ViT adapter
2
+
3
+ Bridges the family-agnostic core (gates/export/proxy/train) to ViT-like models
4
+ from Hugging Face (`ViTModel`, `ViTForImageClassification`, DeiT, etc.).
5
+
6
+ Responsibilities
7
+ ----------------
8
+ - Attach gates to attention heads and MLP hidden in groups
9
+ - Provide logits getters for student/teacher
10
+ - Export helpers: keep-all (remove gates), and pruned (slice weights + metadata)
11
+
12
+ This adapter intentionally keeps the core unaware of ViT internals.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ # Ensure repo root on sys.path for absolute imports (core, adapters, data)
17
+ import sys, pathlib
18
+ sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional
22
+
23
+ import copy
24
+ import torch
25
+ import torch.nn as nn
26
+
27
+ # NOTE: absolute imports so running `-m examples.run_vit_optimize` works without package install
28
+ from core.gates import HeadGate, GroupGate
29
+ from core.export import (
30
+ ExportPolicy as CoreExportPolicy,
31
+ Rounding as CoreRounding,
32
+ keep_group_indices_from_gate,
33
+ keep_element_indices_from_gate,
34
+ slice_linear,
35
+ Rounding as CoreRounding,
36
+ )
37
+
38
+ from core.utils import deepcopy_eval_cpu
39
+ from core.search_export import grid_search_latency
40
+
41
+ # -----------------------------------------------------------------------------
42
+ # Config
43
+ # -----------------------------------------------------------------------------
44
+
45
+ @dataclass
46
+ class ViTGatingConfig:
47
+ tau: float = 1.5
48
+ init_logit: float = 3.0
49
+ head_gating: bool = True
50
+ ffn_group: int = 16
51
+ ffn_gating: bool = True
52
+ hard_eval: bool = True # use hard masks in eval mode during forward
53
+
54
+
55
+
56
+ def _encoder_layers(m: nn.Module):
57
+ """
58
+ Return the sequence of Transformer blocks for HF ViT.
59
+ Supports:
60
+ - ViTModel: m.encoder.layer
61
+ - ViTForImageClassification: m.vit.encoder.layer
62
+ """
63
+ # ViTModel path
64
+ enc = getattr(m, "encoder", None)
65
+ if enc is not None and hasattr(enc, "layer"):
66
+ return enc.layer
67
+
68
+ # ViTForImageClassification path
69
+ vit = getattr(m, "vit", None)
70
+ if vit is not None and hasattr(vit, "encoder") and hasattr(vit.encoder, "layer"):
71
+ return vit.encoder.layer
72
+
73
+ raise ValueError("Provided model does not look like a HF ViT (missing *.encoder.layer)")
74
+
75
+
76
+
77
+ # -----------------------------------------------------------------------------
78
+ # Gated attention wrapper
79
+ # -----------------------------------------------------------------------------
80
+
81
+ class GatedSelfAttentionHF(nn.Module):
82
+ """A thin wrapper around HF ViT self-attention that multiplies per-head gates.
83
+
84
+ It keeps references to the underlying query/key/value `nn.Linear` layers and
85
+ the output projection, while exposing a `HeadGate` in `head_gate`.
86
+ """
87
+
88
+ def __init__(self, attn_container: nn.Module, num_heads: int, head_dim: int, cfg: ViTGatingConfig):
89
+ super().__init__()
90
+ base_attn = attn_container.attention # ViTSdpaSelfAttention or ViTSelfAttention
91
+ out_proj = attn_container.output.dense
92
+
93
+ self.base_attn = base_attn
94
+ self.out_proj = out_proj
95
+
96
+ self.q_proj = base_attn.query
97
+ self.k_proj = base_attn.key
98
+ self.v_proj = base_attn.value
99
+
100
+ self.num_heads = int(num_heads)
101
+ self.head_dim = int(head_dim)
102
+ self.drop_p = getattr(base_attn, "dropout", nn.Dropout(0.0)).p
103
+
104
+ self.head_gate = HeadGate(num_heads=self.num_heads, head_dim=self.head_dim, tau=cfg.tau, init_logit=cfg.init_logit, hard_during_eval=cfg.hard_eval)
105
+
106
+ @property
107
+ def logits(self) -> torch.Tensor:
108
+ return self.head_gate.logits
109
+
110
+ def kept_heads_soft(self) -> torch.Tensor:
111
+ return self.head_gate.probs().sum()
112
+
113
+ def forward(self, hidden_states, head_mask=None):
114
+ B, N, _ = hidden_states.shape
115
+ H, Dh = self.num_heads, self.head_dim
116
+
117
+ wdev = self.q_proj.weight.device
118
+ if hidden_states.device != wdev:
119
+ hidden_states = hidden_states.to(wdev, non_blocking=True)
120
+
121
+ q_lin = self.q_proj(hidden_states)
122
+ k_lin = self.k_proj(hidden_states)
123
+ v_lin = self.v_proj(hidden_states)
124
+
125
+ q = q_lin.view(B, N, H, Dh).transpose(1, 2)
126
+ k = k_lin.view(B, N, H, Dh).transpose(1, 2)
127
+ v = v_lin.view(B, N, H, Dh).transpose(1, 2)
128
+
129
+ logits = self.head_gate.logits
130
+ tau = float(self.head_gate.tau)
131
+ if self.training:
132
+ u = torch.rand_like(logits).clamp_(1e-6, 1-1e-6)
133
+ s = u.log() - (1 - u).log()
134
+ y = torch.sigmoid((logits + s) / tau)
135
+ g_head = ((y > 0.5).to(y.dtype) - y).detach() + y
136
+ else:
137
+ if getattr(self.head_gate, 'hard_during_eval', True):
138
+ g_head = (logits > 0).to(logits.dtype)
139
+ else:
140
+ g_head = torch.sigmoid(logits / tau)
141
+ g = g_head.view(1, H, 1, 1)
142
+
143
+ q = q * g; k = k * g; v = v * g
144
+
145
+ attn_out = torch.nn.functional.scaled_dot_product_attention(
146
+ q, k, v, dropout_p=self.drop_p if self.training else 0.0
147
+ ) # [B, H, N, Dh]
148
+
149
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, N, H * Dh)
150
+ attn_out = self.out_proj(attn_out)
151
+ return attn_out, None
152
+
153
+
154
+ # -----------------------------------------------------------------------------
155
+ # Adapter
156
+ # -----------------------------------------------------------------------------
157
+
158
+ class ViTAdapter:
159
+ def __init__(self, model: nn.Module):
160
+ self.model = model
161
+ _ = _encoder_layers(model)
162
+
163
+ # ---------- Gating attachment ----------
164
+ def attach_gates(self, cfg: ViTGatingConfig) -> nn.Module:
165
+ m = self.model
166
+ H = int(getattr(m.config, "num_attention_heads", 12))
167
+ D = int(getattr(m.config, "hidden_size", 768))
168
+ Dh = D // H
169
+
170
+ for layer in _encoder_layers(m):
171
+ # Attention heads
172
+ if cfg.head_gating:
173
+ attn_container = layer.attention
174
+ if not isinstance(getattr(attn_container, "attention", None), GatedSelfAttentionHF):
175
+ gated = GatedSelfAttentionHF(attn_container, H, Dh, cfg)
176
+ attn_container.attention = gated
177
+
178
+ # FFN hidden (grouped)
179
+ if cfg.ffn_gating:
180
+ inter = layer.intermediate
181
+ d_ff = int(inter.dense.out_features)
182
+ assert d_ff % cfg.ffn_group == 0, f"FFN size {d_ff} not divisible by group {cfg.ffn_group}"
183
+ if not hasattr(inter, "neuron_gate"):
184
+ inter.neuron_gate = GroupGate(num_groups=d_ff // cfg.ffn_group, group_size=cfg.ffn_group, tau=cfg.tau, init_logit=cfg.init_logit, hard_during_eval=cfg.hard_eval)
185
+ # Monkey-patch forward to apply mask after activation (keeps HF shapes)
186
+ if not hasattr(inter, "_orig_forward"):
187
+ inter._orig_forward = inter.forward
188
+
189
+ def _gated_forward(this, x):
190
+ h = this.dense(x)
191
+ h = this.intermediate_act_fn(h)
192
+ msk = this.neuron_gate.mask(this.training).view(1, 1, -1)
193
+ return h * msk
194
+
195
+ inter.forward = _gated_forward.__get__(inter, inter.__class__)
196
+ return m
197
+
198
+ # ---------- Logits helpers ----------
199
+ @staticmethod
200
+ def get_logits(model: nn.Module, x: torch.Tensor, *, head: Optional[nn.Module] = None) -> torch.Tensor:
201
+ out = model(pixel_values=x)
202
+ if hasattr(out, "logits"):
203
+ return out.logits # ViTForImageClassification path
204
+ if hasattr(out, "last_hidden_state"): # ViTModel path (needs external head)
205
+ if head is None:
206
+ raise ValueError("Provide a classification head when using ViTModel without logits.")
207
+ cls_tok = out.last_hidden_state[:, 0, :]
208
+ if next(head.parameters(), torch.tensor([], device=cls_tok.device)).device != cls_tok.device:
209
+ head = head.to(cls_tok.device)
210
+ return head(cls_tok)
211
+ raise ValueError("Model output lacks logits and last_hidden_state.")
212
+
213
+
214
+ # ---------- Exporters ----------
215
+ @staticmethod
216
+ @torch.no_grad()
217
+ def export_keepall(model_with_gates: nn.Module) -> nn.Module:
218
+ slim = deepcopy_eval_cpu(model_with_gates)
219
+ for layer in _encoder_layers(slim):
220
+ # Attention: unwrap gate
221
+ attn_container = layer.attention
222
+ if isinstance(getattr(attn_container, "attention", None), GatedSelfAttentionHF):
223
+ gat = attn_container.attention
224
+ new_attn = copy.deepcopy(gat.base_attn)
225
+ # restore HF metadata if present
226
+ if hasattr(new_attn, "num_attention_heads"):
227
+ new_attn.num_attention_heads = int(gat.num_heads)
228
+ if hasattr(new_attn, "attention_head_size"):
229
+ new_attn.attention_head_size = int(gat.head_dim)
230
+ if hasattr(new_attn, "all_head_size"):
231
+ new_attn.all_head_size = int(gat.num_heads * gat.head_dim)
232
+ attn_container.attention = new_attn
233
+ # FFN: restore original forward and drop gate
234
+ inter = layer.intermediate
235
+ if hasattr(inter, "_orig_forward"):
236
+ inter.forward = inter._orig_forward
237
+ delattr(inter, "_orig_forward")
238
+ if hasattr(inter, "neuron_gate"):
239
+ delattr(inter, "neuron_gate")
240
+ return slim
241
+
242
+ @staticmethod
243
+ @torch.no_grad()
244
+ def export_pruned(model_with_gates: nn.Module, policy, step: int) -> nn.Module:
245
+ # Support both CoreExportPolicy (single rounding) and ViTExportPolicy (per-axis)
246
+ if isinstance(policy, ViTExportPolicy):
247
+ head_rounding = policy.head_rounding
248
+ ffn_rounding = policy.ffn_rounding
249
+ warmup_steps = policy.warmup_steps
250
+ else:
251
+ # fallback to single rounding for both
252
+ head_rounding = getattr(policy, "rounding", None)
253
+ ffn_rounding = getattr(policy, "rounding", None)
254
+ warmup_steps = int(getattr(policy, "warmup_steps", 0))
255
+
256
+ slim = deepcopy_eval_cpu(model_with_gates)
257
+ warm = (step < warmup_steps)
258
+
259
+ for layer in _encoder_layers(slim):
260
+ # --- Attention heads ---
261
+ attn_container = layer.attention
262
+ gat = getattr(attn_container, "attention", None)
263
+ if isinstance(gat, GatedSelfAttentionHF):
264
+ # choose rounding
265
+ rnd = head_rounding
266
+ # decide head indices via our helper; honor warmup if needed by passing step
267
+ grp_idx = keep_group_indices_from_gate(
268
+ gat.head_gate,
269
+ policy=policy,
270
+ step=step,
271
+ custom_rounding=rnd,
272
+ )
273
+ H_keep = int(grp_idx.numel())
274
+ Dh = int(gat.head_dim)
275
+
276
+ ch_idx = torch.cat([torch.arange(h * Dh, (h + 1) * Dh) for h in grp_idx]).long()
277
+ gat.q_proj = slice_linear(gat.q_proj, keep_out=ch_idx)
278
+ gat.k_proj = slice_linear(gat.k_proj, keep_out=ch_idx)
279
+ gat.v_proj = slice_linear(gat.v_proj, keep_out=ch_idx)
280
+ attn_container.output.dense = slice_linear(attn_container.output.dense, keep_in=ch_idx)
281
+
282
+ new_attn = copy.deepcopy(gat.base_attn)
283
+ new_attn.query = gat.q_proj
284
+ new_attn.key = gat.k_proj
285
+ new_attn.value = gat.v_proj
286
+ if hasattr(new_attn, "num_attention_heads"):
287
+ new_attn.num_attention_heads = H_keep
288
+ if hasattr(new_attn, "attention_head_size"):
289
+ new_attn.attention_head_size = Dh
290
+ if hasattr(new_attn, "all_head_size"):
291
+ new_attn.all_head_size = H_keep * Dh
292
+ attn_container.attention = new_attn
293
+
294
+ # --- FFN groups ---
295
+ inter, out = layer.intermediate, layer.output
296
+ g = getattr(inter, "neuron_gate", None)
297
+ if g is not None:
298
+ rnd = ffn_rounding
299
+ grp_idx = keep_group_indices_from_gate(
300
+ g,
301
+ policy=policy,
302
+ step=step,
303
+ custom_rounding=rnd,
304
+ )
305
+ group = int(g.group_size)
306
+ keep_exp = torch.cat([torch.arange(i * group, (i + 1) * group) for i in grp_idx]).long()
307
+ inter.dense = slice_linear(inter.dense, keep_out=keep_exp)
308
+ out.dense = slice_linear(out.dense, keep_in=keep_exp)
309
+
310
+ # # restore clean forward & drop gate
311
+ # if hasattr(inter, "_orig_forward"):
312
+ # def _clean_forward(this, x):
313
+ # h = this.dense(x)
314
+ # return this.intermediate_act_fn(h)
315
+ # inter.forward = _clean_forward.__get__(inter, inter.__class__)
316
+ # delattr(inter, "_orig_forward")
317
+ # if hasattr(inter, "neuron_gate"):
318
+ # delattr(inter, "neuron_gate")
319
+
320
+ inter.forward = inter.__class__.forward.__get__(inter, inter.__class__)
321
+ if hasattr(inter, "neuron_gate"):
322
+ delattr(inter, "neuron_gate")
323
+ if hasattr(inter, "_orig_forward"):
324
+ delattr(inter, "_orig_forward")
325
+
326
+
327
+ return slim
328
+
329
+
330
+
331
+
332
+ # -----------------------------------------------------------------------------
333
+ # Export policy
334
+ # -----------------------------------------------------------------------------
335
+ """ViT-specific export policy that allows different rounding for heads vs FFN."""
336
+ @dataclass
337
+ class ViTExportPolicy:
338
+ warmup_steps: int = 0
339
+ head_rounding: CoreRounding = CoreRounding()
340
+ ffn_rounding: CoreRounding = CoreRounding()
341
+
342
+
343
+ @dataclass
344
+ class ViTGrid:
345
+ head_multiple_grid: Optional[Sequence[int]] = (2, 4, 8)
346
+ ffn_snap_grid: Sequence[int] = (1, 8)
347
+ # head_multiple_grid: Optional[Sequence[int]] = None # default --> 1..num_heads
348
+ # ffn_snap_grid: Sequence[int] = (1, 2, 4, 8, 16)
349
+
350
+
351
+ def vit_search_best_export(
352
+ model_with_gates: nn.Module,
353
+ *,
354
+ export_fn: ExportFn,
355
+ num_heads: int,
356
+ step: int,
357
+ batch_shape: Tuple[int, int, int, int],
358
+ grid: Optional[ViTGrid] = None,
359
+ device: str = "cuda",
360
+ measure_settings: Optional[ProfileSettings] = None,
361
+ make_policy: Optional[Callable[[int, int], object]] = None,
362
+ ) -> SearchResult:
363
+ """Convenience wrapper for ViT-style search.
364
+
365
+ If `make_policy` is not provided, the caller's adapter should accept a
366
+ policy with separate head/FFN rounding; see `adapters.huggingface.vit.ViTExportPolicy`.
367
+ """
368
+ g = grid or ViTGrid()
369
+ head_grid = g.head_multiple_grid or list(range(1, int(num_heads) + 1))
370
+ ffn_grid = list(g.ffn_snap_grid)
371
+
372
+ return grid_search_latency(
373
+ model_with_gates,
374
+ export_fn,
375
+ head_multiples=head_grid,
376
+ ffn_snaps=ffn_grid,
377
+ step=step,
378
+ batch_shape=batch_shape,
379
+ measure_settings=measure_settings,
380
+ device=device,
381
+ make_policy=make_policy,
382
+ )
383
+
huggingface/__init__.py ADDED
File without changes
huggingface/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
huggingface/__pycache__/vit.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
huggingface/llama.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace LLaMA/Mistral adapter
2
+
3
+ Bridges the family-agnostic core (gates/export/proxy/train) to HF causal LMs
4
+ (LlamaForCausalLM / MistralForCausalLM, etc.).
5
+
6
+ Responsibilities
7
+ ----------------
8
+ - Attach gates to attention Q heads (and optional KV) + grouped MLP (SwiGLU)
9
+ - Provide a logits getter (student/teacher)
10
+ - Exporters:
11
+ * keep-all (unwrap gates, restore clean HF modules)
12
+ * pruned (slice q_proj/o_proj and SwiGLU up/gate/down; update HF metadata)
13
+ - Grid-search wrapper for post-export rounding/snap params
14
+
15
+ This adapter intentionally keeps the core unaware of LLaMA internals.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ # Ensure repo root on sys.path for absolute imports (core, adapters, data)
20
+ import sys, pathlib
21
+ sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
22
+
23
+ from dataclasses import dataclass
24
+ from typing import Optional, Sequence, Callable, Tuple
25
+
26
+ import copy
27
+ import math
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ # Core (absolute imports so running `-m examples.run_llama_optimize` works)
33
+ from core.gates import HeadGate, GroupGate
34
+ from core.export import (
35
+ ExportPolicy as CoreExportPolicy,
36
+ Rounding as CoreRounding,
37
+ keep_group_indices_from_gate,
38
+ slice_linear,
39
+ )
40
+ from core.utils import deepcopy_eval_cpu
41
+ from core.search_export import grid_search_latency
42
+
43
+ # -------------------------------------------------------------------------
44
+ # Configs
45
+ # -------------------------------------------------------------------------
46
+
47
+ @dataclass
48
+ class LlamaGatingConfig:
49
+ tau: float = 1.5
50
+ init_logit: float = 3.0
51
+ head_gating: bool = True
52
+ gate_kv: bool = False # optional: gate KV along with Q
53
+ ffn_group: int = 128 # SwiGLU groups
54
+ ffn_gating: bool = True
55
+ hard_eval: bool = True # use hard gates in eval forward
56
+
57
+
58
+ # -------------------------------------------------------------------------
59
+ # Helpers (GQA, rotary, cache-safe)
60
+ # -------------------------------------------------------------------------
61
+
62
+
63
+ def _last_nonpad_index(attn_mask: Optional[torch.Tensor], seq_len: int, device) -> torch.Tensor:
64
+ if attn_mask is None:
65
+ return torch.full((1,), seq_len - 1, device=device, dtype=torch.long) # will be expanded per-batch later
66
+ # attn_mask: [B, S] in {0,1}; works for left/right padding
67
+ return (attn_mask.sum(dim=1) - 1).clamp(min=0).long()
68
+
69
+ def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
70
+ if n_rep == 1:
71
+ return x
72
+ B, Hkv, T, Dh = x.shape
73
+ return x.unsqueeze(2).expand(B, Hkv, n_rep, T, Dh).reshape(B, Hkv * n_rep, T, Dh)
74
+
75
+ try:
76
+ from transformers.cache_utils import Cache
77
+ except Exception:
78
+ class Cache: # type: ignore
79
+ pass
80
+
81
+
82
+ # -------------------------------------------------------------------------
83
+ # Gated attention wrapper (Llama/Mistral ready)
84
+ # -------------------------------------------------------------------------
85
+
86
+ class GatedSelfAttentionLLM(nn.Module):
87
+ """
88
+ Thin wrapper around HF Llama/Mistral attention module.
89
+
90
+ - Uses the base module's q_proj/k_proj/v_proj/o_proj
91
+ - Applies per-Q-head gates (and optional KV gates)
92
+ - Handles rotary and cache (tuple or HF Cache)
93
+ - Runs SDPA directly, then o_proj
94
+ """
95
+ def __init__(self, attn_container: nn.Module,
96
+ num_q_heads: int, num_kv_heads: int, head_dim: int,
97
+ cfg: LlamaGatingConfig, layer_idx: int):
98
+ super().__init__()
99
+ self.base_attn = attn_container
100
+ self.q_proj = attn_container.q_proj
101
+ self.k_proj = attn_container.k_proj
102
+ self.v_proj = attn_container.v_proj
103
+ self.o_proj = getattr(attn_container, "o_proj", getattr(attn_container, "out_proj", None))
104
+
105
+ self.num_q_heads = int(num_q_heads)
106
+ self.num_kv_heads = int(num_kv_heads)
107
+ self.head_dim = int(head_dim)
108
+ self.gate_kv = bool(cfg.gate_kv)
109
+ self.drop_p = float(getattr(attn_container, "attention_dropout",
110
+ getattr(attn_container, "attn_dropout",
111
+ getattr(attn_container, "dropout", 0.0))))
112
+ self.head_gate = HeadGate(num_heads=self.num_q_heads,
113
+ head_dim=self.head_dim,
114
+ tau=cfg.tau, init_logit=cfg.init_logit,
115
+ hard_during_eval=cfg.hard_eval)
116
+
117
+ # rotary helpers if present on base
118
+ self.rotary_emb = getattr(attn_container, "rotary_emb", None)
119
+ self.apply_rotary_pos_emb = getattr(attn_container, "apply_rotary_pos_emb", None)
120
+ self.layer_idx = int(layer_idx)
121
+
122
+ @property
123
+ def logits(self) -> torch.Tensor:
124
+ return self.head_gate.logits
125
+
126
+ def kept_heads_soft(self) -> torch.Tensor:
127
+ p = self.head_gate.probs().detach().float().view(-1)
128
+ if p.numel() == self.num_q_heads * self.head_dim:
129
+ p = p.view(self.num_q_heads, self.head_dim).mean(dim=1)
130
+ return p.sum()
131
+
132
+
133
+ def forward(
134
+ self,
135
+ hidden_states: torch.Tensor, # [B,T,D]
136
+ attention_mask: Optional[torch.Tensor] = None, # additive mask [B,1,Tq,Tk] or None
137
+ position_ids: Optional[torch.Tensor] = None,
138
+ past_key_value = None, # tuple, list, Cache or None
139
+ output_attentions: bool = False,
140
+ use_cache: bool = False,
141
+ cache_position: Optional[torch.Tensor] = None,
142
+ position_embeddings: Optional[torch.Tensor] = None,
143
+ **kwargs,
144
+ ):
145
+ B, T, D = hidden_states.shape
146
+ Hq, Hkv, Dh = self.num_q_heads, self.num_kv_heads, self.head_dim
147
+ assert Hq * Dh == D, "hidden_size must equal num_heads * head_dim"
148
+ n_rep = max(1, Hq // Hkv)
149
+
150
+ # qkv projections
151
+ q = self.q_proj(hidden_states).view(B, T, Hq, Dh).transpose(1, 2) # [B,Hq,T,Dh]
152
+ k = self.k_proj(hidden_states).view(B, T, Hkv, Dh).transpose(1, 2) # [B,Hkv,T,Dh]
153
+ v = self.v_proj(hidden_states).view(B, T, Hkv, Dh).transpose(1, 2) # [B,Hkv,T,Dh]
154
+
155
+ # rotary
156
+ if (self.rotary_emb is not None) and (self.apply_rotary_pos_emb is not None):
157
+ Tpast = 0
158
+ if isinstance(past_key_value, (tuple, list)) and len(past_key_value) == 2:
159
+ Tpast = int(past_key_value[0].size(2))
160
+ elif isinstance(past_key_value, Cache):
161
+ Tpast = int(cache_position.max().item() if cache_position is not None else 0)
162
+ seq_len = Tpast + T
163
+ try:
164
+ cos, sin = self.rotary_emb(v, seq_len=seq_len)
165
+ except TypeError:
166
+ cos, sin = self.rotary_emb(q, seq_len=seq_len)
167
+ # try rich signature first
168
+ try:
169
+ q, k = self.apply_rotary_pos_emb(
170
+ q, k, cos, sin,
171
+ position_ids=position_ids,
172
+ cache_position=cache_position,
173
+ position_embeddings=position_embeddings
174
+ )
175
+ except TypeError:
176
+ try:
177
+ q, k = self.apply_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids)
178
+ except TypeError:
179
+ q, k = self.apply_rotary_pos_emb(q, k, cos, sin)
180
+
181
+ # cache merge
182
+ present = None
183
+ if past_key_value is None or (isinstance(past_key_value, (tuple, list)) and len(past_key_value) == 0):
184
+ pass
185
+ elif isinstance(past_key_value, (tuple, list)):
186
+ pk, pv = past_key_value # [B,Hkv,Tpast,Dh]
187
+ k = torch.cat([pk, k], dim=2)
188
+ v = torch.cat([pv, v], dim=2)
189
+ present = (k, v) if use_cache else None
190
+ elif isinstance(past_key_value, Cache):
191
+ k, v = past_key_value.update(k, v, self.layer_idx, cache_position)
192
+ present = past_key_value
193
+
194
+ # gates
195
+ # g = self.head_gate.mask(self.training).view(1, Hq, 1, 1)
196
+ # ---- gates (supports per-head OR per-channel HeadGate) ----
197
+ m = self.head_gate.mask(self.training) # 1D tensor
198
+ m = m.detach() if not self.training else m
199
+ if m.numel() == Hq:
200
+ # per-head gating
201
+ gH = m.view(1, Hq, 1, 1) # [1,Hq,1,1]
202
+ q = q * gH
203
+ if self.gate_kv:
204
+ if n_rep == 1:
205
+ k = k * gH; v = v * gH
206
+ else:
207
+ g_kv = gH.view(1, Hkv, n_rep, 1, 1).amax(dim=2)
208
+ k = k * g_kv; v = v * g_kv
209
+ elif m.numel() == Hq * Dh:
210
+ # per-channel gating
211
+ gHD = m.view(1, Hq, 1, Dh) # [1,Hq,1,Dh]
212
+ q = q * gHD
213
+ if self.gate_kv:
214
+ # collapse to per-head for KV, then map to Hkv via amax over replicas
215
+ gH = gHD.amax(dim=-1, keepdim=True) # [1,Hq,1,1]
216
+ if n_rep == 1:
217
+ k = k * gH; v = v * gH
218
+ else:
219
+ g_kv = gH.view(1, Hkv, n_rep, 1, 1).amax(dim=2)
220
+ k = k * g_kv; v = v * g_kv
221
+ else:
222
+ raise RuntimeError(
223
+ f"HeadGate mask has {m.numel()} elems; expected {Hq} or {Hq*Dh}"
224
+ )
225
+
226
+
227
+ # GQA replicate KV to Q count
228
+ k = _repeat_kv(k, n_rep)
229
+ v = _repeat_kv(v, n_rep)
230
+
231
+ attn = F.scaled_dot_product_attention(
232
+ q, k, v,
233
+ attn_mask=attention_mask,
234
+ dropout_p=self.drop_p if self.training else 0.0,
235
+ is_causal=True
236
+ )
237
+ out = attn.transpose(1, 2).contiguous().view(B, T, Hq * Dh)
238
+ out = self.o_proj(out)
239
+
240
+ attn_weights = None
241
+ # HF expects (attn_output, attn_weights, present_key_value) always
242
+ if output_attentions:
243
+ return (out, attn_weights, present)
244
+ else:
245
+ return (out, None, present)
246
+
247
+
248
+
249
+ # -------------------------------------------------------------------------
250
+ # Adapter
251
+ # -------------------------------------------------------------------------
252
+
253
+ class LlamaAdapter:
254
+ def __init__(self, model: nn.Module):
255
+ self.model = model
256
+ core = getattr(model, "model", model)
257
+ if not hasattr(core, "layers"):
258
+ raise ValueError("Provided model does not look like HF LLaMA/Mistral (missing .model.layers or .layers)")
259
+
260
+ # ---------- Gating attachment ----------
261
+ def attach_gates(self, cfg: LlamaGatingConfig) -> nn.Module:
262
+ m = self.model
263
+ core = getattr(m, "model", m)
264
+ layers = core.layers
265
+
266
+ Hq = int(core.config.num_attention_heads)
267
+ Hkv = int(getattr(core.config, "num_key_value_heads", Hq))
268
+ Dh = int(core.config.hidden_size // Hq)
269
+
270
+ for li, layer in enumerate(layers):
271
+ # Attention heads
272
+ if cfg.head_gating:
273
+ base = layer.self_attn
274
+ if not isinstance(base, GatedSelfAttentionLLM):
275
+ gated = GatedSelfAttentionLLM(
276
+ attn_container=base,
277
+ num_q_heads=Hq,
278
+ num_kv_heads=Hkv,
279
+ head_dim=Dh,
280
+ cfg=cfg,
281
+ layer_idx=li,
282
+ )
283
+ layer.self_attn = gated # route via our wrapper
284
+
285
+ # MLP grouped gating (SwiGLU)
286
+ if cfg.ffn_gating:
287
+ mlp = layer.mlp
288
+ I = int(mlp.up_proj.out_features)
289
+ assert I % cfg.ffn_group == 0, f"SwiGLU size {I} not divisible by group {cfg.ffn_group}"
290
+ if not hasattr(mlp, "neuron_gate"):
291
+ mlp.neuron_gate = GroupGate(
292
+ num_groups=I // cfg.ffn_group,
293
+ group_size=cfg.ffn_group,
294
+ tau=cfg.tau, init_logit=cfg.init_logit,
295
+ hard_during_eval=cfg.hard_eval,
296
+ )
297
+ if not hasattr(mlp, "_orig_forward"):
298
+ mlp._orig_forward = mlp.forward
299
+
300
+ def _gated_mlp_forward(this, x):
301
+ # LLaMA: z = silu(up(x)) * (gate(x) * m); out = down(z)
302
+ u = this.up_proj(x)
303
+ g = this.gate_proj(x)
304
+ m = this.neuron_gate.mask(this.training).view(1, 1, -1)
305
+ z = torch.nn.functional.silu(u) * (g * m)
306
+ return this.down_proj(z)
307
+
308
+ mlp.forward = _gated_mlp_forward.__get__(mlp, mlp.__class__)
309
+ return m
310
+
311
+ # ---------- Logits helper ----------
312
+ @staticmethod
313
+ def _last_token_index(attn_mask: torch.Tensor) -> torch.Tensor:
314
+ # attn_mask: [B, S] with 1 for tokens, 0 for padding
315
+ # returns [B] indices of last non-pad
316
+ # works for both bool and int masks
317
+ if attn_mask is None:
318
+ # no mask → use last position S-1
319
+ return None
320
+ if attn_mask.dtype != torch.long:
321
+ attn_mask = attn_mask.to(torch.long)
322
+ # idx = lengths - 1
323
+ return (attn_mask.sum(dim=-1) - 1).clamp_min(0)
324
+
325
+ @staticmethod
326
+ def get_logits(model: nn.Module,
327
+ input_ids: torch.Tensor,
328
+ attention_mask: Optional[torch.Tensor] = None,
329
+ last_only: bool = True,
330
+ **forward_kwargs) -> torch.Tensor:
331
+ """
332
+ Returns logits. If last_only=True, computes ONLY the last-token logits by:
333
+ 1) getting hidden states from the base decoder,
334
+ 2) selecting last non-pad position per sample,
335
+ 3) projecting through lm_head on that 1 position.
336
+ This avoids allocating [B,S,V].
337
+ """
338
+ # (1) run base decoder, not the full CausalLM head
339
+ core = getattr(model, "model", None)
340
+ if core is None:
341
+ # fallback if the model is already a bare decoder (rare)
342
+ core = model
343
+
344
+ # We only need last_hidden_state; no cache; avoid building logits for all S
345
+ # return_dict=False to grab tuple and avoid extra allocations
346
+ outputs = core(
347
+ input_ids=input_ids,
348
+ attention_mask=attention_mask,
349
+ use_cache=False,
350
+ return_dict=False,
351
+ **forward_kwargs
352
+ )
353
+ hidden = outputs[0] # [B, S, D]
354
+
355
+ if not last_only:
356
+ # If someone explicitly wants all logits, fine:
357
+ return model.lm_head(hidden) # [B,S,V] (expensive!)
358
+
359
+ # (2) select last token per sample
360
+ B, S, D = hidden.shape
361
+ if attention_mask is None:
362
+ # simple "last index"
363
+ idx = torch.full((B,), S - 1, device=hidden.device, dtype=torch.long)
364
+ else:
365
+ idx = LlamaAdapter._last_token_index(attention_mask)
366
+
367
+ # gather last hidden: [B, D]
368
+ last_h = hidden[torch.arange(B, device=hidden.device), idx] # [B, D]
369
+ # (3) project to logits for that 1 position
370
+ last_logits = model.lm_head(last_h).unsqueeze(1) # [B,1,V]
371
+ return last_logits
372
+
373
+ # ---------- Exporters ----------
374
+ @staticmethod
375
+ @torch.no_grad()
376
+ def export_keepall(model_with_gates: nn.Module) -> nn.Module:
377
+ """
378
+ Unwrap attention wrappers; restore original MLP.forward; drop gates.
379
+ """
380
+ slim = deepcopy_eval_cpu(model_with_gates)
381
+ core = getattr(slim, "model", slim)
382
+ if not hasattr(core, "layers"):
383
+ return slim
384
+
385
+ for layer in core.layers:
386
+ # attention
387
+ attn = layer.self_attn
388
+ if isinstance(attn, GatedSelfAttentionLLM):
389
+ gat = attn
390
+ new_attn = copy.deepcopy(gat.base_attn)
391
+ # keep metadata consistent
392
+ if hasattr(new_attn, "num_heads"):
393
+ new_attn.num_heads = int(gat.num_q_heads)
394
+ if hasattr(new_attn, "num_key_value_heads"):
395
+ new_attn.num_key_value_heads = int(gat.num_kv_heads)
396
+ if hasattr(new_attn, "head_dim"):
397
+ new_attn.head_dim = int(gat.head_dim)
398
+ layer.self_attn = new_attn
399
+
400
+ # mlp
401
+ mlp = layer.mlp
402
+ if hasattr(mlp, "_orig_forward"):
403
+ mlp.forward = mlp._orig_forward
404
+ delattr(mlp, "_orig_forward")
405
+ if hasattr(mlp, "neuron_gate"):
406
+ delattr(mlp, "neuron_gate")
407
+
408
+ return slim
409
+
410
+ @staticmethod
411
+ @torch.no_grad()
412
+ def export_pruned(model_with_gates: nn.Module, policy, step: int) -> nn.Module:
413
+ """
414
+ Produce a clean CPU eval model:
415
+ - Read gates to choose Q heads; slice q_proj rows and o_proj cols
416
+ - Snap kept heads to an LCM of (policy multiple, Hkv)
417
+ - Slice SwiGLU up/gate/down by groups
418
+ - Unwrap back to plain HF modules; update metadata
419
+ """
420
+ # Accept either CoreExportPolicy with per-axis rounding or family policy
421
+ if isinstance(policy, LlamaExportPolicy):
422
+ head_rounding = policy.head_rounding
423
+ ffn_rounding = policy.ffn_rounding
424
+ warmup_steps = policy.warmup_steps
425
+ else:
426
+ head_rounding = getattr(policy, "rounding", None)
427
+ ffn_rounding = getattr(policy, "rounding", None)
428
+ warmup_steps = int(getattr(policy, "warmup_steps", 0))
429
+
430
+ slim = deepcopy_eval_cpu(model_with_gates)
431
+ core = getattr(slim, "model", slim)
432
+ layers = getattr(core, "layers", None)
433
+ if layers is None:
434
+ return slim
435
+
436
+ warm = (step < warmup_steps)
437
+
438
+ def _lcm(a: int, b: int) -> int:
439
+ return abs(a * b) // math.gcd(max(a, 1), max(b, 1)) if a > 0 and b > 0 else max(a, b, 1)
440
+
441
+ for li, layer in enumerate(layers):
442
+ # ---- Attention (Q heads) ----
443
+ attn = layer.self_attn
444
+ if isinstance(attn, GatedSelfAttentionLLM):
445
+ gat = attn
446
+ base = gat.base_attn
447
+
448
+ Hq = int(gat.num_q_heads)
449
+ Hkv = int(gat.num_kv_heads)
450
+ Dh = int(gat.head_dim)
451
+
452
+ if warm:
453
+ keep_idx = torch.arange(Hq)
454
+ else:
455
+ # Build a "per-head" proxy gate if base gate is per-channel.
456
+ base_logits = gat.head_gate.logits.detach().float().view(-1)
457
+ tau = float(getattr(gat.head_gate, "tau", 1.0))
458
+
459
+ if base_logits.numel() == Hq:
460
+ # Native per-head gate: use as-is
461
+ proxy_gate = gat.head_gate
462
+ keep_idx = keep_group_indices_from_gate(
463
+ proxy_gate, policy=policy, step=step, custom_rounding=head_rounding
464
+ )
465
+ elif base_logits.numel() == Hq * Dh:
466
+ # Collapse per-channel → per-head (mean; or use .amax for stricter)
467
+ per_head_logits = base_logits.view(Hq, Dh).mean(dim=1)
468
+
469
+ class _PerHeadProxyGate:
470
+ def __init__(self, logits, tau):
471
+ self.logits = logits
472
+ self.tau = tau
473
+ self.num_groups = logits.numel()
474
+ self.group_size = 1
475
+
476
+ proxy_gate = _PerHeadProxyGate(per_head_logits, tau)
477
+ keep_idx = keep_group_indices_from_gate(
478
+ proxy_gate, policy=policy, step=step, custom_rounding=head_rounding
479
+ )
480
+ else:
481
+ raise RuntimeError(
482
+ f"Unexpected HeadGate logits len {base_logits.numel()} vs H={Hq} or H*Dh={Hq*Dh}"
483
+ )
484
+
485
+ # Enforce LCM with GQA (Hkv) via truncation to floor-multiple
486
+ def _lcm(a: int, b: int) -> int:
487
+ import math
488
+ return abs(a * b) // math.gcd(max(a, 1), max(b, 1)) if a > 0 and b > 0 else max(a, b, 1)
489
+
490
+ pol_mult = getattr(head_rounding, "multiple_groups", 1)
491
+ snap = _lcm(int(pol_mult), max(1, Hkv))
492
+ if keep_idx.numel() % snap != 0:
493
+ k = (keep_idx.numel() // snap) * snap
494
+ k = max(snap, min(Hq, k))
495
+ # recompute top-k by per-head logits (ensure same criterion used above)
496
+ if base_logits.numel() == Hq * Dh:
497
+ scores = per_head_logits
498
+ else:
499
+ scores = base_logits
500
+ keep_idx = torch.topk(scores, k=k, largest=True).indices.sort().values
501
+
502
+
503
+ H_keep = int(keep_idx.numel())
504
+ # channels for q/o slicing
505
+ ch_idx = torch.cat([torch.arange(h * Dh, (h + 1) * Dh) for h in keep_idx]).long()
506
+
507
+ # slice wrapper linears
508
+ gat.q_proj = slice_linear(gat.q_proj, keep_out=ch_idx)
509
+ gat.o_proj = slice_linear(gat.o_proj, keep_in=ch_idx)
510
+
511
+ # transplant into a clean HF attention
512
+ new_attn = copy.deepcopy(base)
513
+ if hasattr(new_attn, "q_proj"):
514
+ new_attn.q_proj = gat.q_proj
515
+ if hasattr(new_attn, "o_proj"):
516
+ new_attn.o_proj = gat.o_proj
517
+ elif hasattr(new_attn, "out_proj"):
518
+ new_attn.out_proj = gat.o_proj
519
+
520
+ # update metadata
521
+ if hasattr(new_attn, "num_heads"):
522
+ new_attn.num_heads = int(H_keep)
523
+ if hasattr(new_attn, "num_key_value_heads"):
524
+ new_attn.num_key_value_heads = int(Hkv)
525
+ if hasattr(new_attn, "head_dim"):
526
+ new_attn.head_dim = int(Dh)
527
+ if hasattr(core.config, "hidden_size"):
528
+ core.config.hidden_size = int(H_keep * Dh)
529
+
530
+ layer.self_attn = new_attn # unwrap
531
+
532
+ # ---- MLP (SwiGLU grouped) ----
533
+ mlp = layer.mlp
534
+ g = getattr(mlp, "neuron_gate", None)
535
+ if g is not None:
536
+ grp_idx = keep_group_indices_from_gate(
537
+ g, policy=policy, step=step, custom_rounding=ffn_rounding,
538
+ )
539
+ group = int(g.group_size) # GroupGate exposes group_size
540
+ keep_exp = torch.cat([torch.arange(i * group, (i + 1) * group) for i in grp_idx]).long()
541
+
542
+ mlp.up_proj = slice_linear(mlp.up_proj, keep_out=keep_exp)
543
+ mlp.gate_proj = slice_linear(mlp.gate_proj, keep_out=keep_exp)
544
+ mlp.down_proj = slice_linear(mlp.down_proj, keep_in=keep_exp)
545
+
546
+ # Restore clean forward & drop gate
547
+ if hasattr(mlp, "_orig_forward"):
548
+ mlp.forward = mlp._orig_forward
549
+ delattr(mlp, "_orig_forward")
550
+ if hasattr(mlp, "neuron_gate"):
551
+ delattr(mlp, "neuron_gate")
552
+
553
+ return slim
554
+
555
+
556
+ # -------------------------------------------------------------------------
557
+ # Export policy (allow different rounding for Heads vs FFN)
558
+ # -------------------------------------------------------------------------
559
+
560
+ @dataclass
561
+ class LlamaExportPolicy:
562
+ warmup_steps: int = 0
563
+ head_rounding: CoreRounding = CoreRounding() # e.g., CoreRounding(floor=8, multiple=8)
564
+ ffn_rounding: CoreRounding = CoreRounding() # e.g., CoreRounding(min_keep_ratio=0.8, multiple=32)
565
+
566
+
567
+ # -------------------------------------------------------------------------
568
+ # Grid-search convenience
569
+ # -------------------------------------------------------------------------
570
+
571
+ @dataclass
572
+ class LlamaGrid:
573
+ head_multiple_grid: Optional[Sequence[int]] = (1, 2, 4, 8)
574
+ ffn_snap_grid: Sequence[int] = (1, 32, 64, 128)
575
+
576
+ def llama_search_best_export(
577
+ model_with_gates: nn.Module,
578
+ *,
579
+ export_fn: Callable[[nn.Module, CoreExportPolicy, int], nn.Module],
580
+ num_q_heads: int,
581
+ num_kv_heads: int,
582
+ step: int,
583
+ batch_shape: Tuple[int, int], # (B,S) for text
584
+ grid: Optional[LlamaGrid] = None,
585
+ device: str = "cuda",
586
+ measure_settings=None,
587
+ make_policy: Optional[Callable[[int, int], object]] = None,
588
+ ):
589
+ """
590
+ Convenience wrapper for LLaMA-style search.
591
+ Uses the same `grid_search_latency` as ViT; we just feed head/ffn grids.
592
+ """
593
+ g = grid or LlamaGrid()
594
+ head_grid = g.head_multiple_grid or [1]
595
+ ffn_grid = list(g.ffn_snap_grid)
596
+
597
+ return grid_search_latency(
598
+ model_with_gates,
599
+ export_fn,
600
+ head_multiples=head_grid,
601
+ ffn_snaps=ffn_grid,
602
+ step=step,
603
+ batch_shape=batch_shape, # adapter’s runner should interpret as (B,S)
604
+ measure_settings=measure_settings,
605
+ device=device,
606
+ make_policy=make_policy,
607
+ )
huggingface/registry.py ADDED
File without changes
huggingface/vit.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace ViT adapter
2
+
3
+ Bridges the family-agnostic core (gates/export/proxy/train) to ViT-like models
4
+ from Hugging Face (`ViTModel`, `ViTForImageClassification`, DeiT, etc.).
5
+
6
+ Responsibilities
7
+ ----------------
8
+ - Attach gates to attention heads and MLP hidden in groups
9
+ - Provide logits getters for student/teacher
10
+ - Export helpers: keep-all (remove gates), and pruned (slice weights + metadata)
11
+
12
+ This adapter intentionally keeps the core unaware of ViT internals.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ # Ensure repo root on sys.path for absolute imports (core, adapters, data)
17
+ import sys, pathlib
18
+ sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional
22
+
23
+ import copy
24
+ import torch
25
+ import torch.nn as nn
26
+
27
+ # NOTE: absolute imports so running `-m examples.run_vit_optimize` works without package install
28
+ from core.gates import HeadGate, GroupGate
29
+ from core.export import (
30
+ ExportPolicy as CoreExportPolicy,
31
+ Rounding as CoreRounding,
32
+ keep_group_indices_from_gate,
33
+ keep_element_indices_from_gate,
34
+ slice_linear,
35
+ Rounding as CoreRounding,
36
+ )
37
+
38
+ from core.utils import deepcopy_eval_cpu
39
+ from core.search_export import grid_search_latency
40
+
41
+ # -----------------------------------------------------------------------------
42
+ # Config
43
+ # -----------------------------------------------------------------------------
44
+
45
+ @dataclass
46
+ class ViTGatingConfig:
47
+ tau: float = 1.5
48
+ init_logit: float = 3.0
49
+ head_gating: bool = True
50
+ ffn_group: int = 16
51
+ ffn_gating: bool = True
52
+ hard_eval: bool = True # use hard masks in eval mode during forward
53
+
54
+
55
+
56
+ def _encoder_layers(m: nn.Module):
57
+ """
58
+ Return the sequence of Transformer blocks for HF ViT.
59
+ Supports:
60
+ - ViTModel: m.encoder.layer
61
+ - ViTForImageClassification: m.vit.encoder.layer
62
+ """
63
+ # ViTModel path
64
+ enc = getattr(m, "encoder", None)
65
+ if enc is not None and hasattr(enc, "layer"):
66
+ return enc.layer
67
+
68
+ # ViTForImageClassification path
69
+ vit = getattr(m, "vit", None)
70
+ if vit is not None and hasattr(vit, "encoder") and hasattr(vit.encoder, "layer"):
71
+ return vit.encoder.layer
72
+
73
+ raise ValueError("Provided model does not look like a HF ViT (missing *.encoder.layer)")
74
+
75
+
76
+
77
+ # -----------------------------------------------------------------------------
78
+ # Gated attention wrapper
79
+ # -----------------------------------------------------------------------------
80
+
81
+ class GatedSelfAttentionHF(nn.Module):
82
+ """A thin wrapper around HF ViT self-attention that multiplies per-head gates.
83
+
84
+ It keeps references to the underlying query/key/value `nn.Linear` layers and
85
+ the output projection, while exposing a `HeadGate` in `head_gate`.
86
+ """
87
+
88
+ def __init__(self, attn_container: nn.Module, num_heads: int, head_dim: int, cfg: ViTGatingConfig):
89
+ super().__init__()
90
+ base_attn = attn_container.attention # ViTSdpaSelfAttention or ViTSelfAttention
91
+ out_proj = attn_container.output.dense
92
+
93
+ self.base_attn = base_attn
94
+ self.out_proj = out_proj
95
+
96
+ self.q_proj = base_attn.query
97
+ self.k_proj = base_attn.key
98
+ self.v_proj = base_attn.value
99
+
100
+ self.num_heads = int(num_heads)
101
+ self.head_dim = int(head_dim)
102
+ self.drop_p = getattr(base_attn, "dropout", nn.Dropout(0.0)).p
103
+
104
+ self.head_gate = HeadGate(num_heads=self.num_heads, head_dim=self.head_dim, tau=cfg.tau, init_logit=cfg.init_logit, hard_during_eval=cfg.hard_eval)
105
+
106
+ @property
107
+ def logits(self) -> torch.Tensor:
108
+ return self.head_gate.logits
109
+
110
+ def kept_heads_soft(self) -> torch.Tensor:
111
+ return self.head_gate.probs().sum()
112
+
113
+ def forward(self, hidden_states, head_mask=None):
114
+ B, N, _ = hidden_states.shape
115
+ H, Dh = self.num_heads, self.head_dim
116
+
117
+ wdev = self.q_proj.weight.device
118
+ if hidden_states.device != wdev:
119
+ hidden_states = hidden_states.to(wdev, non_blocking=True)
120
+
121
+ q_lin = self.q_proj(hidden_states)
122
+ k_lin = self.k_proj(hidden_states)
123
+ v_lin = self.v_proj(hidden_states)
124
+
125
+ q = q_lin.view(B, N, H, Dh).transpose(1, 2)
126
+ k = k_lin.view(B, N, H, Dh).transpose(1, 2)
127
+ v = v_lin.view(B, N, H, Dh).transpose(1, 2)
128
+
129
+ logits = self.head_gate.logits
130
+ tau = float(self.head_gate.tau)
131
+ if self.training:
132
+ u = torch.rand_like(logits).clamp_(1e-6, 1-1e-6)
133
+ s = u.log() - (1 - u).log()
134
+ y = torch.sigmoid((logits + s) / tau)
135
+ g_head = ((y > 0.5).to(y.dtype) - y).detach() + y
136
+ else:
137
+ if getattr(self.head_gate, 'hard_during_eval', True):
138
+ g_head = (logits > 0).to(logits.dtype)
139
+ else:
140
+ g_head = torch.sigmoid(logits / tau)
141
+ g = g_head.view(1, H, 1, 1)
142
+
143
+ q = q * g; k = k * g; v = v * g
144
+
145
+ attn_out = torch.nn.functional.scaled_dot_product_attention(
146
+ q, k, v, dropout_p=self.drop_p if self.training else 0.0
147
+ ) # [B, H, N, Dh]
148
+
149
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, N, H * Dh)
150
+ attn_out = self.out_proj(attn_out)
151
+ return attn_out, None
152
+
153
+
154
+ # -----------------------------------------------------------------------------
155
+ # Adapter
156
+ # -----------------------------------------------------------------------------
157
+
158
+ class ViTAdapter:
159
+ def __init__(self, model: nn.Module):
160
+ self.model = model
161
+ _ = _encoder_layers(model)
162
+
163
+ # ---------- Gating attachment ----------
164
+ def attach_gates(self, cfg: ViTGatingConfig) -> nn.Module:
165
+ m = self.model
166
+ H = int(getattr(m.config, "num_attention_heads", 12))
167
+ D = int(getattr(m.config, "hidden_size", 768))
168
+ Dh = D // H
169
+
170
+ for layer in _encoder_layers(m):
171
+ # Attention heads
172
+ if cfg.head_gating:
173
+ attn_container = layer.attention
174
+ if not isinstance(getattr(attn_container, "attention", None), GatedSelfAttentionHF):
175
+ gated = GatedSelfAttentionHF(attn_container, H, Dh, cfg)
176
+ attn_container.attention = gated
177
+
178
+ # FFN hidden (grouped)
179
+ if cfg.ffn_gating:
180
+ inter = layer.intermediate
181
+ d_ff = int(inter.dense.out_features)
182
+ assert d_ff % cfg.ffn_group == 0, f"FFN size {d_ff} not divisible by group {cfg.ffn_group}"
183
+ if not hasattr(inter, "neuron_gate"):
184
+ inter.neuron_gate = GroupGate(num_groups=d_ff // cfg.ffn_group, group_size=cfg.ffn_group, tau=cfg.tau, init_logit=cfg.init_logit, hard_during_eval=cfg.hard_eval)
185
+ # Monkey-patch forward to apply mask after activation (keeps HF shapes)
186
+ if not hasattr(inter, "_orig_forward"):
187
+ inter._orig_forward = inter.forward
188
+
189
+ def _gated_forward(this, x):
190
+ h = this.dense(x)
191
+ h = this.intermediate_act_fn(h)
192
+ msk = this.neuron_gate.mask(this.training).view(1, 1, -1)
193
+ return h * msk
194
+
195
+ inter.forward = _gated_forward.__get__(inter, inter.__class__)
196
+ return m
197
+
198
+ # ---------- Logits helpers ----------
199
+ @staticmethod
200
+ def get_logits(model: nn.Module, x: torch.Tensor, *, head: Optional[nn.Module] = None) -> torch.Tensor:
201
+ out = model(pixel_values=x)
202
+ if hasattr(out, "logits"):
203
+ return out.logits # ViTForImageClassification path
204
+ if hasattr(out, "last_hidden_state"): # ViTModel path (needs external head)
205
+ if head is None:
206
+ raise ValueError("Provide a classification head when using ViTModel without logits.")
207
+ cls_tok = out.last_hidden_state[:, 0, :]
208
+ if next(head.parameters(), torch.tensor([], device=cls_tok.device)).device != cls_tok.device:
209
+ head = head.to(cls_tok.device)
210
+ return head(cls_tok)
211
+ raise ValueError("Model output lacks logits and last_hidden_state.")
212
+
213
+
214
+ # ---------- Exporters ----------
215
+ @staticmethod
216
+ @torch.no_grad()
217
+ def export_keepall(model_with_gates: nn.Module) -> nn.Module:
218
+ slim = deepcopy_eval_cpu(model_with_gates)
219
+ for layer in _encoder_layers(slim):
220
+ # Attention: unwrap gate
221
+ attn_container = layer.attention
222
+ if isinstance(getattr(attn_container, "attention", None), GatedSelfAttentionHF):
223
+ gat = attn_container.attention
224
+ new_attn = copy.deepcopy(gat.base_attn)
225
+ # restore HF metadata if present
226
+ if hasattr(new_attn, "num_attention_heads"):
227
+ new_attn.num_attention_heads = int(gat.num_heads)
228
+ if hasattr(new_attn, "attention_head_size"):
229
+ new_attn.attention_head_size = int(gat.head_dim)
230
+ if hasattr(new_attn, "all_head_size"):
231
+ new_attn.all_head_size = int(gat.num_heads * gat.head_dim)
232
+ attn_container.attention = new_attn
233
+ # FFN: restore original forward and drop gate
234
+ inter = layer.intermediate
235
+ if hasattr(inter, "_orig_forward"):
236
+ inter.forward = inter._orig_forward
237
+ delattr(inter, "_orig_forward")
238
+ if hasattr(inter, "neuron_gate"):
239
+ delattr(inter, "neuron_gate")
240
+ return slim
241
+
242
+ @staticmethod
243
+ @torch.no_grad()
244
+ def export_pruned(model_with_gates: nn.Module, policy, step: int) -> nn.Module:
245
+ # Support both CoreExportPolicy (single rounding) and ViTExportPolicy (per-axis)
246
+ if isinstance(policy, ViTExportPolicy):
247
+ head_rounding = policy.head_rounding
248
+ ffn_rounding = policy.ffn_rounding
249
+ warmup_steps = policy.warmup_steps
250
+ else:
251
+ # fallback to single rounding for both
252
+ head_rounding = getattr(policy, "rounding", None)
253
+ ffn_rounding = getattr(policy, "rounding", None)
254
+ warmup_steps = int(getattr(policy, "warmup_steps", 0))
255
+
256
+ slim = deepcopy_eval_cpu(model_with_gates)
257
+ warm = (step < warmup_steps)
258
+
259
+ for layer in _encoder_layers(slim):
260
+ # --- Attention heads ---
261
+ attn_container = layer.attention
262
+ gat = getattr(attn_container, "attention", None)
263
+ if isinstance(gat, GatedSelfAttentionHF):
264
+ # choose rounding
265
+ rnd = head_rounding
266
+ # decide head indices via our helper; honor warmup if needed by passing step
267
+ grp_idx = keep_group_indices_from_gate(
268
+ gat.head_gate,
269
+ policy=policy,
270
+ step=step,
271
+ custom_rounding=rnd,
272
+ )
273
+ H_keep = int(grp_idx.numel())
274
+ Dh = int(gat.head_dim)
275
+
276
+ ch_idx = torch.cat([torch.arange(h * Dh, (h + 1) * Dh) for h in grp_idx]).long()
277
+ gat.q_proj = slice_linear(gat.q_proj, keep_out=ch_idx)
278
+ gat.k_proj = slice_linear(gat.k_proj, keep_out=ch_idx)
279
+ gat.v_proj = slice_linear(gat.v_proj, keep_out=ch_idx)
280
+ attn_container.output.dense = slice_linear(attn_container.output.dense, keep_in=ch_idx)
281
+
282
+ new_attn = copy.deepcopy(gat.base_attn)
283
+ new_attn.query = gat.q_proj
284
+ new_attn.key = gat.k_proj
285
+ new_attn.value = gat.v_proj
286
+ if hasattr(new_attn, "num_attention_heads"):
287
+ new_attn.num_attention_heads = H_keep
288
+ if hasattr(new_attn, "attention_head_size"):
289
+ new_attn.attention_head_size = Dh
290
+ if hasattr(new_attn, "all_head_size"):
291
+ new_attn.all_head_size = H_keep * Dh
292
+ attn_container.attention = new_attn
293
+
294
+ # --- FFN groups ---
295
+ inter, out = layer.intermediate, layer.output
296
+ g = getattr(inter, "neuron_gate", None)
297
+ if g is not None:
298
+ rnd = ffn_rounding
299
+ grp_idx = keep_group_indices_from_gate(
300
+ g,
301
+ policy=policy,
302
+ step=step,
303
+ custom_rounding=rnd,
304
+ )
305
+ group = int(g.group_size)
306
+ keep_exp = torch.cat([torch.arange(i * group, (i + 1) * group) for i in grp_idx]).long()
307
+ inter.dense = slice_linear(inter.dense, keep_out=keep_exp)
308
+ out.dense = slice_linear(out.dense, keep_in=keep_exp)
309
+
310
+ # # restore clean forward & drop gate
311
+ # if hasattr(inter, "_orig_forward"):
312
+ # def _clean_forward(this, x):
313
+ # h = this.dense(x)
314
+ # return this.intermediate_act_fn(h)
315
+ # inter.forward = _clean_forward.__get__(inter, inter.__class__)
316
+ # delattr(inter, "_orig_forward")
317
+ # if hasattr(inter, "neuron_gate"):
318
+ # delattr(inter, "neuron_gate")
319
+
320
+ inter.forward = inter.__class__.forward.__get__(inter, inter.__class__)
321
+ if hasattr(inter, "neuron_gate"):
322
+ delattr(inter, "neuron_gate")
323
+ if hasattr(inter, "_orig_forward"):
324
+ delattr(inter, "_orig_forward")
325
+
326
+
327
+ return slim
328
+
329
+
330
+
331
+
332
+ # -----------------------------------------------------------------------------
333
+ # Export policy
334
+ # -----------------------------------------------------------------------------
335
+ """ViT-specific export policy that allows different rounding for heads vs FFN."""
336
+ @dataclass
337
+ class ViTExportPolicy:
338
+ warmup_steps: int = 0
339
+ head_rounding: CoreRounding = CoreRounding()
340
+ ffn_rounding: CoreRounding = CoreRounding()
341
+
342
+
343
+ @dataclass
344
+ class ViTGrid:
345
+ head_multiple_grid: Optional[Sequence[int]] = (2, 4, 8)
346
+ ffn_snap_grid: Sequence[int] = (1, 8)
347
+ # head_multiple_grid: Optional[Sequence[int]] = None # default --> 1..num_heads
348
+ # ffn_snap_grid: Sequence[int] = (1, 2, 4, 8, 16)
349
+
350
+
351
+ def vit_search_best_export(
352
+ model_with_gates: nn.Module,
353
+ *,
354
+ export_fn: ExportFn,
355
+ num_heads: int,
356
+ step: int,
357
+ batch_shape: Tuple[int, int, int, int],
358
+ grid: Optional[ViTGrid] = None,
359
+ device: str = "cuda",
360
+ measure_settings: Optional[ProfileSettings] = None,
361
+ make_policy: Optional[Callable[[int, int], object]] = None,
362
+ ) -> SearchResult:
363
+ """Convenience wrapper for ViT-style search.
364
+
365
+ If `make_policy` is not provided, the caller's adapter should accept a
366
+ policy with separate head/FFN rounding; see `adapters.huggingface.vit.ViTExportPolicy`.
367
+ """
368
+ g = grid or ViTGrid()
369
+ head_grid = g.head_multiple_grid or list(range(1, int(num_heads) + 1))
370
+ ffn_grid = list(g.ffn_snap_grid)
371
+
372
+ return grid_search_latency(
373
+ model_with_gates,
374
+ export_fn,
375
+ head_multiples=head_grid,
376
+ ffn_snaps=ffn_grid,
377
+ step=step,
378
+ batch_shape=batch_shape,
379
+ measure_settings=measure_settings,
380
+ device=device,
381
+ make_policy=make_policy,
382
+ )
383
+
model_index.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "task": "image-classification",
3
+ "base_id": "google/vit-base-patch16-224",
4
+ "variant": "gated-student"
5
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eca9442ba47bd27888b3dc0b0df757113779d2c21182d5626cf1d54643fe637c
3
+ size 343618083