Upload rung6_moe_g4.py with huggingface_hub
Browse files- rung6_moe_g4.py +1372 -0
rung6_moe_g4.py
ADDED
|
@@ -0,0 +1,1372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
rung6_moe_g4.py — Gemma-4 E2B port of rung6_moe.py.
|
| 4 |
+
|
| 5 |
+
Same MECE MoE approach, adapted for Gemma-4's heterogeneous MLP widths:
|
| 6 |
+
- Layers 0-14: D_FFN=6144 (INTERMEDIATE)
|
| 7 |
+
- Layers 15-34: D_FFN=12288 (INTERMEDIATE_WIDE)
|
| 8 |
+
Per-layer A logits have different row counts; PRUNE_K is per-layer.
|
| 9 |
+
|
| 10 |
+
Architecture:
|
| 11 |
+
- Frozen base weights: W_gate, W_up, W_down (per-layer, variable D_FFN)
|
| 12 |
+
- Trainable per-layer:
|
| 13 |
+
* Assignment logits A ∈ R^{D_FFN_i, K}
|
| 14 |
+
* Router W_r ∈ R^{D_MODEL, K_spec} (K_spec = K - K_const)
|
| 15 |
+
- Expert k's soft mask: m_k[j] = softmax(A[j,:] / tau)[k]
|
| 16 |
+
- τ anneals 1.0 → 0.01
|
| 17 |
+
- Per-token forward:
|
| 18 |
+
1. Apply K_const always-on experts' combined soft mask to h
|
| 19 |
+
2. Route top-K_active specialist experts via W_r (+ noise)
|
| 20 |
+
3. Add selected specialist masks to combined mask (softmax-weighted within top-K)
|
| 21 |
+
4. h = gelu(gate) * up * combined_mask; y = W_down @ h
|
| 22 |
+
- Aux losses: Switch balance (α_b=0.01) + router z-loss (α_z=0.001)
|
| 23 |
+
|
| 24 |
+
Usage:
|
| 25 |
+
# fix_both-style Gemma-4 launch
|
| 26 |
+
python rung6_moe_g4.py --phase g4_fixboth \
|
| 27 |
+
--K 8 --K_const 2 --K_active_spec 2 \
|
| 28 |
+
--init taylor --loss ce \
|
| 29 |
+
--int4_qat --int4_group_size 32 \
|
| 30 |
+
--calib_path 3BASiL/calibration_data/gemma4_e2b_it_bulk_50k.jsonl \
|
| 31 |
+
--eval_calib_path 3BASiL/calibration_data/gemma4_e2b_it_final_50k.jsonl \
|
| 32 |
+
--diverse_calib_path 3BASiL/calibration_data/diverse_wikitext.jsonl \
|
| 33 |
+
--kl_base_lambda 0.5 --kl_base_temp 8.0 \
|
| 34 |
+
--w_drift_lambda 1e-6 \
|
| 35 |
+
--max_steps 2000 --save_checkpoint ckpts/g4_fixboth.pt
|
| 36 |
+
|
| 37 |
+
Output:
|
| 38 |
+
logs/rung6_moe_<phase>_results.json
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import argparse
|
| 42 |
+
import json
|
| 43 |
+
import math
|
| 44 |
+
import os
|
| 45 |
+
import time
|
| 46 |
+
import torch
|
| 47 |
+
import torch.nn as nn
|
| 48 |
+
import torch.nn.functional as F
|
| 49 |
+
from torch.optim import AdamW
|
| 50 |
+
try:
|
| 51 |
+
import bitsandbytes as bnb
|
| 52 |
+
_HAS_BNB = True
|
| 53 |
+
except ImportError:
|
| 54 |
+
_HAS_BNB = False
|
| 55 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 56 |
+
from gemma4_hf import (
|
| 57 |
+
load_gemma4 as load_model,
|
| 58 |
+
N_LAYERS,
|
| 59 |
+
HIDDEN_SIZE as D_MODEL,
|
| 60 |
+
DEVICE,
|
| 61 |
+
DTYPE,
|
| 62 |
+
INTERMEDIATE,
|
| 63 |
+
INTERMEDIATE_WIDE,
|
| 64 |
+
DOUBLE_WIDE_START,
|
| 65 |
+
)
|
| 66 |
+
from moe_recovery import (
|
| 67 |
+
recover_modules_via_generic_pipeline,
|
| 68 |
+
finetune_moe_per_layer,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
CALIB_DATA_PATH = "3BASiL/calibration_data/gemma4_e2b_it_final_50k.jsonl" # default; override via --calib_path
|
| 72 |
+
BASELINE_PPL = 0.0 # Gemma-4 baselines TBD — set to 0 so diff prints as "+ppl"
|
| 73 |
+
CLEAN_PPL = 0.0
|
| 74 |
+
# MAX_SEQ_LEN: per-record padded length. We use one-sequence-per-record so every
|
| 75 |
+
# sequence starts with BOS + chat-template scaffold (no mid-document chunks losing
|
| 76 |
+
# the BOS / scaffold context). 2048 covers ~70% of `final.jsonl` records fully;
|
| 77 |
+
# longer records are truncated (prompt + response prefix that fits). This eliminates
|
| 78 |
+
# the eval unfairness where mid-document chunks lacked BOS — base model lost context
|
| 79 |
+
# while the trained student had memorized the chunked positions.
|
| 80 |
+
MAX_SEQ_LEN = 2048
|
| 81 |
+
SEQ_LEN = MAX_SEQ_LEN # alias for back-compat (eval/train loops use SEQ_LEN)
|
| 82 |
+
BATCH = 1 # Gemma-4 E2B (4.65B) is ~17× larger than Gemma-3 (270M)
|
| 83 |
+
GRAD_ACCUM = 16 # 1 × 16 = 16 effective — keeps optimizer-step cadence similar
|
| 84 |
+
EVAL_BATCHES = 0 # 0 = no cap; eval scans every chunk in the eval split
|
| 85 |
+
LR = 1e-4
|
| 86 |
+
NOISE_SCALE = 0.020264
|
| 87 |
+
PRUNE_P = 0.40 # 40% kept (same per-token sparsity target as Gemma-3)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _d_ffn_at(layer_idx: int) -> int:
|
| 91 |
+
"""Return the FFN intermediate size for a given layer index."""
|
| 92 |
+
return INTERMEDIATE_WIDE if layer_idx >= DOUBLE_WIDE_START else INTERMEDIATE
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _prune_k_at(layer_idx: int) -> int:
|
| 96 |
+
"""Per-layer target of active neurons at bottom-60 parity."""
|
| 97 |
+
return int(_d_ffn_at(layer_idx) * PRUNE_P)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ─────────────────────────── Int4 QAT (Phase I) ────────────────────────
|
| 101 |
+
|
| 102 |
+
def quantize_int4_groupwise_ste(w, group_size=32):
|
| 103 |
+
"""Fake-quantize w (fp) to int4 groupwise along last dim with STE gradient.
|
| 104 |
+
|
| 105 |
+
Symmetric int4: range [-7, 7] (one sign bit + 3 magnitude bits, skip -8 to stay
|
| 106 |
+
symmetric — matches AWQ/GGUF Q4_K convention). One scale per group (groupwise).
|
| 107 |
+
Forward: returns the dequantized weight. Backward: gradient passes through the
|
| 108 |
+
original weight unchanged via straight-through estimator.
|
| 109 |
+
|
| 110 |
+
w: [out_dim, in_dim] — typical nn.Linear.weight shape.
|
| 111 |
+
group_size: in_features per scale group. Default 32 to match GGUF Q4_0 / Q4_K
|
| 112 |
+
block size used by llama.cpp-family inference kernels. Gemma-3's in_features
|
| 113 |
+
(640, 1024, 2048) are all divisible by 32 — no padding needed.
|
| 114 |
+
"""
|
| 115 |
+
out_dim, in_dim = w.shape
|
| 116 |
+
orig_dtype = w.dtype
|
| 117 |
+
# Do quant math in fp32 to avoid bf16 precision loss in scale/round steps.
|
| 118 |
+
w_fp32 = w.float()
|
| 119 |
+
pad = (group_size - in_dim % group_size) % group_size
|
| 120 |
+
if pad:
|
| 121 |
+
w_padded = F.pad(w_fp32, (0, pad))
|
| 122 |
+
else:
|
| 123 |
+
w_padded = w_fp32
|
| 124 |
+
n_groups = (in_dim + pad) // group_size
|
| 125 |
+
w_g = w_padded.view(out_dim, n_groups, group_size)
|
| 126 |
+
max_abs = w_g.abs().amax(dim=-1, keepdim=True).clamp_min(1e-6)
|
| 127 |
+
scale = max_abs / 7.0 # [out_dim, n_groups, 1]
|
| 128 |
+
w_int = torch.round(w_g / scale).clamp(-7, 7)
|
| 129 |
+
w_deq = (w_int * scale).view(out_dim, -1) # [out_dim, in_dim+pad]
|
| 130 |
+
if pad:
|
| 131 |
+
w_deq = w_deq[:, :in_dim]
|
| 132 |
+
w_deq = w_deq.to(orig_dtype)
|
| 133 |
+
# STE: forward = w_deq, backward = identity w.r.t. w
|
| 134 |
+
return w + (w_deq - w).detach()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class Int4QuantLinear(nn.Linear):
|
| 138 |
+
"""Drop-in nn.Linear replacement that fake-quantizes its weight to int4 in forward.
|
| 139 |
+
|
| 140 |
+
Subclasses nn.Linear, so state_dict keys (.weight, .bias) are identical to a
|
| 141 |
+
regular nn.Linear — cross-loadable. The quantization happens only in forward,
|
| 142 |
+
leaving the stored fp weight intact (trained by QAT gradients).
|
| 143 |
+
"""
|
| 144 |
+
_group_size = 32 # GGUF Q4_0 / Q4_K block size — matches deploy-time inference kernels
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
w_q = quantize_int4_groupwise_ste(self.weight, self._group_size)
|
| 148 |
+
return F.linear(x, w_q, self.bias)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def apply_int4_inplace(model, group_size=32,
|
| 152 |
+
target_substrings=("gate_proj", "up_proj", "down_proj",
|
| 153 |
+
"q_proj", "k_proj", "v_proj", "o_proj")):
|
| 154 |
+
"""Actually quantize target Linear weights to int4 grid IN-PLACE (deployment simulation).
|
| 155 |
+
|
| 156 |
+
Unlike wrap_int4 (which fake-quantizes every forward via STE), this snaps the
|
| 157 |
+
stored fp weight to the int4 grid exactly once. Post-call the model behaves as
|
| 158 |
+
if it's been exported to a real int4 deploy format — no runtime quantize overhead.
|
| 159 |
+
Returns count of modified weights.
|
| 160 |
+
"""
|
| 161 |
+
count = 0
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
for name, mod in model.named_modules():
|
| 164 |
+
if not isinstance(mod, nn.Linear):
|
| 165 |
+
continue
|
| 166 |
+
if isinstance(mod, Int4QuantLinear):
|
| 167 |
+
continue
|
| 168 |
+
if not any(t in name for t in target_substrings):
|
| 169 |
+
continue
|
| 170 |
+
w_q = quantize_int4_groupwise_ste(mod.weight, group_size).detach()
|
| 171 |
+
mod.weight.data.copy_(w_q)
|
| 172 |
+
count += 1
|
| 173 |
+
return count
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def apply_gaussian_noise_inplace(model, noise_scale,
|
| 177 |
+
target_substrings=("gate_proj", "up_proj", "down_proj",
|
| 178 |
+
"q_proj", "k_proj", "v_proj", "o_proj"),
|
| 179 |
+
seed=0):
|
| 180 |
+
"""Add N(0, noise_scale × p.std()) to target Linear weights IN-PLACE.
|
| 181 |
+
|
| 182 |
+
Gaussian proxy for quantization noise. For int4 group=32, analytically equivalent
|
| 183 |
+
noise_scale ≈ 0.129 (from σ_q/σ_w ≈ √((max_abs/7)²/12)/σ_w with
|
| 184 |
+
max_abs ≈ σ_w·√(2·ln group_size)). Returns count of modified weights.
|
| 185 |
+
"""
|
| 186 |
+
gen = torch.Generator(device=DEVICE)
|
| 187 |
+
gen.manual_seed(seed)
|
| 188 |
+
count = 0
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
for name, mod in model.named_modules():
|
| 191 |
+
if not isinstance(mod, nn.Linear):
|
| 192 |
+
continue
|
| 193 |
+
if isinstance(mod, Int4QuantLinear):
|
| 194 |
+
# Skip to avoid compounding noise with fake-quant in forward (ambiguous semantics).
|
| 195 |
+
continue
|
| 196 |
+
if not any(t in name for t in target_substrings):
|
| 197 |
+
continue
|
| 198 |
+
w = mod.weight.data
|
| 199 |
+
std_w = w.float().std()
|
| 200 |
+
noise = torch.randn(w.shape, generator=gen, device=w.device, dtype=torch.float32) * std_w * noise_scale
|
| 201 |
+
w.add_(noise.to(w.dtype))
|
| 202 |
+
count += 1
|
| 203 |
+
return count
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class LoRALinear(nn.Module):
|
| 207 |
+
"""Wraps an nn.Linear (incl. Int4QuantLinear). Base is frozen; trainable rank-r LoRA delta.
|
| 208 |
+
|
| 209 |
+
forward(x) = base(x) + (alpha / rank) * lora_b(lora_a(x))
|
| 210 |
+
A is initialized Kaiming-uniform; B is zero — so initial output equals base output.
|
| 211 |
+
"""
|
| 212 |
+
def __init__(self, base_linear: nn.Linear, rank: int, alpha: float):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.base = base_linear
|
| 215 |
+
for p in self.base.parameters():
|
| 216 |
+
p.requires_grad_(False)
|
| 217 |
+
in_dim, out_dim = base_linear.in_features, base_linear.out_features
|
| 218 |
+
self.lora_a = nn.Linear(in_dim, rank, bias=False,
|
| 219 |
+
device=base_linear.weight.device, dtype=base_linear.weight.dtype)
|
| 220 |
+
self.lora_b = nn.Linear(rank, out_dim, bias=False,
|
| 221 |
+
device=base_linear.weight.device, dtype=base_linear.weight.dtype)
|
| 222 |
+
nn.init.kaiming_uniform_(self.lora_a.weight, a=5 ** 0.5)
|
| 223 |
+
nn.init.zeros_(self.lora_b.weight)
|
| 224 |
+
self.scale = alpha / rank
|
| 225 |
+
|
| 226 |
+
def forward(self, x):
|
| 227 |
+
return self.base(x) + self.lora_b(self.lora_a(x)) * self.scale
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def wrap_lora(model, rank: int, alpha: float,
|
| 231 |
+
target_substrings=("gate_proj", "up_proj", "down_proj",
|
| 232 |
+
"q_proj", "k_proj", "v_proj", "o_proj")):
|
| 233 |
+
"""Replace target Linear modules with LoRALinear. Base is frozen; only LoRA A/B train.
|
| 234 |
+
|
| 235 |
+
Run AFTER wrap_int4 so the base inside LoRALinear is the int4-quantized Linear.
|
| 236 |
+
Returns number of wrapped modules and total LoRA params added.
|
| 237 |
+
"""
|
| 238 |
+
count = 0
|
| 239 |
+
n_params = 0
|
| 240 |
+
for name, mod in list(model.named_modules()):
|
| 241 |
+
if not isinstance(mod, nn.Linear):
|
| 242 |
+
continue
|
| 243 |
+
if isinstance(mod, LoRALinear):
|
| 244 |
+
continue
|
| 245 |
+
if not any(t in name for t in target_substrings):
|
| 246 |
+
continue
|
| 247 |
+
new_mod = LoRALinear(mod, rank=rank, alpha=alpha)
|
| 248 |
+
parent_name, _, attr = name.rpartition(".")
|
| 249 |
+
parent = model.get_submodule(parent_name) if parent_name else model
|
| 250 |
+
setattr(parent, attr, new_mod)
|
| 251 |
+
n_params += sum(p.numel() for p in new_mod.lora_a.parameters()) + \
|
| 252 |
+
sum(p.numel() for p in new_mod.lora_b.parameters())
|
| 253 |
+
count += 1
|
| 254 |
+
return count, n_params
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def wrap_int4(model, target_substrings=("gate_proj", "up_proj", "down_proj",
|
| 258 |
+
"q_proj", "k_proj", "v_proj", "o_proj")):
|
| 259 |
+
"""Replace matching nn.Linear modules with Int4QuantLinear (subclass).
|
| 260 |
+
State-dict keys unchanged; weights shared (same Tensor). Returns count of wrapped modules."""
|
| 261 |
+
count = 0
|
| 262 |
+
for name, mod in list(model.named_modules()):
|
| 263 |
+
if not isinstance(mod, nn.Linear):
|
| 264 |
+
continue
|
| 265 |
+
if isinstance(mod, Int4QuantLinear):
|
| 266 |
+
continue # already wrapped
|
| 267 |
+
if not any(t in name for t in target_substrings):
|
| 268 |
+
continue
|
| 269 |
+
new_mod = Int4QuantLinear(mod.in_features, mod.out_features,
|
| 270 |
+
bias=mod.bias is not None,
|
| 271 |
+
device=mod.weight.device, dtype=mod.weight.dtype)
|
| 272 |
+
# Share the underlying tensor (no copy) so optimizer state and grad flow are preserved
|
| 273 |
+
new_mod.weight = mod.weight
|
| 274 |
+
if mod.bias is not None:
|
| 275 |
+
new_mod.bias = mod.bias
|
| 276 |
+
parent_name, _, attr = name.rpartition(".")
|
| 277 |
+
parent = model.get_submodule(parent_name) if parent_name else model
|
| 278 |
+
setattr(parent, attr, new_mod)
|
| 279 |
+
count += 1
|
| 280 |
+
return count
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ────────────────────────────── utilities ──────────────────────────────
|
| 284 |
+
|
| 285 |
+
def corrupt_model(model, noise_scale=NOISE_SCALE, seed=42):
|
| 286 |
+
rng = torch.Generator(); rng.manual_seed(seed)
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
for p in model.parameters():
|
| 289 |
+
noise = torch.randn(p.shape, generator=rng, dtype=p.dtype).to(p.device)
|
| 290 |
+
p.add_(noise * p.std() * noise_scale)
|
| 291 |
+
print(f" Corrupted model with noise_scale={noise_scale}")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def load_seqs(tokenizer, split="train", calib_path=None, raw_text=False):
|
| 295 |
+
"""Load tokenized sequences from a JSONL calibration file.
|
| 296 |
+
80/20 train/eval split within the file. Use split='all' to return all records
|
| 297 |
+
(useful when train path and eval path differ — no need to withhold).
|
| 298 |
+
Pass `calib_path` to override default.
|
| 299 |
+
|
| 300 |
+
Format: ONE sequence per record, length MAX_SEQ_LEN, padded with pad_token_id.
|
| 301 |
+
Every sequence starts with BOS + the chat-template scaffold. We do NOT chunk
|
| 302 |
+
long records into multiple length-MAX_SEQ_LEN pieces, because mid-document
|
| 303 |
+
chunks lack the BOS + chat scaffold and drop the base model out of distribution
|
| 304 |
+
while a trained student memorizes the chunked positions — that produced an
|
| 305 |
+
unfair eval comparison previously. Records longer than MAX_SEQ_LEN are
|
| 306 |
+
truncated to MAX_SEQ_LEN (prompt + response prefix that fits). Records whose
|
| 307 |
+
user prompt alone exceeds MAX_SEQ_LEN-1 (no room for a response token) are
|
| 308 |
+
skipped — they have no scored positions.
|
| 309 |
+
|
| 310 |
+
If raw_text=True, expects JSONL with a 'text' field (e.g., wikitext) and skips
|
| 311 |
+
the chat-template wrapping — suitable for KL-to-teacher regularization on a
|
| 312 |
+
diverse pretraining-style corpus. Each record produces one MAX_SEQ_LEN
|
| 313 |
+
sequence (truncated if longer) with every non-pad position scored.
|
| 314 |
+
"""
|
| 315 |
+
path = calib_path or CALIB_DATA_PATH
|
| 316 |
+
records = []
|
| 317 |
+
with open(path) as f:
|
| 318 |
+
for line in f:
|
| 319 |
+
records.append(json.loads(line))
|
| 320 |
+
if split == "all":
|
| 321 |
+
pass # use all records
|
| 322 |
+
else:
|
| 323 |
+
n_train = int(len(records) * 0.8)
|
| 324 |
+
records = records[:n_train] if split == "train" else records[n_train:]
|
| 325 |
+
|
| 326 |
+
pad_id = tokenizer.pad_token_id or 0
|
| 327 |
+
|
| 328 |
+
if raw_text:
|
| 329 |
+
# Pretraining-style: each record has a 'text' field; no chat template.
|
| 330 |
+
# One sequence per record (truncated to MAX_SEQ_LEN); every non-pad
|
| 331 |
+
# position is scored (no prompt mask — every token is informative).
|
| 332 |
+
seqs = []
|
| 333 |
+
for r in records:
|
| 334 |
+
text = r.get("text") or r.get("content") or ""
|
| 335 |
+
if not text:
|
| 336 |
+
continue
|
| 337 |
+
ids = tokenizer.encode(text, add_special_tokens=True)
|
| 338 |
+
if len(ids) < 32:
|
| 339 |
+
continue
|
| 340 |
+
ids = ids[:MAX_SEQ_LEN]
|
| 341 |
+
n = len(ids)
|
| 342 |
+
pad_len = MAX_SEQ_LEN - n
|
| 343 |
+
# labels[t] = ids[t+1] for t in [0, n-2]; labels[n-1] = -100 (boundary);
|
| 344 |
+
# labels[n:] = -100 (pad). Total length = MAX_SEQ_LEN.
|
| 345 |
+
labels_list = ids[1:n] + [-100] * (pad_len + 1)
|
| 346 |
+
assert len(labels_list) == MAX_SEQ_LEN
|
| 347 |
+
seqs.append({
|
| 348 |
+
"input_ids": torch.tensor(ids + [pad_id] * pad_len, dtype=torch.long),
|
| 349 |
+
"labels": torch.tensor(labels_list, dtype=torch.long),
|
| 350 |
+
})
|
| 351 |
+
return seqs
|
| 352 |
+
# Chat-template format: mask user-prompt tokens with -100 in labels so only assistant
|
| 353 |
+
# response tokens are scored (CE training and PPL eval). Avoids over-fitting to the user
|
| 354 |
+
# prompt distribution and gives a meaningful PPL number for "how well does the model
|
| 355 |
+
# produce the assistant response given the prompt." Pretraining-style raw_text above is
|
| 356 |
+
# NOT masked (every token is informative).
|
| 357 |
+
seqs = []
|
| 358 |
+
for r in records:
|
| 359 |
+
msgs = [{"role": "user", "content": r["prompt"]},
|
| 360 |
+
{"role": "model", "content": r["response"]}]
|
| 361 |
+
try:
|
| 362 |
+
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
|
| 363 |
+
user_text = tokenizer.apply_chat_template([msgs[0]], tokenize=False, add_generation_prompt=True)
|
| 364 |
+
except Exception:
|
| 365 |
+
text = f"{r['prompt']}\n{r['response']}"
|
| 366 |
+
user_text = f"{r['prompt']}\n"
|
| 367 |
+
ids = tokenizer.encode(text, add_special_tokens=True)
|
| 368 |
+
# n_user = number of tokens at the START of `ids` that are user prompt + chat-template
|
| 369 |
+
# scaffolding. Tokens at positions [n_user, len(ids)) are the assistant response.
|
| 370 |
+
n_user = len(tokenizer.encode(user_text, add_special_tokens=True))
|
| 371 |
+
# Truncate to MAX_SEQ_LEN. Drop records where the user prompt alone fills the
|
| 372 |
+
# window (no scored response tokens would survive).
|
| 373 |
+
if n_user >= MAX_SEQ_LEN:
|
| 374 |
+
continue
|
| 375 |
+
ids = ids[:MAX_SEQ_LEN]
|
| 376 |
+
n = len(ids)
|
| 377 |
+
if n < 32:
|
| 378 |
+
continue
|
| 379 |
+
# labels[t] is the next-token target for position t (paired with logits[t]).
|
| 380 |
+
# labels[t] = ids[t+1] when ids[t+1] is in the response (i.e., t+1 >= n_user),
|
| 381 |
+
# else -100. Final labels[n-1] is set to -100 because there is no ids[n] inside
|
| 382 |
+
# the in-document content (truncation/end). labels[n:] = -100 (pad).
|
| 383 |
+
# In code: first n_user-1 labels are -100 (prompt-token targets),
|
| 384 |
+
# then labels[n_user-1 .. n-2] = ids[n_user .. n-1] (response targets),
|
| 385 |
+
# then labels[n-1 .. MAX_SEQ_LEN-1] = -100.
|
| 386 |
+
n_mask = n_user - 1
|
| 387 |
+
pad_len = MAX_SEQ_LEN - n
|
| 388 |
+
labels_list = [-100] * n_mask + ids[n_user:n] + [-100] * (pad_len + 1)
|
| 389 |
+
assert len(labels_list) == MAX_SEQ_LEN, \
|
| 390 |
+
f"label len {len(labels_list)} != MAX_SEQ_LEN {MAX_SEQ_LEN}"
|
| 391 |
+
# Sanity: at least one scored position (otherwise drop).
|
| 392 |
+
if not any(l != -100 for l in labels_list):
|
| 393 |
+
continue
|
| 394 |
+
seqs.append({
|
| 395 |
+
"input_ids": torch.tensor(ids + [pad_id] * pad_len, dtype=torch.long),
|
| 396 |
+
"labels": torch.tensor(labels_list, dtype=torch.long),
|
| 397 |
+
})
|
| 398 |
+
return seqs
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def kl_loss(s_logits, t_logits, temp=1.0, mask=None):
|
| 402 |
+
"""KL(student || teacher), optional bool [B,T] mask of positions to score.
|
| 403 |
+
|
| 404 |
+
Without mask, equivalent to F.kl_div(reduction='batchmean') * temp**2 (legacy).
|
| 405 |
+
With mask, scales the masked elements as if positions outside the mask had
|
| 406 |
+
contributed 0 — preserves the same per-batch loss magnitude as legacy.
|
| 407 |
+
"""
|
| 408 |
+
s_log = F.log_softmax(s_logits / temp, dim=-1)
|
| 409 |
+
t_prob = F.softmax(t_logits / temp, dim=-1)
|
| 410 |
+
if mask is None:
|
| 411 |
+
return F.kl_div(s_log, t_prob, reduction="batchmean") * (temp ** 2)
|
| 412 |
+
# Per-(B,T) vocab-summed KL; preserve batchmean (sum / batch) semantics over masked subset.
|
| 413 |
+
elem = F.kl_div(s_log, t_prob, reduction="none").sum(dim=-1) # [B, T]
|
| 414 |
+
return elem[mask].sum() / s_logits.shape[0] * (temp ** 2)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def ce_loss(s_logits, labels):
|
| 418 |
+
return F.cross_entropy(
|
| 419 |
+
s_logits.reshape(-1, s_logits.size(-1)),
|
| 420 |
+
labels.reshape(-1), ignore_index=-100)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
@torch.no_grad()
|
| 424 |
+
def eval_ppl(model, tokenizer, calib_path=None, max_seqs=None):
|
| 425 |
+
"""Compute PPL over the eval split. If max_seqs is set, cap at that many
|
| 426 |
+
sequences (in load order — deterministic). Default behavior unchanged when
|
| 427 |
+
max_seqs is None."""
|
| 428 |
+
seqs = load_seqs(tokenizer, "eval", calib_path=calib_path)
|
| 429 |
+
if max_seqs is not None and max_seqs > 0:
|
| 430 |
+
seqs = seqs[:max_seqs]
|
| 431 |
+
loader = torch.utils.data.DataLoader(seqs, batch_size=1)
|
| 432 |
+
total_nll, total_tok = 0.0, 0
|
| 433 |
+
model.eval()
|
| 434 |
+
for i, batch in enumerate(loader):
|
| 435 |
+
if EVAL_BATCHES and i >= EVAL_BATCHES: break
|
| 436 |
+
ids = batch["input_ids"].to(DEVICE)
|
| 437 |
+
labels = batch["labels"][:, :-1].to(DEVICE)
|
| 438 |
+
logits = model(ids)
|
| 439 |
+
loss = F.cross_entropy(
|
| 440 |
+
logits[:, :-1].reshape(-1, logits.size(-1)),
|
| 441 |
+
labels.reshape(-1), ignore_index=-100, reduction="sum")
|
| 442 |
+
total_nll += loss.item()
|
| 443 |
+
total_tok += (labels != -100).sum().item()
|
| 444 |
+
return math.exp(total_nll / total_tok) if total_tok > 0 else float("inf")
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
# ──────────────────────── initialization helpers ────────────────────────
|
| 448 |
+
|
| 449 |
+
def compute_taylor_saliency(model, tokenizer, n_batches=8, calib_path=None):
|
| 450 |
+
"""Mean |h * dL/dh| per neuron per layer. Returns list[N_LAYERS] of [D_FFN] tensors.
|
| 451 |
+
Temporarily re-enables grad on model params; restores frozen state on exit."""
|
| 452 |
+
model.eval()
|
| 453 |
+
# Snapshot freeze state; temporarily unfreeze for Taylor computation
|
| 454 |
+
prev_grad = [p.requires_grad for p in model.parameters()]
|
| 455 |
+
for p in model.parameters(): p.requires_grad_(True)
|
| 456 |
+
try:
|
| 457 |
+
scores = [torch.zeros(_d_ffn_at(i), device=DEVICE) for i in range(N_LAYERS)]
|
| 458 |
+
seqs = load_seqs(tokenizer, "train", calib_path=calib_path)[:n_batches * BATCH]
|
| 459 |
+
loader = torch.utils.data.DataLoader(seqs, batch_size=BATCH)
|
| 460 |
+
caches = [None] * N_LAYERS
|
| 461 |
+
hooks = []
|
| 462 |
+
def make_hook(i):
|
| 463 |
+
def hook(mod, inp, out):
|
| 464 |
+
caches[i] = out
|
| 465 |
+
out.retain_grad()
|
| 466 |
+
return hook
|
| 467 |
+
for i, layer in enumerate(model.layers):
|
| 468 |
+
hooks.append(layer.mlp.gate_proj.register_forward_hook(make_hook(i)))
|
| 469 |
+
|
| 470 |
+
n_seen = 0
|
| 471 |
+
for batch in loader:
|
| 472 |
+
ids = batch["input_ids"].to(DEVICE)
|
| 473 |
+
labels = batch["labels"][:, :-1].to(DEVICE)
|
| 474 |
+
logits = model(ids)
|
| 475 |
+
loss = F.cross_entropy(
|
| 476 |
+
logits[:, :-1].reshape(-1, logits.size(-1)),
|
| 477 |
+
labels.reshape(-1), ignore_index=-100)
|
| 478 |
+
loss.backward()
|
| 479 |
+
for i in range(N_LAYERS):
|
| 480 |
+
if caches[i] is not None and caches[i].grad is not None:
|
| 481 |
+
s = (caches[i].detach() * caches[i].grad.detach()).abs().mean(dim=(0, 1))
|
| 482 |
+
scores[i] += s
|
| 483 |
+
model.zero_grad(set_to_none=True)
|
| 484 |
+
n_seen += 1
|
| 485 |
+
if n_seen >= n_batches: break
|
| 486 |
+
|
| 487 |
+
for h in hooks: h.remove()
|
| 488 |
+
scores = [s.detach() / max(n_seen, 1) for s in scores]
|
| 489 |
+
return scores
|
| 490 |
+
finally:
|
| 491 |
+
# Restore original freeze state
|
| 492 |
+
for p, g in zip(model.parameters(), prev_grad):
|
| 493 |
+
p.requires_grad_(g)
|
| 494 |
+
model.zero_grad(set_to_none=True)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def init_assignment_logits(init_mode, K, K_const, taylor_scores=None, core_frac=0.5):
|
| 498 |
+
"""Return per-layer list of A ∈ [D_FFN_i, K] init tensors."""
|
| 499 |
+
As = []
|
| 500 |
+
for i in range(N_LAYERS):
|
| 501 |
+
d_ffn_i = _d_ffn_at(i)
|
| 502 |
+
prune_k_i = _prune_k_at(i)
|
| 503 |
+
if init_mode == "random":
|
| 504 |
+
# std=0.5 gives softmax(A/1.0) mild bias but softmax(A/0.01) near one-hot
|
| 505 |
+
# at the end of training — room for A to grow meaningfully during anneal.
|
| 506 |
+
A = torch.randn(d_ffn_i, K) * 0.5
|
| 507 |
+
elif init_mode == "taylor":
|
| 508 |
+
assert taylor_scores is not None, "taylor init requires scores"
|
| 509 |
+
scores = taylor_scores[i].cpu() # [D_FFN_i]
|
| 510 |
+
order = scores.argsort(descending=True) # high-saliency first
|
| 511 |
+
# Softer init (±2.0) so τ-anneal 1.0→0.01 has dynamic range
|
| 512 |
+
A = torch.full((d_ffn_i, K), -2.0)
|
| 513 |
+
if K_const > 0:
|
| 514 |
+
# Top core_frac of prune_k_i active neurons into K_const always-on experts
|
| 515 |
+
n_core = int(prune_k_i * core_frac)
|
| 516 |
+
for rank, idx in enumerate(order[:n_core]):
|
| 517 |
+
A[idx, rank % K_const] = 2.0
|
| 518 |
+
for rank, idx in enumerate(order[n_core:prune_k_i]):
|
| 519 |
+
A[idx, K_const + rank % (K - K_const)] = 2.0
|
| 520 |
+
# Low-saliency neurons: uniform mild bias, τ-anneal drives assignment
|
| 521 |
+
for rank, idx in enumerate(order[prune_k_i:]):
|
| 522 |
+
A[idx, rank % K] = 0.0
|
| 523 |
+
else:
|
| 524 |
+
for rank, idx in enumerate(order[:prune_k_i]):
|
| 525 |
+
A[idx, rank % K] = 2.0
|
| 526 |
+
for rank, idx in enumerate(order[prune_k_i:]):
|
| 527 |
+
A[idx, rank % K] = 0.0
|
| 528 |
+
elif init_mode == "em" or init_mode == "kmeans":
|
| 529 |
+
raise NotImplementedError(
|
| 530 |
+
f"init={init_mode} requires a precomputed file; use --init random|taylor for now")
|
| 531 |
+
else:
|
| 532 |
+
raise ValueError(f"Unknown init_mode: {init_mode}")
|
| 533 |
+
As.append(A)
|
| 534 |
+
return As
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def init_router_weights(init_mode, model, init_As, K, K_const, scale_multiplier=1.0):
|
| 538 |
+
"""Return per-layer list of W_r ∈ [D_MODEL, K_spec] init tensors (None = use default random).
|
| 539 |
+
|
| 540 |
+
Modes:
|
| 541 |
+
- random: let MoEMLP use its default N(0, 0.02) init (returns list of Nones)
|
| 542 |
+
- zero: W_r = 0 everywhere (uniform routing at init; router learns from scratch)
|
| 543 |
+
- centroid: W_r[:, k] = L2-normalized mean of W_gate rows for expert k's
|
| 544 |
+
argmax-assigned specialist neurons, scaled to magnitude 0.02
|
| 545 |
+
(matches default random init scale). Informed warm start.
|
| 546 |
+
- scaled_centroid: W_r[:, k] = scale_multiplier × mean of W_gate rows for
|
| 547 |
+
expert k's assigned neurons (NOT normalized). The router's weight
|
| 548 |
+
scale inherits the natural magnitude of the base model's W_gate
|
| 549 |
+
columns — i.e., the router is "a multiple" of the underlying
|
| 550 |
+
weight geometry. scale_multiplier sets that multiple explicitly.
|
| 551 |
+
"""
|
| 552 |
+
K_spec = K - K_const
|
| 553 |
+
if K_spec == 0:
|
| 554 |
+
return [None] * N_LAYERS
|
| 555 |
+
W_rs = []
|
| 556 |
+
for i in range(N_LAYERS):
|
| 557 |
+
if init_mode == "random":
|
| 558 |
+
W_rs.append(None)
|
| 559 |
+
elif init_mode == "zero":
|
| 560 |
+
W_rs.append(torch.zeros(D_MODEL, K_spec))
|
| 561 |
+
elif init_mode == "centroid":
|
| 562 |
+
W_gate = model.layers[i].mlp.gate_proj.weight.detach().float().cpu() # [D_FFN, D_MODEL]
|
| 563 |
+
A = init_As[i].cpu() # [D_FFN, K]
|
| 564 |
+
assignment = A.argmax(dim=-1) # [D_FFN], values in [0, K)
|
| 565 |
+
W_r = torch.zeros(D_MODEL, K_spec)
|
| 566 |
+
for k in range(K_spec):
|
| 567 |
+
expert_k = K_const + k # specialist expert index in full K space
|
| 568 |
+
mask = (assignment == expert_k)
|
| 569 |
+
if mask.any():
|
| 570 |
+
W_r[:, k] = W_gate[mask].mean(dim=0)
|
| 571 |
+
else:
|
| 572 |
+
# Fallback: mean of all columns
|
| 573 |
+
W_r[:, k] = W_gate.mean(dim=0)
|
| 574 |
+
W_r = F.normalize(W_r, dim=0) * 0.02 # unit-direction, small magnitude comparable to default
|
| 575 |
+
W_rs.append(W_r)
|
| 576 |
+
elif init_mode == "scaled_centroid":
|
| 577 |
+
W_gate = model.layers[i].mlp.gate_proj.weight.detach().float().cpu() # [D_FFN, D_MODEL]
|
| 578 |
+
A = init_As[i].cpu() # [D_FFN, K]
|
| 579 |
+
assignment = A.argmax(dim=-1)
|
| 580 |
+
W_r = torch.zeros(D_MODEL, K_spec)
|
| 581 |
+
for k in range(K_spec):
|
| 582 |
+
expert_k = K_const + k
|
| 583 |
+
mask = (assignment == expert_k)
|
| 584 |
+
if mask.any():
|
| 585 |
+
W_r[:, k] = W_gate[mask].mean(dim=0)
|
| 586 |
+
else:
|
| 587 |
+
W_r[:, k] = W_gate.mean(dim=0)
|
| 588 |
+
# NOT normalized — router magnitude inherits base-model weight scale.
|
| 589 |
+
# scale_multiplier is a free "how many multiples of base weights" knob.
|
| 590 |
+
W_r = W_r * scale_multiplier
|
| 591 |
+
W_rs.append(W_r)
|
| 592 |
+
else:
|
| 593 |
+
raise ValueError(f"Unknown W_r init_mode: {init_mode}")
|
| 594 |
+
return W_rs
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
# ───────────────────────── MoE MLP module ─────────────────────────────
|
| 598 |
+
|
| 599 |
+
class MoEMLP(nn.Module):
|
| 600 |
+
"""
|
| 601 |
+
K experts via per-neuron softmax assignment (MECE mode) OR
|
| 602 |
+
independent sigmoid masks with orthogonality loss (sigmoid_ortho mode).
|
| 603 |
+
|
| 604 |
+
K_const always-on experts always apply; K_spec routed experts selected
|
| 605 |
+
via top-K_active_spec routing.
|
| 606 |
+
"""
|
| 607 |
+
def __init__(self, base_mlp, K, K_const, K_active_spec, mece_mode, init_A,
|
| 608 |
+
noise_std=1.0, freeze_base=True, init_W_r=None):
|
| 609 |
+
super().__init__()
|
| 610 |
+
self.gate_proj = base_mlp.gate_proj
|
| 611 |
+
self.up_proj = base_mlp.up_proj
|
| 612 |
+
self.down_proj = base_mlp.down_proj
|
| 613 |
+
if freeze_base:
|
| 614 |
+
for p in self.gate_proj.parameters(): p.requires_grad_(False)
|
| 615 |
+
for p in self.up_proj.parameters(): p.requires_grad_(False)
|
| 616 |
+
for p in self.down_proj.parameters(): p.requires_grad_(False)
|
| 617 |
+
|
| 618 |
+
self.K = K
|
| 619 |
+
self.K_const = K_const
|
| 620 |
+
self.K_spec = K - K_const
|
| 621 |
+
self.K_active_spec = K_active_spec # # of specialist experts fired per token
|
| 622 |
+
self.mece_mode = mece_mode # "softmax" | "sigmoid_ortho"
|
| 623 |
+
self.noise_std = noise_std
|
| 624 |
+
self.tau = 1.0
|
| 625 |
+
|
| 626 |
+
# Assignment logits
|
| 627 |
+
self.A = nn.Parameter(init_A.to(DEVICE).float())
|
| 628 |
+
|
| 629 |
+
if self.K_spec > 0:
|
| 630 |
+
if init_W_r is not None:
|
| 631 |
+
self.W_r = nn.Parameter(init_W_r.to(DEVICE).float())
|
| 632 |
+
else:
|
| 633 |
+
self.W_r = nn.Parameter(torch.zeros(D_MODEL, self.K_spec, device=DEVICE, dtype=torch.float32))
|
| 634 |
+
nn.init.normal_(self.W_r, std=0.02)
|
| 635 |
+
else:
|
| 636 |
+
self.register_parameter("W_r", None)
|
| 637 |
+
|
| 638 |
+
# Diagnostics cache (populated during training forward)
|
| 639 |
+
self._last_logits = None
|
| 640 |
+
self._last_top_idx = None
|
| 641 |
+
|
| 642 |
+
def _expert_masks(self):
|
| 643 |
+
"""Return [K, D_FFN] — each expert's soft mask."""
|
| 644 |
+
if self.mece_mode == "softmax":
|
| 645 |
+
probs = F.softmax(self.A / max(self.tau, 1e-3), dim=-1) # [D_FFN, K]
|
| 646 |
+
return probs.T.contiguous() # [K, D_FFN]
|
| 647 |
+
elif self.mece_mode == "sigmoid_ortho":
|
| 648 |
+
return torch.sigmoid(self.A / max(self.tau, 1e-3)).T.contiguous()
|
| 649 |
+
else:
|
| 650 |
+
raise ValueError(self.mece_mode)
|
| 651 |
+
|
| 652 |
+
def forward(self, x):
|
| 653 |
+
gate_raw = self.gate_proj(x) # [B, T, D_FFN]
|
| 654 |
+
gate_act = F.gelu(gate_raw, approximate="tanh")
|
| 655 |
+
up_act = self.up_proj(x)
|
| 656 |
+
h_pre = gate_act * up_act # [B, T, D_FFN]
|
| 657 |
+
|
| 658 |
+
masks = self._expert_masks() # [K, D_FFN]
|
| 659 |
+
|
| 660 |
+
# Always-on core contribution
|
| 661 |
+
d_ffn = self.A.shape[0] # per-layer D_FFN
|
| 662 |
+
if self.K_const > 0:
|
| 663 |
+
const_mask = masks[:self.K_const].sum(dim=0) # [D_FFN]
|
| 664 |
+
else:
|
| 665 |
+
const_mask = torch.zeros(d_ffn, device=x.device, dtype=torch.float32)
|
| 666 |
+
|
| 667 |
+
# Routed specialist contribution
|
| 668 |
+
if self.K_spec > 0:
|
| 669 |
+
logits = x.to(torch.float32) @ self.W_r # [B, T, K_spec]
|
| 670 |
+
if self.training and self.noise_std > 0:
|
| 671 |
+
logits = logits + torch.randn_like(logits) * (self.noise_std / (self.K_spec ** 0.5))
|
| 672 |
+
self._last_logits = logits
|
| 673 |
+
|
| 674 |
+
k_act = min(self.K_active_spec, self.K_spec)
|
| 675 |
+
top_vals, top_idx = logits.topk(k_act, dim=-1) # [B, T, k_act]
|
| 676 |
+
self._last_top_idx = top_idx
|
| 677 |
+
top_w = F.softmax(top_vals, dim=-1) # [B, T, k_act]
|
| 678 |
+
|
| 679 |
+
spec_masks = masks[self.K_const:] # [K_spec, D_FFN]
|
| 680 |
+
gathered = spec_masks[top_idx] # [B, T, k_act, D_FFN]
|
| 681 |
+
spec_combined = (gathered * top_w.unsqueeze(-1)).sum(dim=-2) # [B, T, D_FFN]
|
| 682 |
+
combined = const_mask.view(1, 1, -1) + spec_combined
|
| 683 |
+
else:
|
| 684 |
+
combined = const_mask.view(1, 1, -1).expand_as(h_pre)
|
| 685 |
+
|
| 686 |
+
h = h_pre * combined.to(x.dtype)
|
| 687 |
+
return self.down_proj(h)
|
| 688 |
+
|
| 689 |
+
def aux_loss(self, alpha_b=0.01, alpha_z=0.001):
|
| 690 |
+
"""Switch balance loss (on specialists only) + router z-loss.
|
| 691 |
+
|
| 692 |
+
Balance (Switch/GShard top-k generalization):
|
| 693 |
+
f_k = fraction of tokens routed to expert k
|
| 694 |
+
= (tokens_selecting_k / total_tokens) / K_active_spec
|
| 695 |
+
p_k = mean softmax probability for expert k
|
| 696 |
+
L = α * K * Σ f_k p_k, minimized (→ 1/K) when uniform.
|
| 697 |
+
"""
|
| 698 |
+
if self.K_spec == 0 or self._last_logits is None:
|
| 699 |
+
return torch.tensor(0.0, device=DEVICE)
|
| 700 |
+
logits = self._last_logits # [B, T, K_spec]
|
| 701 |
+
probs = F.softmax(logits, dim=-1)
|
| 702 |
+
top_idx = self._last_top_idx # [B, T, k_act]
|
| 703 |
+
hot = F.one_hot(top_idx, self.K_spec).float().sum(dim=-2) # [B, T, K_spec]
|
| 704 |
+
# Normalize so Σ_k f_k = 1 regardless of K_active_spec.
|
| 705 |
+
f_k = hot.mean(dim=(0, 1)) / max(self.K_active_spec, 1) # [K_spec]
|
| 706 |
+
p_k = probs.mean(dim=(0, 1)) # [K_spec]
|
| 707 |
+
balance = alpha_b * self.K_spec * (f_k * p_k).sum()
|
| 708 |
+
|
| 709 |
+
lse = torch.logsumexp(logits, dim=-1) # [B, T]
|
| 710 |
+
z_loss = alpha_z * (lse ** 2).mean()
|
| 711 |
+
return balance + z_loss
|
| 712 |
+
|
| 713 |
+
def orth_loss(self):
|
| 714 |
+
"""For sigmoid_ortho mode: penalize pairwise expert mask overlap."""
|
| 715 |
+
if self.mece_mode != "sigmoid_ortho": return torch.tensor(0.0, device=DEVICE)
|
| 716 |
+
masks = self._expert_masks() # [K, D_FFN]
|
| 717 |
+
# L2-normalize rows, then off-diagonal Gram
|
| 718 |
+
mn = F.normalize(masks, dim=-1)
|
| 719 |
+
gram = mn @ mn.T # [K, K]
|
| 720 |
+
K = gram.size(0)
|
| 721 |
+
off = gram - torch.eye(K, device=gram.device)
|
| 722 |
+
return (off ** 2).sum() / (K * (K - 1) + 1e-8)
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
# ───────────────────────────── training ─────────────────────────────
|
| 726 |
+
|
| 727 |
+
def get_tau(step, max_steps, tau_start, tau_end, hold_frac=0.2):
|
| 728 |
+
"""Linear anneal over first (1-hold_frac) of steps, then hold at tau_end.
|
| 729 |
+
This prevents τ-anneal shock — the model needs time to adapt to hard masks."""
|
| 730 |
+
anneal_steps = max(1, int(max_steps * (1 - hold_frac)))
|
| 731 |
+
if step >= anneal_steps:
|
| 732 |
+
return tau_end
|
| 733 |
+
frac = step / max(1, anneal_steps - 1)
|
| 734 |
+
return tau_start + frac * (tau_end - tau_start)
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def install_moe(model, K, K_const, K_active_spec, mece_mode, init_As, noise_std,
|
| 738 |
+
freeze_base=True, init_W_rs=None):
|
| 739 |
+
mlp_modules = []
|
| 740 |
+
if init_W_rs is None:
|
| 741 |
+
init_W_rs = [None] * N_LAYERS
|
| 742 |
+
for i in range(N_LAYERS):
|
| 743 |
+
new_mlp = MoEMLP(
|
| 744 |
+
base_mlp=model.layers[i].mlp,
|
| 745 |
+
K=K, K_const=K_const, K_active_spec=K_active_spec,
|
| 746 |
+
mece_mode=mece_mode, init_A=init_As[i], noise_std=noise_std,
|
| 747 |
+
freeze_base=freeze_base, init_W_r=init_W_rs[i])
|
| 748 |
+
model.layers[i].mlp = new_mlp
|
| 749 |
+
mlp_modules.append(new_mlp)
|
| 750 |
+
return mlp_modules
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def main():
|
| 754 |
+
parser = argparse.ArgumentParser()
|
| 755 |
+
parser.add_argument("--phase", type=str, default="A1")
|
| 756 |
+
parser.add_argument("--K", type=int, default=4)
|
| 757 |
+
parser.add_argument("--K_const", type=int, default=0)
|
| 758 |
+
parser.add_argument("--K_active_spec", type=int, default=-1,
|
| 759 |
+
help="# specialists fired per token. Default = round(K_spec * 0.40 / (1 - K_const/K * 0.40)); falls back to max(1, round(K_spec*0.5))")
|
| 760 |
+
parser.add_argument("--loss", choices=["kl", "ce"], default="kl")
|
| 761 |
+
parser.add_argument("--init", choices=["random", "taylor", "em", "kmeans"], default="random")
|
| 762 |
+
parser.add_argument("--core_frac", type=float, default=0.5,
|
| 763 |
+
help="Fraction of PRUNE_K active neurons to concentrate in K_const core (Taylor init only)")
|
| 764 |
+
parser.add_argument("--mece_mode", choices=["softmax", "sigmoid_ortho"], default="softmax")
|
| 765 |
+
parser.add_argument("--tau_start", type=float, default=1.0)
|
| 766 |
+
parser.add_argument("--tau_end", type=float, default=0.01)
|
| 767 |
+
parser.add_argument("--tau_hold_frac", type=float, default=0.2,
|
| 768 |
+
help="Fraction of max_steps to HOLD at tau_end after annealing. Default 0.2 = "
|
| 769 |
+
"anneal over first 80%, hold last 20%. For long continuation runs, set "
|
| 770 |
+
"to e.g. 0.857 to give just 5k anneal steps and 30k hard-tau steps "
|
| 771 |
+
"(out of 35k total).")
|
| 772 |
+
parser.add_argument("--max_steps", type=int, default=2000)
|
| 773 |
+
parser.add_argument("--lr", type=float, default=LR)
|
| 774 |
+
parser.add_argument("--alpha_b", type=float, default=0.01)
|
| 775 |
+
parser.add_argument("--alpha_z", type=float, default=0.001)
|
| 776 |
+
parser.add_argument("--alpha_orth", type=float, default=0.01)
|
| 777 |
+
parser.add_argument("--noise_std", type=float, default=1.0)
|
| 778 |
+
parser.add_argument("--eval_every", type=int, default=200)
|
| 779 |
+
parser.add_argument("--optimizer", choices=["adamw", "adamw8bit"], default="adamw",
|
| 780 |
+
help="adamw8bit uses bitsandbytes 8-bit optimizer — saves ~28GB "
|
| 781 |
+
"optimizer state on 4.65B model, required to --unfreeze_base on H100 80GB")
|
| 782 |
+
parser.add_argument("--freeze_embeddings", action="store_true",
|
| 783 |
+
help="Freeze embed_tokens (+tied lm_head) and embed_tokens_per_layer. "
|
| 784 |
+
"For Gemma-4 E2B these are 2.75B of 5.1B params and embed_tokens_per_layer "
|
| 785 |
+
"is a single 2.35B-element tensor that exceeds bnb 8bit kernel limits. "
|
| 786 |
+
"Freezing them makes --unfreeze_base feasible with plain fp32 AdamW on "
|
| 787 |
+
"the remaining ~2.35B params (~19GB state, fits 80GB).")
|
| 788 |
+
parser.add_argument("--use_lora", action="store_true",
|
| 789 |
+
help="Wrap target Linears with LoRALinear (frozen base + trainable rank-r delta). "
|
| 790 |
+
"Use INSTEAD of full base fine-tuning. Combines naturally with --int4_qat: "
|
| 791 |
+
"LoRA wraps the int4-quantized Linear. Trains only ~10-30M LoRA params + MoE.")
|
| 792 |
+
parser.add_argument("--lora_rank", type=int, default=16,
|
| 793 |
+
help="LoRA rank (low-dim adapter dim). Typical: 8 (less capacity, less overfit), "
|
| 794 |
+
"16 (default), 32 (more capacity).")
|
| 795 |
+
parser.add_argument("--lora_alpha", type=float, default=16.0,
|
| 796 |
+
help="LoRA scaling factor; effective scale = alpha/rank. Default 16/16 = 1.0.")
|
| 797 |
+
parser.add_argument("--W_r_init", choices=["random", "zero", "centroid", "scaled_centroid"], default="random",
|
| 798 |
+
help="Router W_r init: random (default), zero (uniform routing), "
|
| 799 |
+
"centroid (mean W_gate row per Taylor-assigned expert, L2-normalized to 0.02 mag), "
|
| 800 |
+
"scaled_centroid (mean W_gate row per expert, NOT normalized, scaled by --W_r_scale).")
|
| 801 |
+
parser.add_argument("--W_r_scale", type=float, default=1.0,
|
| 802 |
+
help="Multiplier for scaled_centroid init. W_r = scale × mean(W_gate per expert). "
|
| 803 |
+
"Values ~0.1–10 control how 'loud' the router is relative to base weight scale.")
|
| 804 |
+
parser.add_argument("--W_r_lr_mult", type=float, default=1.0,
|
| 805 |
+
help="Learning rate multiplier for router W_r params (and A logits). "
|
| 806 |
+
"E.g., 5.0 trains the router 5× faster than base weights. The router "
|
| 807 |
+
"is ~0.03% of total params and has a specific job — higher LR can "
|
| 808 |
+
"help it converge quickly without destabilizing base-weight training.")
|
| 809 |
+
parser.add_argument("--freeze_A", action="store_true",
|
| 810 |
+
help="Freeze assignment logits A (only router + optionally base train)")
|
| 811 |
+
parser.add_argument("--unfreeze_base", action="store_true",
|
| 812 |
+
help="Train base weights (W_gate/W_up/W_down, attn, norms). Default freezes them.")
|
| 813 |
+
parser.add_argument("--save_checkpoint", type=str, default="",
|
| 814 |
+
help="Save final student state_dict to this path (.pt)")
|
| 815 |
+
parser.add_argument("--save_every", type=int, default=0,
|
| 816 |
+
help="If >0 and --save_checkpoint set, also save an intermediate ckpt every "
|
| 817 |
+
"N max_steps. Filename: <save_checkpoint stem>_step<N>.pt. Use for long "
|
| 818 |
+
"runs where you may want to early-stop without losing progress.")
|
| 819 |
+
parser.add_argument("--shuffle_seed", type=int, default=0,
|
| 820 |
+
help="Seed for the dataloader shuffle. Same seed → same record order. Use a "
|
| 821 |
+
"different seed in continuation runs to expose the model to a new ordering "
|
| 822 |
+
"of the dataset.")
|
| 823 |
+
parser.add_argument("--data_skip", type=int, default=0,
|
| 824 |
+
help="Discard first N samples of the (shuffled) dataloader stream before "
|
| 825 |
+
"training. Combine with same --shuffle_seed as a previous run to start "
|
| 826 |
+
"where it left off — model sees fresh records first.")
|
| 827 |
+
parser.add_argument("--load_checkpoint", type=str, default="",
|
| 828 |
+
help="Load student state_dict from this path BEFORE training (warm-start). "
|
| 829 |
+
"Must be from a prior rung6_moe.py run with matching architecture.")
|
| 830 |
+
parser.add_argument("--calib_path", type=str, default=CALIB_DATA_PATH,
|
| 831 |
+
help="Path to JSONL calibration data for TRAINING. Default: final.jsonl (640 records). "
|
| 832 |
+
"Use bulk.jsonl (~12k records) or trajectories_25k.jsonl (25k) for more data.")
|
| 833 |
+
parser.add_argument("--eval_calib_path", type=str, default="",
|
| 834 |
+
help="Path to JSONL calibration data for EVAL. Default: same as --calib_path. "
|
| 835 |
+
"Set to final.jsonl for consistent eval across curriculum phases.")
|
| 836 |
+
parser.add_argument("--int4_qat", action="store_true",
|
| 837 |
+
help="Enable int4 QAT: wrap target Linears (MLP + attention) with Int4QuantLinear "
|
| 838 |
+
"so forward uses fake-quantized weights (groupwise STE, group_size=128).")
|
| 839 |
+
parser.add_argument("--int4_group_size", type=int, default=32,
|
| 840 |
+
help="Groupwise int4 quant group size. Default 32 matches GGUF Q4_0/Q4_K deploy block size. "
|
| 841 |
+
"128 is another common choice (AWQ-style) with less storage overhead but larger quant error.")
|
| 842 |
+
parser.add_argument("--eval_only", action="store_true",
|
| 843 |
+
help="Skip training; just eval after setup (init + optional checkpoint load + optional "
|
| 844 |
+
"int4 wrap). Useful for measuring untrained-int4 baseline or a specific checkpoint's "
|
| 845 |
+
"eval PPL at tau_end without further optimization.")
|
| 846 |
+
# Knowledge preservation fixes
|
| 847 |
+
parser.add_argument("--diverse_calib_path", type=str, default="",
|
| 848 |
+
help="Path to JSONL (raw 'text' field) for periodic KL-to-base preservation batches. "
|
| 849 |
+
"Usually wikitext or similar pretraining-distribution text.")
|
| 850 |
+
parser.add_argument("--diverse_every_n", type=int, default=4,
|
| 851 |
+
help="Every N optimizer steps, replace the normal CE batch with a KL-to-teacher pass "
|
| 852 |
+
"on diverse data. Default 4 = ~25%% of batches.")
|
| 853 |
+
parser.add_argument("--main_kl_temp", type=float, default=1.0,
|
| 854 |
+
help="Softmax temperature for the MAIN loss when --loss kl. "
|
| 855 |
+
"T>1 softens teacher's argmax commitment. Useful for knowledge "
|
| 856 |
+
"retention but too high (>5) can destabilize Gemma-4 training "
|
| 857 |
+
"due to low teacher entropy.")
|
| 858 |
+
parser.add_argument("--kl_base_lambda", type=float, default=0.5,
|
| 859 |
+
help="Scalar on the diverse-batch KL-to-teacher loss.")
|
| 860 |
+
parser.add_argument("--kl_base_temp", type=float, default=2.0,
|
| 861 |
+
help="Softmax temperature for KL-to-teacher. >1 softens distributions, recovering "
|
| 862 |
+
"tail mass — important when teacher entropy is low (e.g., Gemma-4 E2B). "
|
| 863 |
+
"Try 2-5 for Gemma-3, 5-10 for Gemma-4.")
|
| 864 |
+
parser.add_argument("--w_drift_lambda", type=float, default=0.0,
|
| 865 |
+
help="L2-to-base weight-drift penalty: λ × Σ ‖W_student − W_teacher‖² over trainable "
|
| 866 |
+
"base weights (excluding MoE .A and .W_r). Prevents catastrophic forgetting by "
|
| 867 |
+
"anchoring weights to base. Typical: 1e-6 to 1e-4.")
|
| 868 |
+
parser.add_argument("--real_int4_inplace", action="store_true",
|
| 869 |
+
help="After load_checkpoint, snap target Linear weights to int4 grid in-place (no STE, "
|
| 870 |
+
"no runtime overhead). Simulates deployment — forward uses plain nn.Linear with "
|
| 871 |
+
"already-quantized weights. Combine with --eval_only for the real-int4 benchmark.")
|
| 872 |
+
parser.add_argument("--gaussian_noise_scale", type=float, default=0.0,
|
| 873 |
+
help="Add N(0, scale × p.std()) Gaussian noise to target Linear weights in-place. "
|
| 874 |
+
"Default 0.0 = disabled. 0.129 is the analytical int4 group=32 equivalent.")
|
| 875 |
+
# ── Activation-MSE recovery (mechanism A: generic per-module) ──
|
| 876 |
+
parser.add_argument("--recovery_steps", type=int, default=0,
|
| 877 |
+
help="If >0: run module_recovery.recover_modules_sequentially on every per-layer "
|
| 878 |
+
"MLP after install_moe + wrap_int4 (+ wrap_lora) and BEFORE main training. "
|
| 879 |
+
"Default 0 = disabled.")
|
| 880 |
+
parser.add_argument("--recovery_lr", type=float, default=1e-4,
|
| 881 |
+
help="LR for the generic recovery AdamW (only A and W_r receive grad — base "
|
| 882 |
+
"and LoRA params are not in the trainable set during recovery).")
|
| 883 |
+
parser.add_argument("--recovery_n_batches", type=int, default=8,
|
| 884 |
+
help="# calibration batches sampled from --calib_path for generic recovery.")
|
| 885 |
+
# ── Activation-MSE recovery (mechanism B: specialized MoE per-layer) ──
|
| 886 |
+
parser.add_argument("--moe_recovery_seconds_per_layer", type=float, default=0.0,
|
| 887 |
+
help="If >0: run finetune_moe_per_layer for this many wall-clock seconds per "
|
| 888 |
+
"MLP layer. Pre-caches teacher (X, Y), optimizes A and W_r only. "
|
| 889 |
+
"Default 0 = disabled.")
|
| 890 |
+
parser.add_argument("--moe_recovery_lr", type=float, default=1e-3,
|
| 891 |
+
help="LR for the specialized per-layer recovery (A and W_r are tiny — 1e-3 is fine).")
|
| 892 |
+
parser.add_argument("--moe_recovery_n_calib_records", type=int, default=32,
|
| 893 |
+
help="# calibration records (single-sequence, len MAX_SEQ_LEN) cached for the "
|
| 894 |
+
"specialized recovery. Memory ≈ 2 × N × MAX_SEQ_LEN × hidden × 2 bytes.")
|
| 895 |
+
parser.add_argument("--moe_recovery_use_student_inputs", type=lambda s: s.lower() in ("1", "true", "yes"),
|
| 896 |
+
default=True,
|
| 897 |
+
help="If True (default), refresh student X between layers so each layer sees "
|
| 898 |
+
"error-corrected upstream activations. If False, use teacher X throughout "
|
| 899 |
+
"(matches Sunday's original pipeline).")
|
| 900 |
+
parser.add_argument("--moe_recovery_optimizer", choices=["adam", "muon"], default="adam",
|
| 901 |
+
help="Specialized recovery optimizer. 'muon' uses muon.MuonWithAdam (matrix-aware "
|
| 902 |
+
"Newton-Schulz). A and W_r are both 2D so Muon-eligible.")
|
| 903 |
+
parser.add_argument("--moe_recovery_noise_std", type=float, default=-1.0,
|
| 904 |
+
help="Override MoEMLP router noise during recovery. -1.0 = keep current "
|
| 905 |
+
"MoEMLP setting (default 1.0 from MoE training convention). 0.0 = "
|
| 906 |
+
"deterministic routing for clean per-step loss + meaningful best-state "
|
| 907 |
+
"tracking + train/deploy match. Higher = more router exploration.")
|
| 908 |
+
args = parser.parse_args()
|
| 909 |
+
if not args.eval_calib_path:
|
| 910 |
+
args.eval_calib_path = args.calib_path
|
| 911 |
+
|
| 912 |
+
K_spec = args.K - args.K_const
|
| 913 |
+
assert K_spec >= 0 and args.K_const >= 0 and args.K >= 1
|
| 914 |
+
if args.K_active_spec < 0:
|
| 915 |
+
# Target per-token sparsity = 40% of D_FFN = PRUNE_K neurons.
|
| 916 |
+
# Each expert covers ~D_FFN/K neurons at MECE. K_const always fires (D_FFN/K * K_const).
|
| 917 |
+
# Need K_active_spec such that (K_const + K_active_spec) * D_FFN/K ≈ PRUNE_K
|
| 918 |
+
# → K_active_spec = round(K * PRUNE_P - K_const)
|
| 919 |
+
k_act = max(1, round(args.K * PRUNE_P) - args.K_const) if K_spec > 0 else 0
|
| 920 |
+
args.K_active_spec = k_act
|
| 921 |
+
assert args.K_active_spec <= K_spec
|
| 922 |
+
|
| 923 |
+
os.makedirs("logs", exist_ok=True)
|
| 924 |
+
print(f"=== Rung 6 MoE — phase={args.phase} ===")
|
| 925 |
+
print(f" K={args.K} K_const={args.K_const} K_spec={K_spec} K_active_spec={args.K_active_spec}")
|
| 926 |
+
print(f" mece_mode={args.mece_mode} init={args.init} loss={args.loss}")
|
| 927 |
+
print(f" tau: {args.tau_start} → {args.tau_end} over {args.max_steps} steps")
|
| 928 |
+
# Gemma-4 has two MLP widths (6144 / 12288). Report both layer types' active budgets.
|
| 929 |
+
ratio = (args.K_const + args.K_active_spec) / args.K
|
| 930 |
+
for width_name, d in (("narrow (layers 0-14)", INTERMEDIATE),
|
| 931 |
+
("wide (layers 15+)", INTERMEDIATE_WIDE)):
|
| 932 |
+
eff_active = ratio * d
|
| 933 |
+
prune_k = int(d * PRUNE_P)
|
| 934 |
+
print(f" {width_name}: active ~{eff_active:.0f}/{d} "
|
| 935 |
+
f"(40% target = {prune_k}; diff = {eff_active - prune_k:+.0f})")
|
| 936 |
+
|
| 937 |
+
print(f" freeze_A={args.freeze_A} unfreeze_base={args.unfreeze_base} W_r_init={args.W_r_init}")
|
| 938 |
+
if args.load_checkpoint: print(f" load_checkpoint={args.load_checkpoint}")
|
| 939 |
+
if args.save_checkpoint: print(f" save_checkpoint={args.save_checkpoint}")
|
| 940 |
+
|
| 941 |
+
print(f"Loading teacher & student on {DEVICE}...")
|
| 942 |
+
teacher, tokenizer = load_model()
|
| 943 |
+
teacher.eval()
|
| 944 |
+
for p in teacher.parameters(): p.requires_grad_(False)
|
| 945 |
+
|
| 946 |
+
student, _ = load_model()
|
| 947 |
+
# Note: NO corruption — rung 6 uses the CLEAN IT model.
|
| 948 |
+
freeze_base = not args.unfreeze_base
|
| 949 |
+
if freeze_base:
|
| 950 |
+
for p in student.parameters(): p.requires_grad_(False) # freeze base first
|
| 951 |
+
# If unfreeze_base: leave requires_grad=True on all params (default)
|
| 952 |
+
|
| 953 |
+
# Embedding freeze for Gemma-4 (selectively keep embed_tokens and embed_tokens_per_layer
|
| 954 |
+
# frozen even when the rest of the base is training). Required for Gemma-4 4.65B on 80GB:
|
| 955 |
+
# embed_tokens_per_layer alone is a single 2.35B tensor that breaks bnb 8bit kernels, and
|
| 956 |
+
# embeddings rarely need to move for MoE-preservation work anyway.
|
| 957 |
+
if args.freeze_embeddings:
|
| 958 |
+
n_frozen = 0
|
| 959 |
+
for name, p in student.named_parameters():
|
| 960 |
+
if "embed_tokens" in name: # catches embed_tokens and embed_tokens_per_layer (and tied lm_head)
|
| 961 |
+
p.requires_grad_(False)
|
| 962 |
+
n_frozen += p.numel()
|
| 963 |
+
print(f" Froze embeddings: {n_frozen/1e9:.2f}B params (embed_tokens, embed_tokens_per_layer, tied lm_head)")
|
| 964 |
+
|
| 965 |
+
# Initialization
|
| 966 |
+
taylor_scores = None
|
| 967 |
+
if args.init == "taylor" and not args.load_checkpoint:
|
| 968 |
+
print("Computing Taylor saliency for init...")
|
| 969 |
+
taylor_scores = compute_taylor_saliency(student, tokenizer, n_batches=8, calib_path=args.calib_path)
|
| 970 |
+
init_As = init_assignment_logits(args.init if not args.load_checkpoint else "random",
|
| 971 |
+
args.K, args.K_const, taylor_scores, core_frac=args.core_frac)
|
| 972 |
+
init_W_rs = init_router_weights(args.W_r_init, student, init_As, args.K, args.K_const,
|
| 973 |
+
scale_multiplier=args.W_r_scale)
|
| 974 |
+
|
| 975 |
+
mlp_modules = install_moe(
|
| 976 |
+
student, K=args.K, K_const=args.K_const,
|
| 977 |
+
K_active_spec=args.K_active_spec, mece_mode=args.mece_mode,
|
| 978 |
+
init_As=init_As, noise_std=args.noise_std,
|
| 979 |
+
freeze_base=freeze_base, init_W_rs=init_W_rs)
|
| 980 |
+
|
| 981 |
+
# Optionally freeze A (only router trains) — done AFTER install_moe
|
| 982 |
+
if args.freeze_A:
|
| 983 |
+
for m in mlp_modules:
|
| 984 |
+
m.A.requires_grad_(False)
|
| 985 |
+
print(" A frozen — only router W_r (and base if --unfreeze_base) trains")
|
| 986 |
+
|
| 987 |
+
# Load warm-start checkpoint BEFORE computing trainable params
|
| 988 |
+
if args.load_checkpoint:
|
| 989 |
+
print(f" Loading checkpoint from {args.load_checkpoint}...")
|
| 990 |
+
ckpt = torch.load(args.load_checkpoint, map_location=DEVICE)
|
| 991 |
+
state = ckpt.get('student_state', ckpt) if isinstance(ckpt, dict) else ckpt
|
| 992 |
+
missing, unexpected = student.load_state_dict(state, strict=False)
|
| 993 |
+
print(f" missing={len(missing)} unexpected={len(unexpected)}")
|
| 994 |
+
|
| 995 |
+
# Int4 QAT: wrap target Linears AFTER state_dict load (keys unchanged — subclass of nn.Linear).
|
| 996 |
+
# Must happen BEFORE optimizer creation so parameter references are stable.
|
| 997 |
+
if args.int4_qat:
|
| 998 |
+
Int4QuantLinear._group_size = args.int4_group_size
|
| 999 |
+
n_wrapped = wrap_int4(student)
|
| 1000 |
+
print(f" Int4 QAT: wrapped {n_wrapped} nn.Linear modules (group_size={args.int4_group_size}, "
|
| 1001 |
+
f"range [-7, 7]). Forward uses fake-quant; backward is STE through fp weight.")
|
| 1002 |
+
|
| 1003 |
+
# LoRA: wrap target Linears (incl. Int4QuantLinear) with LoRALinear so base is frozen
|
| 1004 |
+
# and only LoRA A/B + MoE A logits/W_r train. Apply AFTER int4 so the base inside LoRA
|
| 1005 |
+
# is the int4-quantized Linear (deploy-realistic).
|
| 1006 |
+
if args.use_lora:
|
| 1007 |
+
# Pure-LoRA semantics: freeze ALL base params (including attention, norms, scalars
|
| 1008 |
+
# not in LoRA target list). MoE A/W_r and the LoRA adapters added by wrap_lora are
|
| 1009 |
+
# the only trainable things. Overrides --unfreeze_base.
|
| 1010 |
+
for name, p in student.named_parameters():
|
| 1011 |
+
if not (name.endswith(".A") or name.endswith(".W_r")):
|
| 1012 |
+
p.requires_grad_(False)
|
| 1013 |
+
n_wrapped, n_lora_params = wrap_lora(student, rank=args.lora_rank, alpha=args.lora_alpha)
|
| 1014 |
+
print(f" LoRA: wrapped {n_wrapped} Linears with rank={args.lora_rank} alpha={args.lora_alpha} "
|
| 1015 |
+
f"(trainable LoRA params: {n_lora_params/1e6:.2f}M)")
|
| 1016 |
+
|
| 1017 |
+
# Real int4 quantization in-place (deploy simulation — no runtime quant overhead).
|
| 1018 |
+
if args.real_int4_inplace:
|
| 1019 |
+
n_q = apply_int4_inplace(student, group_size=args.int4_group_size)
|
| 1020 |
+
print(f" Real int4 inplace: quantized {n_q} Linear weights to int4 grid "
|
| 1021 |
+
f"(group_size={args.int4_group_size}); weights now on-grid, regular nn.Linear forward.")
|
| 1022 |
+
|
| 1023 |
+
# Gaussian-proxy noise benchmark.
|
| 1024 |
+
if args.gaussian_noise_scale > 0:
|
| 1025 |
+
n_g = apply_gaussian_noise_inplace(student, noise_scale=args.gaussian_noise_scale)
|
| 1026 |
+
print(f" Gaussian noise inplace: added N(0, {args.gaussian_noise_scale} × p.std()) "
|
| 1027 |
+
f"to {n_g} Linear weights.")
|
| 1028 |
+
|
| 1029 |
+
# ────────── Activation-MSE recovery (mechanism A: generic) ──────────
|
| 1030 |
+
# Runs AFTER install_moe + wrap_int4 (+ wrap_lora) so the recovered student
|
| 1031 |
+
# is the deployed one (int4 fake-quant in the loop, MoE routing engaged at
|
| 1032 |
+
# tau_end). Trainable params during recovery: same as training (i.e., A,
|
| 1033 |
+
# W_r — base is frozen unless --unfreeze_base, in which case it'd also move,
|
| 1034 |
+
# but we explicitly want only A/W_r so we do NOT alter requires_grad here).
|
| 1035 |
+
if args.recovery_steps > 0:
|
| 1036 |
+
# Hard routing during recovery — match deploy-time temperature.
|
| 1037 |
+
for m in mlp_modules: m.tau = args.tau_end
|
| 1038 |
+
# Optionally override router noise during recovery (default -1 = leave as-is).
|
| 1039 |
+
prev_noise = [getattr(m, "noise_std", None) for m in mlp_modules]
|
| 1040 |
+
if args.moe_recovery_noise_std >= 0:
|
| 1041 |
+
for m in mlp_modules:
|
| 1042 |
+
if hasattr(m, "noise_std"): m.noise_std = args.moe_recovery_noise_std
|
| 1043 |
+
print(f"\n [recovery A] generic recover_modules_sequentially "
|
| 1044 |
+
f"steps={args.recovery_steps} lr={args.recovery_lr} "
|
| 1045 |
+
f"n_batches={args.recovery_n_batches} tau={args.tau_end} "
|
| 1046 |
+
f"noise={args.moe_recovery_noise_std if args.moe_recovery_noise_std >= 0 else 'unchanged'}")
|
| 1047 |
+
# Restrict trainable set to MoE params (A, W_r) for the duration of
|
| 1048 |
+
# recovery. Snapshot prior requires_grad so we can restore it for main
|
| 1049 |
+
# training (e.g., LoRA adapters that should keep training afterwards).
|
| 1050 |
+
prev_requires_grad = {n: p.requires_grad for n, p in student.named_parameters()}
|
| 1051 |
+
# Restrict to A/W_r — but RESPECT --freeze_A: don't enable A if it was
|
| 1052 |
+
# frozen pre-recovery. Same for W_r (in case caller froze it).
|
| 1053 |
+
for n, p in student.named_parameters():
|
| 1054 |
+
is_moe = n.endswith(".A") or n.endswith(".W_r")
|
| 1055 |
+
p.requires_grad_(is_moe and prev_requires_grad[n])
|
| 1056 |
+
# Pull `recovery_n_batches` calibration batches (input_ids only).
|
| 1057 |
+
rec_seqs = load_seqs(tokenizer, "train", calib_path=args.calib_path)
|
| 1058 |
+
rec_seqs = rec_seqs[:args.recovery_n_batches * BATCH]
|
| 1059 |
+
rec_loader = torch.utils.data.DataLoader(rec_seqs, batch_size=BATCH)
|
| 1060 |
+
rec_input_ids = [batch["input_ids"] for batch in rec_loader][:args.recovery_n_batches]
|
| 1061 |
+
if not rec_input_ids:
|
| 1062 |
+
print(" [recovery A] no calibration data — skipping")
|
| 1063 |
+
else:
|
| 1064 |
+
n_train_per_mlp = sum(
|
| 1065 |
+
p.numel() for n, p in mlp_modules[0].named_parameters(recurse=False)
|
| 1066 |
+
if p.requires_grad and n in ("A", "W_r")
|
| 1067 |
+
)
|
| 1068 |
+
print(f" [recovery A] per-layer MoE trainable params: {n_train_per_mlp}")
|
| 1069 |
+
rec_results = recover_modules_via_generic_pipeline(
|
| 1070 |
+
student=student, teacher=teacher,
|
| 1071 |
+
calibration_input_ids=rec_input_ids,
|
| 1072 |
+
n_layers=N_LAYERS,
|
| 1073 |
+
steps=args.recovery_steps,
|
| 1074 |
+
lr=args.recovery_lr,
|
| 1075 |
+
device=DEVICE,
|
| 1076 |
+
)
|
| 1077 |
+
for r in rec_results:
|
| 1078 |
+
print(f" {r['name']} in_mse={r['input_mse']:.4e} "
|
| 1079 |
+
f"out_pre={r['output_mse_before']:.4e} out_post={r['output_mse_after']:.4e}")
|
| 1080 |
+
# Restore prior requires_grad state, tau, and noise.
|
| 1081 |
+
for n, p in student.named_parameters():
|
| 1082 |
+
p.requires_grad_(prev_requires_grad[n])
|
| 1083 |
+
for m in mlp_modules: m.tau = args.tau_start
|
| 1084 |
+
if args.moe_recovery_noise_std >= 0:
|
| 1085 |
+
for m, n in zip(mlp_modules, prev_noise):
|
| 1086 |
+
if hasattr(m, "noise_std") and n is not None: m.noise_std = n
|
| 1087 |
+
|
| 1088 |
+
# ────────── Activation-MSE recovery (mechanism B: specialized MoE) ──────────
|
| 1089 |
+
# Pre-cache (X, Y) per layer once via teacher forward, then per-layer
|
| 1090 |
+
# time-budgeted optimization of A and W_r only with student-input
|
| 1091 |
+
# propagation between layers.
|
| 1092 |
+
if args.moe_recovery_seconds_per_layer > 0:
|
| 1093 |
+
# Hard routing during recovery — match deploy-time temperature.
|
| 1094 |
+
for m in mlp_modules: m.tau = args.tau_end
|
| 1095 |
+
print(f"\n [recovery B] finetune_moe_per_layer "
|
| 1096 |
+
f"sec/layer={args.moe_recovery_seconds_per_layer} "
|
| 1097 |
+
f"lr={args.moe_recovery_lr} n_calib={args.moe_recovery_n_calib_records} "
|
| 1098 |
+
f"use_student_inputs={args.moe_recovery_use_student_inputs} "
|
| 1099 |
+
f"opt={args.moe_recovery_optimizer} tau={args.tau_end}")
|
| 1100 |
+
moe_rec_seqs = load_seqs(tokenizer, "train", calib_path=args.calib_path)
|
| 1101 |
+
moe_rec_seqs = moe_rec_seqs[:args.moe_recovery_n_calib_records * BATCH]
|
| 1102 |
+
moe_rec_loader = torch.utils.data.DataLoader(moe_rec_seqs, batch_size=BATCH)
|
| 1103 |
+
moe_rec_input_ids = [b["input_ids"] for b in moe_rec_loader][:args.moe_recovery_n_calib_records]
|
| 1104 |
+
if not moe_rec_input_ids:
|
| 1105 |
+
print(" [recovery B] no calibration data — skipping")
|
| 1106 |
+
else:
|
| 1107 |
+
n_train_per_mlp = sum(
|
| 1108 |
+
p.numel() for n, p in mlp_modules[0].named_parameters(recurse=False)
|
| 1109 |
+
if p.requires_grad and n in ("A", "W_r")
|
| 1110 |
+
)
|
| 1111 |
+
print(f" [recovery B] per-layer MoE trainable params: {n_train_per_mlp}")
|
| 1112 |
+
moe_rec_results = finetune_moe_per_layer(
|
| 1113 |
+
student=student, teacher=teacher,
|
| 1114 |
+
calibration_input_ids=moe_rec_input_ids,
|
| 1115 |
+
n_layers=N_LAYERS,
|
| 1116 |
+
seconds_per_layer=args.moe_recovery_seconds_per_layer,
|
| 1117 |
+
lr=args.moe_recovery_lr,
|
| 1118 |
+
optimizer=args.moe_recovery_optimizer,
|
| 1119 |
+
use_student_inputs=args.moe_recovery_use_student_inputs,
|
| 1120 |
+
device=DEVICE,
|
| 1121 |
+
tau_end=args.tau_end,
|
| 1122 |
+
noise_std=(None if args.moe_recovery_noise_std < 0 else args.moe_recovery_noise_std),
|
| 1123 |
+
)
|
| 1124 |
+
# Restore tau to start for main training.
|
| 1125 |
+
for m in mlp_modules: m.tau = args.tau_start
|
| 1126 |
+
|
| 1127 |
+
trainable_params = [p for p in student.parameters() if p.requires_grad]
|
| 1128 |
+
n_train = sum(p.numel() for p in trainable_params)
|
| 1129 |
+
moe_params_max = sum(_d_ffn_at(i) * args.K for i in range(N_LAYERS)) \
|
| 1130 |
+
+ N_LAYERS * D_MODEL * max(K_spec, 0)
|
| 1131 |
+
trainable_base = sum(p.numel() for n, p in student.named_parameters()
|
| 1132 |
+
if p.requires_grad and not (n.endswith(".A") or n.endswith(".W_r")))
|
| 1133 |
+
trainable_moe = sum(p.numel() for n, p in student.named_parameters()
|
| 1134 |
+
if p.requires_grad and (n.endswith(".A") or n.endswith(".W_r")))
|
| 1135 |
+
print(f" Trainable params: {n_train/1e6:.3f}M "
|
| 1136 |
+
f"(MoE: {trainable_moe/1e6:.3f}M / max {moe_params_max/1e6:.3f}M, "
|
| 1137 |
+
f"base trainable: {trainable_base/1e6:.2f}M)")
|
| 1138 |
+
if freeze_base and not args.freeze_A:
|
| 1139 |
+
assert trainable_base == 0, f"freeze_base=True but {trainable_base} base params are trainable"
|
| 1140 |
+
assert trainable_moe <= moe_params_max * 1.01, "Too many MoE params trainable"
|
| 1141 |
+
if args.freeze_A:
|
| 1142 |
+
assert trainable_moe <= N_LAYERS * D_MODEL * max(K_spec, 0) * 1.01, \
|
| 1143 |
+
"freeze_A=True but A appears to be trainable"
|
| 1144 |
+
|
| 1145 |
+
# Eval-only mode: skip training entirely, jump to final eval at tau_end.
|
| 1146 |
+
if args.eval_only:
|
| 1147 |
+
print(f" Eval-only mode — skipping training, evaluating at tau={args.tau_end}")
|
| 1148 |
+
print(f" Eval data: {args.eval_calib_path}")
|
| 1149 |
+
for m in mlp_modules: m.tau = args.tau_end
|
| 1150 |
+
final_ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path)
|
| 1151 |
+
print(f"\n=== Eval-only PPL (tau={args.tau_end}): {final_ppl:.4f} "
|
| 1152 |
+
f"baseline(bottom60 CE)={BASELINE_PPL:.4f} clean={CLEAN_PPL:.4f} ===")
|
| 1153 |
+
out = {
|
| 1154 |
+
"phase": args.phase, "config": vars(args),
|
| 1155 |
+
"final_ppl": final_ppl,
|
| 1156 |
+
"baseline_ppl": BASELINE_PPL, "clean_ppl": CLEAN_PPL,
|
| 1157 |
+
"ppl_curve": [], "eval_only": True,
|
| 1158 |
+
}
|
| 1159 |
+
os.makedirs("logs", exist_ok=True)
|
| 1160 |
+
out_path = f"logs/rung6_moe_{args.phase}_results.json"
|
| 1161 |
+
with open(out_path, "w") as f:
|
| 1162 |
+
json.dump(out, f, indent=2)
|
| 1163 |
+
print(f"Saved to {out_path}")
|
| 1164 |
+
return
|
| 1165 |
+
|
| 1166 |
+
# Split params into MoE (A + W_r) vs base for per-group LR.
|
| 1167 |
+
# --W_r_lr_mult multiplies the MoE group's LR relative to base_params' args.lr.
|
| 1168 |
+
moe_group_params = [p for n, p in student.named_parameters()
|
| 1169 |
+
if p.requires_grad and (n.endswith(".A") or n.endswith(".W_r"))]
|
| 1170 |
+
base_group_params = [p for n, p in student.named_parameters()
|
| 1171 |
+
if p.requires_grad and not (n.endswith(".A") or n.endswith(".W_r"))]
|
| 1172 |
+
param_groups = [
|
| 1173 |
+
{"params": base_group_params, "lr": args.lr},
|
| 1174 |
+
{"params": moe_group_params, "lr": args.lr * args.W_r_lr_mult},
|
| 1175 |
+
]
|
| 1176 |
+
print(f" LR: base={args.lr:.2e} MoE(A+W_r)={args.lr * args.W_r_lr_mult:.2e} "
|
| 1177 |
+
f"(multiplier={args.W_r_lr_mult})")
|
| 1178 |
+
if args.optimizer == "adamw8bit":
|
| 1179 |
+
if not _HAS_BNB:
|
| 1180 |
+
raise RuntimeError("bitsandbytes not installed — pip install bitsandbytes")
|
| 1181 |
+
# Paged variant handles huge tensors (Gemma-4's embed_tokens_per_layer is 2.35B params,
|
| 1182 |
+
# exceeds non-paged bnb kernel grid limits → "invalid configuration argument").
|
| 1183 |
+
optimizer = bnb.optim.PagedAdamW8bit(param_groups, weight_decay=0.01)
|
| 1184 |
+
print(f" Using bnb.optim.PagedAdamW8bit (~28GB optimizer-state savings, "
|
| 1185 |
+
f"paged to handle Gemma-4's 2.35B embed_tokens_per_layer)")
|
| 1186 |
+
else:
|
| 1187 |
+
optimizer = AdamW(param_groups, weight_decay=0.01)
|
| 1188 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=args.max_steps, eta_min=args.lr * 0.1)
|
| 1189 |
+
|
| 1190 |
+
print(f" Train data: {args.calib_path}")
|
| 1191 |
+
print(f" Eval data: {args.eval_calib_path}")
|
| 1192 |
+
# When train and eval paths differ, use ALL records of train file (no need to withhold 20%
|
| 1193 |
+
# since eval comes from a separate file).
|
| 1194 |
+
train_split = "all" if args.calib_path != args.eval_calib_path else "train"
|
| 1195 |
+
seqs = load_seqs(tokenizer, train_split, calib_path=args.calib_path)
|
| 1196 |
+
print(f" Loaded {len(seqs)} train sequences of {MAX_SEQ_LEN} tokens = {len(seqs)*MAX_SEQ_LEN/1e6:.2f}M tokens"
|
| 1197 |
+
f" (split={train_split})")
|
| 1198 |
+
# Deterministic shuffle: same --shuffle_seed reproduces the same record order.
|
| 1199 |
+
# Use a different seed in continuation runs to expose model to NEW orderings of
|
| 1200 |
+
# the dataset (avoids replaying the same trajectory the prior run already trained on).
|
| 1201 |
+
g = torch.Generator(); g.manual_seed(args.shuffle_seed)
|
| 1202 |
+
loader = torch.utils.data.DataLoader(seqs, BATCH, shuffle=True, generator=g)
|
| 1203 |
+
loader_iter = iter(loader)
|
| 1204 |
+
# Optional skip: discard first N samples of the shuffled stream before training begins.
|
| 1205 |
+
# Useful when a previous run with the same shuffle_seed consumed N samples.
|
| 1206 |
+
if args.data_skip > 0:
|
| 1207 |
+
skipped = 0
|
| 1208 |
+
for _ in range(args.data_skip):
|
| 1209 |
+
try:
|
| 1210 |
+
next(loader_iter); skipped += 1
|
| 1211 |
+
except StopIteration:
|
| 1212 |
+
loader_iter = iter(loader)
|
| 1213 |
+
next(loader_iter); skipped += 1
|
| 1214 |
+
print(f" Skipped first {skipped} samples (data_skip={args.data_skip})")
|
| 1215 |
+
|
| 1216 |
+
# Optional knowledge-preservation: load diverse corpus + cache teacher base params.
|
| 1217 |
+
diverse_loader_iter = None
|
| 1218 |
+
diverse_dataset_obj = None
|
| 1219 |
+
if args.diverse_calib_path:
|
| 1220 |
+
print(f" Diverse corpus (KL-to-base): {args.diverse_calib_path}")
|
| 1221 |
+
diverse_seqs = load_seqs(tokenizer, "all", calib_path=args.diverse_calib_path, raw_text=True)
|
| 1222 |
+
print(f" {len(diverse_seqs)} sequences, every {args.diverse_every_n} steps, "
|
| 1223 |
+
f"λ={args.kl_base_lambda}, T={args.kl_base_temp}")
|
| 1224 |
+
diverse_dataset_obj = torch.utils.data.DataLoader(diverse_seqs, BATCH, shuffle=True)
|
| 1225 |
+
diverse_loader_iter = iter(diverse_dataset_obj)
|
| 1226 |
+
|
| 1227 |
+
teacher_param_map = None
|
| 1228 |
+
if args.w_drift_lambda > 0:
|
| 1229 |
+
print(f" W-drift penalty active: λ={args.w_drift_lambda} on trainable base params")
|
| 1230 |
+
teacher_param_map = {n: p.detach() for n, p in teacher.named_parameters()}
|
| 1231 |
+
|
| 1232 |
+
step, accum_loss = 0, 0.0
|
| 1233 |
+
optimizer.zero_grad()
|
| 1234 |
+
t0 = time.time()
|
| 1235 |
+
curve = []
|
| 1236 |
+
|
| 1237 |
+
while step < args.max_steps:
|
| 1238 |
+
tau = get_tau(step, args.max_steps, args.tau_start, args.tau_end, hold_frac=args.tau_hold_frac)
|
| 1239 |
+
for m in mlp_modules: m.tau = tau
|
| 1240 |
+
|
| 1241 |
+
student.train()
|
| 1242 |
+
use_diverse = (diverse_loader_iter is not None and step > 0 and (step % args.diverse_every_n == 0))
|
| 1243 |
+
|
| 1244 |
+
if use_diverse:
|
| 1245 |
+
# Pretraining-distribution preservation batch: KL-to-teacher at temperature T.
|
| 1246 |
+
try: batch = next(diverse_loader_iter)
|
| 1247 |
+
except StopIteration:
|
| 1248 |
+
diverse_loader_iter = iter(diverse_dataset_obj); batch = next(diverse_loader_iter)
|
| 1249 |
+
ids = batch["input_ids"].to(DEVICE)
|
| 1250 |
+
with torch.no_grad():
|
| 1251 |
+
t_logits = teacher(ids)
|
| 1252 |
+
s_logits = student(ids)
|
| 1253 |
+
# High-temperature KL: softens sharp teacher distributions to carry tail signal.
|
| 1254 |
+
main_loss = args.kl_base_lambda * kl_loss(s_logits[:, :-1], t_logits[:, :-1], temp=args.kl_base_temp)
|
| 1255 |
+
else:
|
| 1256 |
+
# Normal CE/KL batch on IT trajectories.
|
| 1257 |
+
try: batch = next(loader_iter)
|
| 1258 |
+
except StopIteration:
|
| 1259 |
+
loader_iter = iter(loader); batch = next(loader_iter)
|
| 1260 |
+
|
| 1261 |
+
ids = batch["input_ids"].to(DEVICE)
|
| 1262 |
+
labels = batch["labels"][:, :-1].to(DEVICE)
|
| 1263 |
+
with torch.no_grad():
|
| 1264 |
+
t_logits = teacher(ids)
|
| 1265 |
+
s_logits = student(ids)
|
| 1266 |
+
|
| 1267 |
+
if args.loss == "kl":
|
| 1268 |
+
# Mask = positions where labels != -100 (i.e., assistant response only).
|
| 1269 |
+
# Same masking we apply to CE — keeps "don't train on prompt tokens" consistent.
|
| 1270 |
+
kl_mask = (labels != -100)
|
| 1271 |
+
main_loss = kl_loss(s_logits[:, :-1], t_logits[:, :-1],
|
| 1272 |
+
temp=args.main_kl_temp, mask=kl_mask)
|
| 1273 |
+
else:
|
| 1274 |
+
main_loss = ce_loss(s_logits[:, :-1], labels)
|
| 1275 |
+
|
| 1276 |
+
# Aux losses apply on every batch — functions of module state, not batch content.
|
| 1277 |
+
aux = sum(m.aux_loss(args.alpha_b, args.alpha_z) for m in mlp_modules)
|
| 1278 |
+
orth = sum(m.orth_loss() for m in mlp_modules) * args.alpha_orth
|
| 1279 |
+
|
| 1280 |
+
# Optional: weight-drift penalty on trainable base params (EWC-lite).
|
| 1281 |
+
drift = torch.tensor(0.0, device=DEVICE)
|
| 1282 |
+
if args.w_drift_lambda > 0:
|
| 1283 |
+
for n, p in student.named_parameters():
|
| 1284 |
+
if not p.requires_grad: continue
|
| 1285 |
+
if n.endswith(".A") or n.endswith(".W_r"): continue
|
| 1286 |
+
t = teacher_param_map.get(n) if teacher_param_map is not None else None
|
| 1287 |
+
if t is not None and t.shape == p.shape:
|
| 1288 |
+
drift = drift + ((p - t) ** 2).sum()
|
| 1289 |
+
drift = drift * args.w_drift_lambda
|
| 1290 |
+
|
| 1291 |
+
loss = (main_loss + aux + orth + drift) / GRAD_ACCUM
|
| 1292 |
+
loss.backward()
|
| 1293 |
+
accum_loss += loss.item()
|
| 1294 |
+
|
| 1295 |
+
if (step + 1) % GRAD_ACCUM == 0:
|
| 1296 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
| 1297 |
+
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
| 1298 |
+
|
| 1299 |
+
if (step + 1) % args.eval_every == 0:
|
| 1300 |
+
# Diagnostic metrics (argmax-based hard assignment regardless of τ)
|
| 1301 |
+
with torch.no_grad():
|
| 1302 |
+
avg_entropy = 0.0; avg_jaccard = 0.0
|
| 1303 |
+
for m in mlp_modules:
|
| 1304 |
+
probs = F.softmax(m.A / max(tau, 1e-3), dim=-1) # [D_FFN, K]
|
| 1305 |
+
ent = -(probs * (probs.clamp_min(1e-8)).log()).sum(-1).mean().item()
|
| 1306 |
+
avg_entropy += ent
|
| 1307 |
+
# Hard assignment: each neuron → argmax expert
|
| 1308 |
+
hard = F.one_hot(probs.argmax(dim=-1), args.K).float().T # [K, D_FFN]
|
| 1309 |
+
inter = hard @ hard.T # [K, K]
|
| 1310 |
+
sz = hard.sum(dim=-1, keepdim=True) # [K, 1]
|
| 1311 |
+
union = sz + sz.T - inter
|
| 1312 |
+
jac_off = (inter / union.clamp_min(1.0))
|
| 1313 |
+
jac_off = jac_off - torch.diag(torch.diag(jac_off)) # zero diagonal
|
| 1314 |
+
avg_jaccard += jac_off.sum().item() / (args.K * (args.K - 1) + 1e-8)
|
| 1315 |
+
avg_entropy /= len(mlp_modules)
|
| 1316 |
+
avg_jaccard /= len(mlp_modules)
|
| 1317 |
+
ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path)
|
| 1318 |
+
curve.append({
|
| 1319 |
+
"step": step + 1, "ppl": ppl, "tau": tau,
|
| 1320 |
+
"assign_entropy": avg_entropy, "jaccard": avg_jaccard,
|
| 1321 |
+
})
|
| 1322 |
+
print(f" step={step+1:4d} tau={tau:.4f} loss={accum_loss*GRAD_ACCUM:.4f} "
|
| 1323 |
+
f"ppl={ppl:.4f} H(A)={avg_entropy:.3f} Jac={avg_jaccard:.4f} "
|
| 1324 |
+
f"t={time.time()-t0:.0f}s")
|
| 1325 |
+
accum_loss = 0.0
|
| 1326 |
+
|
| 1327 |
+
# Intermediate ckpt save (--save_every) — single rolling file, OVERWRITES previous.
|
| 1328 |
+
# Filename: <save_checkpoint stem>_intermediate.pt — only one extra ckpt on disk
|
| 1329 |
+
# at any time. Read 'step' field of the saved dict to know which step it was at.
|
| 1330 |
+
if args.save_every and args.save_checkpoint and (step + 1) % args.save_every == 0:
|
| 1331 |
+
stem, ext = os.path.splitext(args.save_checkpoint)
|
| 1332 |
+
inter_path = f"{stem}_intermediate{ext}"
|
| 1333 |
+
os.makedirs(os.path.dirname(inter_path) or ".", exist_ok=True)
|
| 1334 |
+
torch.save({
|
| 1335 |
+
'student_state': student.state_dict(),
|
| 1336 |
+
'config': vars(args),
|
| 1337 |
+
'step': step + 1,
|
| 1338 |
+
}, inter_path)
|
| 1339 |
+
print(f" [intermediate] overwrote {inter_path} (step {step+1})")
|
| 1340 |
+
|
| 1341 |
+
step += 1
|
| 1342 |
+
|
| 1343 |
+
# Final eval at tau_end
|
| 1344 |
+
for m in mlp_modules: m.tau = args.tau_end
|
| 1345 |
+
final_ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path)
|
| 1346 |
+
print(f"\n=== Final PPL (tau={args.tau_end}): {final_ppl:.4f} "
|
| 1347 |
+
f"baseline(bottom60 CE)={BASELINE_PPL:.4f} clean={CLEAN_PPL:.4f} ===")
|
| 1348 |
+
|
| 1349 |
+
out = {
|
| 1350 |
+
"phase": args.phase, "config": vars(args),
|
| 1351 |
+
"final_ppl": final_ppl,
|
| 1352 |
+
"baseline_ppl": BASELINE_PPL, "clean_ppl": CLEAN_PPL,
|
| 1353 |
+
"ppl_curve": curve,
|
| 1354 |
+
}
|
| 1355 |
+
os.makedirs("logs", exist_ok=True)
|
| 1356 |
+
out_path = f"logs/rung6_moe_{args.phase}_results.json"
|
| 1357 |
+
with open(out_path, "w") as f:
|
| 1358 |
+
json.dump(out, f, indent=2)
|
| 1359 |
+
print(f"Saved to {out_path}")
|
| 1360 |
+
|
| 1361 |
+
if args.save_checkpoint:
|
| 1362 |
+
os.makedirs(os.path.dirname(args.save_checkpoint) or ".", exist_ok=True)
|
| 1363 |
+
torch.save({
|
| 1364 |
+
'student_state': student.state_dict(),
|
| 1365 |
+
'config': vars(args),
|
| 1366 |
+
'final_ppl': final_ppl,
|
| 1367 |
+
}, args.save_checkpoint)
|
| 1368 |
+
print(f"Saved checkpoint to {args.save_checkpoint}")
|
| 1369 |
+
|
| 1370 |
+
|
| 1371 |
+
if __name__ == "__main__":
|
| 1372 |
+
main()
|