Update modeling_steerling.py
Browse files- modeling_steerling.py +1574 -0
modeling_steerling.py
ADDED
|
@@ -0,0 +1,1574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
# Auto-generated by scripts/build_hf_files_v3.py — do not edit manually.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import TYPE_CHECKING, Any
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import math
|
| 14 |
+
import warnings
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ======================================================================
|
| 18 |
+
# steerling/models/layers/primitives.py
|
| 19 |
+
# ======================================================================
|
| 20 |
+
|
| 21 |
+
class RMSNorm(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Root Mean Square Layer Normalization.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, config, size: int | None=None):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.eps = getattr(config, 'norm_eps', 1e-05)
|
| 29 |
+
norm_size = size if size is not None else config.n_embd
|
| 30 |
+
self.weight = nn.Parameter(torch.ones(norm_size))
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
og = x.dtype
|
| 34 |
+
x = x.float()
|
| 35 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
| 36 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 37 |
+
return (self.weight * x).to(og)
|
| 38 |
+
|
| 39 |
+
class BufferCache:
|
| 40 |
+
"""Simple cache for storing tensors (used by RotaryEmbedding)."""
|
| 41 |
+
|
| 42 |
+
def __init__(self):
|
| 43 |
+
self._cache: dict[str, torch.Tensor] = {}
|
| 44 |
+
|
| 45 |
+
def get(self, key: str) -> torch.Tensor | None:
|
| 46 |
+
return self._cache.get(key)
|
| 47 |
+
|
| 48 |
+
def __setitem__(self, key: str, value: torch.Tensor):
|
| 49 |
+
self._cache[key] = value
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, key: str) -> torch.Tensor:
|
| 52 |
+
return self._cache[key]
|
| 53 |
+
|
| 54 |
+
class RotaryEmbedding(nn.Module):
|
| 55 |
+
"""
|
| 56 |
+
Rotary Position Embeddings (RoPE).
|
| 57 |
+
|
| 58 |
+
Applies rotary embeddings to queries and keys for position information.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dim: Dimension of the rotary embeddings (typically head_dim)
|
| 62 |
+
max_seq_len: Maximum sequence length to cache
|
| 63 |
+
base: Base for inverse frequency computation (theta)
|
| 64 |
+
rope_full_precision: Whether to compute RoPE in full precision
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, dim: int, max_seq_len: int=2048, base: float=10000.0, rope_full_precision: bool=True):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.dim = dim
|
| 70 |
+
self.max_seq_len = max_seq_len
|
| 71 |
+
self.rope_theta = base
|
| 72 |
+
self.rope_full_precision = rope_full_precision
|
| 73 |
+
self.__cache = BufferCache()
|
| 74 |
+
self.get_rotary_embedding(max_seq_len, torch.device('cpu'))
|
| 75 |
+
|
| 76 |
+
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
| 77 |
+
"""Get or compute rotary embeddings for given sequence length."""
|
| 78 |
+
pos_sin = self.__cache.get('rope_pos_sin')
|
| 79 |
+
pos_cos = self.__cache.get('rope_pos_cos')
|
| 80 |
+
if pos_sin is not None and pos_cos is not None and (pos_sin.shape[-2] >= seq_len) and (pos_cos.shape[-2] >= seq_len):
|
| 81 |
+
if pos_sin.device != device:
|
| 82 |
+
pos_sin = pos_sin.to(device)
|
| 83 |
+
self.__cache['rope_pos_sin'] = pos_sin
|
| 84 |
+
if pos_cos.device != device:
|
| 85 |
+
pos_cos = pos_cos.to(device)
|
| 86 |
+
self.__cache['rope_pos_cos'] = pos_cos
|
| 87 |
+
return (pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :])
|
| 88 |
+
with torch.autocast(device.type, enabled=False):
|
| 89 |
+
inv_freq = 1.0 / self.rope_theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float) / self.dim)
|
| 90 |
+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
| 91 |
+
freqs = torch.outer(seq, inv_freq)
|
| 92 |
+
positions = torch.cat((freqs, freqs), dim=-1)
|
| 93 |
+
pos_sin = positions.sin()[None, None, :, :]
|
| 94 |
+
pos_cos = positions.cos()[None, None, :, :]
|
| 95 |
+
self.__cache['rope_pos_sin'] = pos_sin
|
| 96 |
+
self.__cache['rope_pos_cos'] = pos_cos
|
| 97 |
+
return (pos_sin, pos_cos)
|
| 98 |
+
|
| 99 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
"""Rotate half the hidden dims of the input."""
|
| 101 |
+
B, nh, T, hs = x.size()
|
| 102 |
+
x = x.view(B, nh, T, 2, hs // 2)
|
| 103 |
+
x1, x2 = x.unbind(dim=-2)
|
| 104 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 105 |
+
|
| 106 |
+
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
"""Apply rotary position embeddings to input tensor."""
|
| 108 |
+
return (t * pos_cos + self.rotate_half(t) * pos_sin).to(t.dtype)
|
| 109 |
+
|
| 110 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 111 |
+
"""Apply rotary embeddings to queries and keys."""
|
| 112 |
+
if self.rope_full_precision:
|
| 113 |
+
q_, k_ = (q.float(), k.float())
|
| 114 |
+
else:
|
| 115 |
+
q_, k_ = (q, k)
|
| 116 |
+
with torch.autocast(q.device.type, enabled=False):
|
| 117 |
+
query_len, key_len = (q_.shape[-2], k_.shape[-2])
|
| 118 |
+
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
|
| 119 |
+
pos_sin = pos_sin.type_as(q_)
|
| 120 |
+
pos_cos = pos_cos.type_as(q_)
|
| 121 |
+
q_ = self.apply_rotary_pos_emb(pos_sin[:, :, key_len - query_len:key_len, :], pos_cos[:, :, key_len - query_len:key_len, :], q_)
|
| 122 |
+
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
|
| 123 |
+
return (q_.type_as(q), k_.type_as(k))
|
| 124 |
+
|
| 125 |
+
class MLP(nn.Module):
|
| 126 |
+
"""
|
| 127 |
+
Multi-Layer Perceptron with SwiGLU or standard activation.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
config: Model config with n_embd, mlp_ratio, use_bias, mlp_type, activation
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
def __init__(self, config):
|
| 134 |
+
super().__init__()
|
| 135 |
+
if hasattr(config, 'intermediate_size') and config.intermediate_size is not None:
|
| 136 |
+
intermediate_size = config.intermediate_size
|
| 137 |
+
else:
|
| 138 |
+
intermediate_size = getattr(config, 'mlp_ratio', 4) * config.n_embd
|
| 139 |
+
use_bias = config.use_bias
|
| 140 |
+
mlp_type = config.mlp_type
|
| 141 |
+
if mlp_type == 'swiglu':
|
| 142 |
+
self.c_fc = nn.Linear(config.n_embd, 2 * intermediate_size, bias=use_bias)
|
| 143 |
+
self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias)
|
| 144 |
+
self.activation = None
|
| 145 |
+
else:
|
| 146 |
+
self.c_fc = nn.Linear(config.n_embd, intermediate_size, bias=use_bias)
|
| 147 |
+
self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias)
|
| 148 |
+
act_map = {'gelu': nn.GELU(approximate='tanh'), 'relu': nn.ReLU(), 'silu': nn.SiLU()}
|
| 149 |
+
self.activation = act_map[config.activation]
|
| 150 |
+
self.c_proj.SCALE_INIT = 1
|
| 151 |
+
self.config = config
|
| 152 |
+
|
| 153 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
mlp_type = getattr(self.config, 'mlp_type', 'swiglu')
|
| 155 |
+
if mlp_type == 'swiglu':
|
| 156 |
+
gate_up = self.c_fc(x)
|
| 157 |
+
up, gate = gate_up.chunk(2, dim=-1)
|
| 158 |
+
intermediate = F.silu(gate) * up
|
| 159 |
+
else:
|
| 160 |
+
intermediate = self.c_fc(x)
|
| 161 |
+
intermediate = self.activation(intermediate)
|
| 162 |
+
return self.c_proj(intermediate)
|
| 163 |
+
|
| 164 |
+
# ======================================================================
|
| 165 |
+
# steerling/models/layers/causal_diffusion_layers.py
|
| 166 |
+
# ======================================================================
|
| 167 |
+
|
| 168 |
+
logger = logging.getLogger(__name__)
|
| 169 |
+
try:
|
| 170 |
+
from torch.nn.attention.flex_attention import BlockMask, _dense_to_ordered, flex_attention
|
| 171 |
+
_FLEX_ATTN_AVAILABLE = True
|
| 172 |
+
except ImportError:
|
| 173 |
+
_FLEX_ATTN_AVAILABLE = False
|
| 174 |
+
BlockMask: Any = None
|
| 175 |
+
flex_attention: Any = None
|
| 176 |
+
_dense_to_ordered: Any = None
|
| 177 |
+
if os.environ.get('STEERLING_USE_FLEX_ATTN', '0') != '1':
|
| 178 |
+
_FLEX_ATTN_AVAILABLE = False
|
| 179 |
+
if TYPE_CHECKING:
|
| 180 |
+
from torch.nn.attention.flex_attention import BlockMask as BlockMaskType
|
| 181 |
+
from steerling.configs.causal_diffusion import CausalDiffusionConfig
|
| 182 |
+
if torch.cuda.is_available() and _FLEX_ATTN_AVAILABLE:
|
| 183 |
+
compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
|
| 184 |
+
else:
|
| 185 |
+
compiled_flex_attention = flex_attention
|
| 186 |
+
|
| 187 |
+
def block_causal_mask_mod(b: Any, h: Any, q_idx: torch.Tensor, kv_idx: torch.Tensor, *, block_size: int) -> torch.Tensor:
|
| 188 |
+
"""Block-causal mask: causal across blocks, bidirectional within blocks."""
|
| 189 |
+
return q_idx // block_size >= kv_idx // block_size
|
| 190 |
+
|
| 191 |
+
def fast_create_block_causal_mask(attn_block_size: int, seq_length: int, mask_block_size: int, device: torch.device) -> BlockMaskType:
|
| 192 |
+
"""
|
| 193 |
+
Fast block-causal mask creation for flex_attention.
|
| 194 |
+
|
| 195 |
+
Analytically computes the sparse block structure instead of evaluating
|
| 196 |
+
the mask function at every position.
|
| 197 |
+
"""
|
| 198 |
+
if not _FLEX_ATTN_AVAILABLE or _dense_to_ordered is None or BlockMask is None:
|
| 199 |
+
raise RuntimeError('flex_attention not available')
|
| 200 |
+
num_mask_blocks = -(-seq_length // mask_block_size)
|
| 201 |
+
attn_blocks_per_mask_block, rem = divmod(mask_block_size, attn_block_size)
|
| 202 |
+
if rem != 0:
|
| 203 |
+
raise ValueError(f'mask_block_size ({mask_block_size}) must be divisible by attn_block_size ({attn_block_size})')
|
| 204 |
+
num_attn_blocks = num_mask_blocks * attn_blocks_per_mask_block
|
| 205 |
+
lowres_attn_mask = torch.tril(torch.ones(num_attn_blocks, num_attn_blocks, dtype=torch.bool, device=device))
|
| 206 |
+
block_attn_count = lowres_attn_mask.reshape(num_mask_blocks, attn_blocks_per_mask_block, num_mask_blocks, attn_blocks_per_mask_block).permute(0, 2, 1, 3).sum(dim=[-2, -1])
|
| 207 |
+
max_count = attn_blocks_per_mask_block * attn_blocks_per_mask_block
|
| 208 |
+
full_block_mask = block_attn_count == max_count
|
| 209 |
+
if seq_length % mask_block_size > 0:
|
| 210 |
+
full_block_mask[-1, :] = False
|
| 211 |
+
normal_block_mask = (block_attn_count > 0) & ~full_block_mask
|
| 212 |
+
kv_num_blocks, kv_indices = _dense_to_ordered(normal_block_mask)
|
| 213 |
+
full_kv_num_blocks, full_kv_indices = _dense_to_ordered(full_block_mask)
|
| 214 |
+
q_num_blocks, q_indices = _dense_to_ordered(normal_block_mask.transpose(-2, -1))
|
| 215 |
+
full_q_num_blocks, full_q_indices = _dense_to_ordered(full_block_mask.transpose(-2, -1))
|
| 216 |
+
return BlockMask(seq_lengths=(seq_length, seq_length), kv_num_blocks=kv_num_blocks[None, None, ...], kv_indices=kv_indices[None, None, ...], full_kv_num_blocks=full_kv_num_blocks[None, None, ...], full_kv_indices=full_kv_indices[None, None, ...], q_num_blocks=q_num_blocks[None, None, ...], q_indices=q_indices[None, None, ...], full_q_num_blocks=full_q_num_blocks[None, None, ...], full_q_indices=full_q_indices[None, None, ...], mask_mod=partial(block_causal_mask_mod, block_size=attn_block_size), BLOCK_SIZE=(mask_block_size, mask_block_size))
|
| 217 |
+
|
| 218 |
+
def sdpa_with_block_causal_mask(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, diff_block_size: int, mask_cache: dict[str, torch.Tensor], enable_gqa: bool=False) -> torch.Tensor:
|
| 219 |
+
"""Fallback using SDPA with dense mask when flex_attention unavailable."""
|
| 220 |
+
B, H, T, D = q.shape
|
| 221 |
+
device = q.device
|
| 222 |
+
dtype = q.dtype
|
| 223 |
+
cache_key = f'sdpa_{T}_{device}_{dtype}'
|
| 224 |
+
if cache_key not in mask_cache:
|
| 225 |
+
q_idx = torch.arange(T, device=device).unsqueeze(1)
|
| 226 |
+
kv_idx = torch.arange(T, device=device).unsqueeze(0)
|
| 227 |
+
bool_mask = q_idx // diff_block_size >= kv_idx // diff_block_size
|
| 228 |
+
attn_mask = torch.zeros(T, T, device=device, dtype=dtype)
|
| 229 |
+
attn_mask.masked_fill_(~bool_mask, float('-inf'))
|
| 230 |
+
mask_cache[cache_key] = attn_mask
|
| 231 |
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask_cache[cache_key], dropout_p=0.0, is_causal=False, enable_gqa=enable_gqa)
|
| 232 |
+
|
| 233 |
+
class BlockCausalAttention(nn.Module):
|
| 234 |
+
"""Block-causal self-attention with FlexAttention and optional GQA."""
|
| 235 |
+
FLEX_MASK_BLOCK_SIZE = 128
|
| 236 |
+
|
| 237 |
+
def __init__(self, config: CausalDiffusionConfig) -> None:
|
| 238 |
+
super().__init__()
|
| 239 |
+
if not hasattr(config, 'diff_block_size'):
|
| 240 |
+
raise ValueError("BlockCausalAttention requires 'diff_block_size' in config.")
|
| 241 |
+
assert config.n_embd % config.n_head == 0
|
| 242 |
+
self.config = config
|
| 243 |
+
self.n_head = config.n_head
|
| 244 |
+
self.n_embd = config.n_embd
|
| 245 |
+
self.head_dim = config.n_embd // config.n_head
|
| 246 |
+
n_kv = getattr(config, 'n_kv_heads', None)
|
| 247 |
+
self.n_kv_heads = self.n_head if n_kv is None else int(n_kv)
|
| 248 |
+
if self.n_kv_heads <= 0:
|
| 249 |
+
raise ValueError(f'n_kv_heads must be >= 1 (got {self.n_kv_heads})')
|
| 250 |
+
if self.n_head % self.n_kv_heads != 0:
|
| 251 |
+
raise ValueError(f'n_head ({self.n_head}) must be divisible by n_kv_heads ({self.n_kv_heads})')
|
| 252 |
+
self.kv_repeat = self.n_head // self.n_kv_heads
|
| 253 |
+
use_bias = getattr(config, 'use_bias', False)
|
| 254 |
+
kv_out = self.n_kv_heads * self.head_dim
|
| 255 |
+
attn_out = self.n_embd + 2 * kv_out
|
| 256 |
+
self.c_attn = nn.Linear(config.n_embd, attn_out, bias=use_bias)
|
| 257 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=use_bias)
|
| 258 |
+
self.c_proj.SCALE_INIT = 1
|
| 259 |
+
if getattr(config, 'use_qk_norm', False):
|
| 260 |
+
if getattr(config, 'use_rms_norm', True):
|
| 261 |
+
self.q_norm: nn.Module | None = RMSNorm(config, size=self.head_dim)
|
| 262 |
+
self.k_norm: nn.Module | None = RMSNorm(config, size=self.head_dim)
|
| 263 |
+
else:
|
| 264 |
+
self.q_norm = nn.LayerNorm(self.head_dim)
|
| 265 |
+
self.k_norm = nn.LayerNorm(self.head_dim)
|
| 266 |
+
else:
|
| 267 |
+
self.q_norm = None
|
| 268 |
+
self.k_norm = None
|
| 269 |
+
if getattr(config, 'use_rope', True):
|
| 270 |
+
self.rope: RotaryEmbedding | None = RotaryEmbedding(dim=self.head_dim, max_seq_len=config.block_size, base=getattr(config, 'rope_base', 500000.0), rope_full_precision=getattr(config, 'rope_full_precision', True))
|
| 271 |
+
else:
|
| 272 |
+
self.rope = None
|
| 273 |
+
self._mask_cache: dict = {}
|
| 274 |
+
self._sdpa_mask_cache: dict[str, torch.Tensor] = {}
|
| 275 |
+
self._logged_attention_mode = False
|
| 276 |
+
|
| 277 |
+
def _get_block_mask(self, T: int, device: torch.device):
|
| 278 |
+
cache_key = f'flex_{T}_{device}'
|
| 279 |
+
if cache_key not in self._mask_cache:
|
| 280 |
+
diff_block_size = self.config.diff_block_size
|
| 281 |
+
mask_block_size = self.FLEX_MASK_BLOCK_SIZE
|
| 282 |
+
if mask_block_size % diff_block_size != 0:
|
| 283 |
+
mask_block_size = diff_block_size * (mask_block_size // diff_block_size)
|
| 284 |
+
if mask_block_size == 0:
|
| 285 |
+
mask_block_size = diff_block_size
|
| 286 |
+
self._mask_cache[cache_key] = fast_create_block_causal_mask(attn_block_size=diff_block_size, seq_length=T, mask_block_size=mask_block_size, device=device)
|
| 287 |
+
return self._mask_cache[cache_key]
|
| 288 |
+
|
| 289 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 290 |
+
B, T, C = x.size()
|
| 291 |
+
device = x.device
|
| 292 |
+
use_flex = _FLEX_ATTN_AVAILABLE and x.is_cuda and (flex_attention is not None)
|
| 293 |
+
if not self._logged_attention_mode:
|
| 294 |
+
self._logged_attention_mode = True
|
| 295 |
+
mode = 'flex_attention' if use_flex else 'SDPA fallback'
|
| 296 |
+
logger.debug(f'[CausalDiffusion] Using {mode} with GQA (n_head={self.n_head}, n_kv_heads={self.n_kv_heads})')
|
| 297 |
+
qkv = self.c_attn(x)
|
| 298 |
+
clip_qkv = getattr(self.config, 'clip_qkv', None)
|
| 299 |
+
if clip_qkv is not None:
|
| 300 |
+
qkv = qkv.clamp(min=-clip_qkv, max=clip_qkv)
|
| 301 |
+
kv_dim = self.n_kv_heads * self.head_dim
|
| 302 |
+
q, k, v = qkv.split([self.n_embd, kv_dim, kv_dim], dim=2)
|
| 303 |
+
q = q.reshape(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 304 |
+
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 305 |
+
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 306 |
+
if self.q_norm is not None and self.k_norm is not None:
|
| 307 |
+
q = self.q_norm(q)
|
| 308 |
+
k = self.k_norm(k)
|
| 309 |
+
if self.rope is not None:
|
| 310 |
+
q, k = self.rope(q, k)
|
| 311 |
+
if use_flex:
|
| 312 |
+
block_mask = self._get_block_mask(T, device)
|
| 313 |
+
assert flex_attention is not None and compiled_flex_attention is not None
|
| 314 |
+
if q.is_cuda:
|
| 315 |
+
y = compiled_flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True)
|
| 316 |
+
else:
|
| 317 |
+
y = flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True)
|
| 318 |
+
else:
|
| 319 |
+
y = sdpa_with_block_causal_mask(q, k, v, diff_block_size=self.config.diff_block_size, mask_cache=self._sdpa_mask_cache, enable_gqa=True)
|
| 320 |
+
y = y.transpose(1, 2).reshape(B, T, C)
|
| 321 |
+
y = self.c_proj(y)
|
| 322 |
+
return y
|
| 323 |
+
|
| 324 |
+
class CausalDiffusionBlock(nn.Module):
|
| 325 |
+
"""Transformer block for CausalDiffusionLM (block-causal attention + MLP)."""
|
| 326 |
+
|
| 327 |
+
def __init__(self, config: CausalDiffusionConfig) -> None:
|
| 328 |
+
super().__init__()
|
| 329 |
+
use_rms_norm = getattr(config, 'use_rms_norm', True)
|
| 330 |
+
if use_rms_norm:
|
| 331 |
+
self.ln_1: nn.Module = RMSNorm(config)
|
| 332 |
+
self.ln_2: nn.Module = RMSNorm(config)
|
| 333 |
+
else:
|
| 334 |
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
| 335 |
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
| 336 |
+
self.norm_order = getattr(config, 'norm_order', 'post')
|
| 337 |
+
self.attn = BlockCausalAttention(config)
|
| 338 |
+
self.mlp = MLP(config)
|
| 339 |
+
|
| 340 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 341 |
+
if self.norm_order == 'pre':
|
| 342 |
+
x = x + self.attn(self.ln_1(x))
|
| 343 |
+
x = x + self.mlp(self.ln_2(x))
|
| 344 |
+
else:
|
| 345 |
+
x = x + self.ln_1(self.attn(x))
|
| 346 |
+
x = x + self.ln_2(self.mlp(x))
|
| 347 |
+
return x
|
| 348 |
+
|
| 349 |
+
# ======================================================================
|
| 350 |
+
# steerling/models/causal_diffusion.py
|
| 351 |
+
# ======================================================================
|
| 352 |
+
|
| 353 |
+
class CausalDiffusionLM(nn.Module):
|
| 354 |
+
"""
|
| 355 |
+
CausalDiffusionLM transformer backbone with block-causal attention.
|
| 356 |
+
|
| 357 |
+
Pure compute graph — no training code, no loss logic.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
config: CausalDiffusionConfig with model hyperparameters
|
| 361 |
+
vocab_size: Vocabulary size (including special tokens)
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
def __init__(self, config: CausalDiffusionConfig, vocab_size: int) -> None:
|
| 365 |
+
super().__init__()
|
| 366 |
+
self.config = config
|
| 367 |
+
self.vocab_size = vocab_size
|
| 368 |
+
self.tok_emb = nn.Embedding(vocab_size, config.n_embd)
|
| 369 |
+
self.blocks = nn.ModuleList([CausalDiffusionBlock(config) for _ in range(config.n_layers)])
|
| 370 |
+
if config.use_rms_norm:
|
| 371 |
+
self.ln_f: nn.Module = RMSNorm(config)
|
| 372 |
+
else:
|
| 373 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
| 374 |
+
self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False)
|
| 375 |
+
if config.weight_sharing:
|
| 376 |
+
self.tok_emb.weight = self.lm_head.weight
|
| 377 |
+
|
| 378 |
+
def forward(self, input_ids: torch.Tensor, *, input_embeds: torch.Tensor | None=None, return_hidden: bool=False) -> torch.Tensor:
|
| 379 |
+
"""
|
| 380 |
+
Forward pass.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
input_ids: Token indices [B, T] (may contain mask tokens)
|
| 384 |
+
input_embeds: Pre-computed embeddings [B, T, D]. If provided, input_ids is ignored.
|
| 385 |
+
return_hidden: If True, return hidden states before lm_head.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
logits [B, T, vocab_size] or hidden_states [B, T, n_embd]
|
| 389 |
+
"""
|
| 390 |
+
if input_embeds is not None:
|
| 391 |
+
x = input_embeds
|
| 392 |
+
elif input_ids is not None:
|
| 393 |
+
x = self.tok_emb(input_ids)
|
| 394 |
+
else:
|
| 395 |
+
raise ValueError('Either input_ids or input_embeds must be provided')
|
| 396 |
+
for block in self.blocks:
|
| 397 |
+
x = block(x)
|
| 398 |
+
x = self.ln_f(x)
|
| 399 |
+
if return_hidden:
|
| 400 |
+
return x
|
| 401 |
+
return self.lm_head(x)
|
| 402 |
+
|
| 403 |
+
def get_num_params(self, non_embedding: bool=True) -> int:
|
| 404 |
+
"""Return number of parameters."""
|
| 405 |
+
n_params = sum((p.numel() for p in self.parameters()))
|
| 406 |
+
if non_embedding:
|
| 407 |
+
n_params -= self.tok_emb.weight.numel()
|
| 408 |
+
return n_params
|
| 409 |
+
|
| 410 |
+
def _restore_weight_tying(self) -> None:
|
| 411 |
+
"""Re-establish weight tying after to_empty() or device transfer."""
|
| 412 |
+
if self.config.weight_sharing:
|
| 413 |
+
self.tok_emb.weight = self.lm_head.weight
|
| 414 |
+
|
| 415 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 416 |
+
"""Initialize model weights (used for fresh models, not loaded checkpoints)."""
|
| 417 |
+
if isinstance(module, nn.Linear):
|
| 418 |
+
std = 0.02
|
| 419 |
+
if hasattr(module, 'SCALE_INIT'):
|
| 420 |
+
std *= (2 * self.config.n_layers) ** (-0.5)
|
| 421 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 422 |
+
if module.bias is not None:
|
| 423 |
+
torch.nn.init.zeros_(module.bias)
|
| 424 |
+
elif isinstance(module, nn.Embedding):
|
| 425 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 426 |
+
elif isinstance(module, RMSNorm):
|
| 427 |
+
torch.nn.init.ones_(module.weight)
|
| 428 |
+
|
| 429 |
+
# ======================================================================
|
| 430 |
+
# steerling/models/interpretable/outputs.py
|
| 431 |
+
# ======================================================================
|
| 432 |
+
|
| 433 |
+
@dataclass
|
| 434 |
+
class InterpretableOutput:
|
| 435 |
+
"""
|
| 436 |
+
Full output from InterpretableCausalDiffusionLM; it contains all decomposition components for attribution and analysis.
|
| 437 |
+
"""
|
| 438 |
+
hidden: Tensor
|
| 439 |
+
known_features: Tensor
|
| 440 |
+
known_logits: Tensor | None
|
| 441 |
+
known_gt_features: Tensor | None
|
| 442 |
+
known_predicted: Tensor
|
| 443 |
+
known_weights: Tensor | None
|
| 444 |
+
known_topk_indices: Tensor | None
|
| 445 |
+
known_topk_logits: Tensor | None
|
| 446 |
+
unk: Tensor
|
| 447 |
+
unk_hat: Tensor | None
|
| 448 |
+
unk_for_lm: Tensor
|
| 449 |
+
unknown_logits: Tensor | None
|
| 450 |
+
unknown_weights: Tensor | None
|
| 451 |
+
unknown_topk_indices: Tensor | None
|
| 452 |
+
unknown_topk_logits: Tensor | None
|
| 453 |
+
composed: Tensor
|
| 454 |
+
epsilon: Tensor | None
|
| 455 |
+
epsilon_true: Tensor | None
|
| 456 |
+
|
| 457 |
+
# ======================================================================
|
| 458 |
+
# steerling/models/interpretable/concept_head.py
|
| 459 |
+
# ======================================================================
|
| 460 |
+
|
| 461 |
+
logger = logging.getLogger(__name__)
|
| 462 |
+
LARGE_CONCEPT_THRESHOLD = 50000
|
| 463 |
+
|
| 464 |
+
@dataclass
|
| 465 |
+
class ConceptHeadOutput:
|
| 466 |
+
"""Output from ConceptHead forward pass.
|
| 467 |
+
|
| 468 |
+
Attributes:
|
| 469 |
+
features: Final concept features after teacher forcing/intervention (B, T, D)
|
| 470 |
+
gt_features: Ground truth pooled features. None for unknown heads. (B, T, D) or None
|
| 471 |
+
logits: Full concept logits (B, T, C). Only set if return_logits=True. Usually None.
|
| 472 |
+
predicted: Predicted features before teacher forcing mixing (B, T, D)
|
| 473 |
+
weights: Full concept weights (B, T, C). Only set if return_logits=True. Usually None.
|
| 474 |
+
topk_indices: Top-k concept indices (B, T, k). Set when using streaming top-k.
|
| 475 |
+
topk_logits: Logits for top-k concepts (B, T, k). Set when using streaming top-k.
|
| 476 |
+
hidden: Hidden states passed to this head (B, T, D). Stored for attribution.
|
| 477 |
+
"""
|
| 478 |
+
features: Tensor
|
| 479 |
+
gt_features: Tensor | None
|
| 480 |
+
logits: Tensor | None
|
| 481 |
+
predicted: Tensor
|
| 482 |
+
weights: Tensor | None = None
|
| 483 |
+
topk_indices: Tensor | None = None
|
| 484 |
+
topk_logits: Tensor | None = None
|
| 485 |
+
hidden: Tensor | None = None
|
| 486 |
+
|
| 487 |
+
class ConceptHead(nn.Module):
|
| 488 |
+
"""
|
| 489 |
+
Concept decomposition head supporting both known and unknown concepts.
|
| 490 |
+
Memory-efficient implementation that avoids (B, T, C) allocations by default.
|
| 491 |
+
|
| 492 |
+
Modes:
|
| 493 |
+
- Known (is_unknown=False): Supports GT, teacher forcing, top-k, interventions
|
| 494 |
+
- Unknown (is_unknown=True): No GT, no teacher forcing
|
| 495 |
+
|
| 496 |
+
Architectures:
|
| 497 |
+
- use_attention=False: Linear predictor (n_embd -> n_concepts)
|
| 498 |
+
- use_attention=True: Query projection + sigmoid attention over embeddings
|
| 499 |
+
|
| 500 |
+
Factorization (for large unknown heads):
|
| 501 |
+
- factorize=False: Dense embeddings (C, D) and predictor (D, C)
|
| 502 |
+
- factorize=True: Factorized embeddings (C, r) @ (r, D) where r << D
|
| 503 |
+
Reduces memory by ~10-20x for large C
|
| 504 |
+
|
| 505 |
+
Memory Safety:
|
| 506 |
+
- Unknown heads with n_concepts > 50k cannot use dense operations
|
| 507 |
+
- Interventions are only supported for known heads
|
| 508 |
+
- return_logits=True is forbidden for large unknown heads
|
| 509 |
+
- All tensor indexing uses F.embedding for DTensor safety
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
n_concepts: Number of concepts (C)
|
| 513 |
+
concept_dim: Dimension of concept embeddings (should equal n_embd)
|
| 514 |
+
n_embd: Model hidden dimension
|
| 515 |
+
is_unknown: If True, skip GT pooling and teacher forcing
|
| 516 |
+
use_attention: If True, use attention; else use linear predictor
|
| 517 |
+
topk: Top-k sparsity for concept weights. None = no sparsity.
|
| 518 |
+
block_size: Block size for memory-efficient operations
|
| 519 |
+
pad_multiple: Pad n_concepts to a multiple of this for efficiency
|
| 520 |
+
store_unknown_weights: If True and use_attention & is_unknown, store logits/weights
|
| 521 |
+
apply_topk_to_unknown: If True, also apply top-k to unknown concepts
|
| 522 |
+
topk_on_logits: If True, apply top-k on logits (then sigmoid). If False, on weights.
|
| 523 |
+
teacher_force_alpha: If None, hard TF. If in [0,1], soft mixing.
|
| 524 |
+
factorize: If True, use low-rank factorized embeddings
|
| 525 |
+
factorize_rank: Rank for factorization (r). Lower = less memory, less expressivity.
|
| 526 |
+
"""
|
| 527 |
+
|
| 528 |
+
class ConceptPooling(nn.Module):
|
| 529 |
+
"""Memory-efficient sum pooling using scatter-add."""
|
| 530 |
+
|
| 531 |
+
def __init__(self, concept_dim: int):
|
| 532 |
+
super().__init__()
|
| 533 |
+
self.concept_dim = concept_dim
|
| 534 |
+
|
| 535 |
+
def forward(self, concept_ids: Tensor, concept_mask: Tensor, concept_embeddings: nn.Embedding) -> Tensor:
|
| 536 |
+
"""
|
| 537 |
+
Pool concept embeddings based on ground truth IDs.
|
| 538 |
+
Uses scatter-add to avoid (B, T, K, D) allocation when K is sparse.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
concept_ids: (B, T, K) concept indices, -1 for invalid
|
| 542 |
+
concept_mask: (B, T, K) boolean mask for valid concepts
|
| 543 |
+
concept_embeddings: Embedding layer to look up
|
| 544 |
+
|
| 545 |
+
Returns:
|
| 546 |
+
Pooled features (B, T, D)
|
| 547 |
+
"""
|
| 548 |
+
B, T, K = concept_ids.shape
|
| 549 |
+
D = concept_embeddings.embedding_dim
|
| 550 |
+
device = concept_ids.device
|
| 551 |
+
valid_mask = concept_mask & (concept_ids != -1)
|
| 552 |
+
pooled = torch.zeros(B, T, D, device=device, dtype=concept_embeddings.weight.dtype)
|
| 553 |
+
if not valid_mask.any():
|
| 554 |
+
return pooled
|
| 555 |
+
b_idx, t_idx, k_idx = torch.where(valid_mask)
|
| 556 |
+
c_ids = concept_ids[b_idx, t_idx, k_idx].long()
|
| 557 |
+
emb = concept_embeddings(c_ids)
|
| 558 |
+
flat_idx = b_idx * T + t_idx
|
| 559 |
+
flat_idx = flat_idx.unsqueeze(-1).expand(-1, D)
|
| 560 |
+
pooled_flat = pooled.view(B * T, D)
|
| 561 |
+
pooled_flat.scatter_add_(0, flat_idx, emb)
|
| 562 |
+
return pooled.view(B, T, D)
|
| 563 |
+
|
| 564 |
+
def __init__(self, n_concepts: int, concept_dim: int, n_embd: int, is_unknown: bool=False, use_attention: bool=False, topk: int | None=16, topk_features: int | None=None, block_size: int=8192, *, pad_multiple: int=16, store_unknown_weights: bool=False, apply_topk_to_unknown: bool=False, topk_on_logits: bool=False, factorize: bool=False, factorize_rank: int=256):
|
| 565 |
+
super().__init__()
|
| 566 |
+
self.n_concepts = n_concepts
|
| 567 |
+
self.concept_dim = concept_dim
|
| 568 |
+
self.n_embd = n_embd
|
| 569 |
+
self.is_unknown = is_unknown
|
| 570 |
+
self.use_attention = use_attention
|
| 571 |
+
self.topk = topk
|
| 572 |
+
self.topk_features = topk_features if topk_features is not None else topk
|
| 573 |
+
self.block_size = block_size
|
| 574 |
+
self.pad_multiple = pad_multiple
|
| 575 |
+
self.store_unknown_weights = store_unknown_weights
|
| 576 |
+
self.apply_topk_to_unknown = apply_topk_to_unknown
|
| 577 |
+
self.topk_on_logits = topk_on_logits
|
| 578 |
+
self.factorize = factorize
|
| 579 |
+
self.factorize_rank = factorize_rank
|
| 580 |
+
self._is_large = n_concepts > LARGE_CONCEPT_THRESHOLD
|
| 581 |
+
self.n_concepts_padded = (n_concepts + pad_multiple - 1) // pad_multiple * pad_multiple
|
| 582 |
+
if factorize:
|
| 583 |
+
self.embedding_coef = nn.Embedding(self.n_concepts_padded, factorize_rank)
|
| 584 |
+
self.embedding_basis = nn.Linear(factorize_rank, concept_dim, bias=False)
|
| 585 |
+
self.concept_embedding = None
|
| 586 |
+
if not use_attention:
|
| 587 |
+
self.predictor_down = nn.Linear(n_embd, factorize_rank, bias=False)
|
| 588 |
+
self.predictor_up = nn.Linear(factorize_rank, self.n_concepts_padded, bias=False)
|
| 589 |
+
self.concept_predictor = None
|
| 590 |
+
else:
|
| 591 |
+
self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False)
|
| 592 |
+
self.predictor_down = None
|
| 593 |
+
self.predictor_up = None
|
| 594 |
+
self.concept_predictor = None
|
| 595 |
+
dense_params = n_concepts * concept_dim * 2
|
| 596 |
+
factorized_params = n_concepts * factorize_rank + factorize_rank * concept_dim + (n_embd * factorize_rank + factorize_rank * n_concepts if not use_attention else 0)
|
| 597 |
+
logger.info(f'[ConceptHead] Factorized mode: {n_concepts} concepts, rank={factorize_rank}')
|
| 598 |
+
logger.info(f'[ConceptHead] Memory: {dense_params * 2 / 1000000000.0:.2f} GB (dense) -> {factorized_params * 2 / 1000000000.0:.2f} GB (factorized) = {(1 - factorized_params / dense_params) * 100:.1f}% reduction')
|
| 599 |
+
else:
|
| 600 |
+
self.concept_embedding = nn.Embedding(self.n_concepts_padded, concept_dim)
|
| 601 |
+
self.embedding_coef = None
|
| 602 |
+
self.embedding_basis = None
|
| 603 |
+
if use_attention:
|
| 604 |
+
self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False)
|
| 605 |
+
self.concept_predictor = None
|
| 606 |
+
else:
|
| 607 |
+
self.concept_predictor = nn.Linear(n_embd, self.n_concepts_padded, bias=False)
|
| 608 |
+
self.predictor_down = None
|
| 609 |
+
self.predictor_up = None
|
| 610 |
+
self.concept_pooling = self.ConceptPooling(concept_dim)
|
| 611 |
+
if self.topk_features != self.topk:
|
| 612 |
+
logger.info(f"[ConceptHead] {('Unknown' if is_unknown else 'Known')} head: topk={self.topk} (loss), topk_features={self.topk_features} (features)")
|
| 613 |
+
if is_unknown and apply_topk_to_unknown:
|
| 614 |
+
logger.info(f'[ConceptHead] Unknown head: apply_topk_to_unknown=True, topk={self.topk}')
|
| 615 |
+
self._init_weights()
|
| 616 |
+
|
| 617 |
+
def _init_weights(self):
|
| 618 |
+
"""Initialize weights with small values."""
|
| 619 |
+
if self.factorize:
|
| 620 |
+
nn.init.normal_(self.embedding_coef.weight, mean=0.0, std=0.02)
|
| 621 |
+
nn.init.normal_(self.embedding_basis.weight, mean=0.0, std=0.02)
|
| 622 |
+
if self.predictor_down is not None:
|
| 623 |
+
nn.init.normal_(self.predictor_down.weight, mean=0.0, std=0.02)
|
| 624 |
+
if self.predictor_up is not None:
|
| 625 |
+
nn.init.normal_(self.predictor_up.weight, mean=0.0, std=0.02)
|
| 626 |
+
else:
|
| 627 |
+
if self.concept_embedding is not None:
|
| 628 |
+
nn.init.normal_(self.concept_embedding.weight, mean=0.0, std=0.02)
|
| 629 |
+
if self.concept_predictor is not None:
|
| 630 |
+
nn.init.normal_(self.concept_predictor.weight, mean=0.0, std=0.02)
|
| 631 |
+
if hasattr(self, 'concept_query_projection') and self.concept_query_projection is not None:
|
| 632 |
+
nn.init.normal_(self.concept_query_projection.weight, mean=0.0, std=0.02)
|
| 633 |
+
|
| 634 |
+
def _check_dense_allowed(self, operation: str) -> None:
|
| 635 |
+
"""Raise error if dense operations are requested for large unknown heads."""
|
| 636 |
+
if self.is_unknown and self._is_large:
|
| 637 |
+
raise ValueError(f'{operation} requested for unknown head with {self.n_concepts} concepts. This would allocate multi-GB tensors. Use streaming mode instead. (Threshold: {LARGE_CONCEPT_THRESHOLD})')
|
| 638 |
+
|
| 639 |
+
@staticmethod
|
| 640 |
+
def _safe_index(weight: Tensor, indices: Tensor) -> Tensor:
|
| 641 |
+
"""
|
| 642 |
+
DTensor-safe indexing using F.embedding.
|
| 643 |
+
|
| 644 |
+
Replaces weight[indices] which crashes under FSDP2/DTensor.
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
weight: (N, D) weight matrix
|
| 648 |
+
indices: (...) indices to select
|
| 649 |
+
|
| 650 |
+
Returns:
|
| 651 |
+
(..., D) selected embeddings
|
| 652 |
+
"""
|
| 653 |
+
original_shape = indices.shape
|
| 654 |
+
flat_indices = indices.reshape(-1)
|
| 655 |
+
flat_result = F.embedding(flat_indices, weight)
|
| 656 |
+
return flat_result.reshape(*original_shape, -1)
|
| 657 |
+
|
| 658 |
+
def _get_embedding_weight(self) -> Tensor:
|
| 659 |
+
"""
|
| 660 |
+
Get full embedding matrix.
|
| 661 |
+
|
| 662 |
+
For dense: returns concept_embedding.weight
|
| 663 |
+
For factorized: computes coef @ basis (materializes full matrix)
|
| 664 |
+
|
| 665 |
+
Returns:
|
| 666 |
+
(C, D) embedding matrix
|
| 667 |
+
"""
|
| 668 |
+
if self.concept_embedding is not None:
|
| 669 |
+
return self.concept_embedding.weight
|
| 670 |
+
else:
|
| 671 |
+
return self.embedding_basis(self.embedding_coef.weight)
|
| 672 |
+
|
| 673 |
+
def _get_embedding(self, indices: Tensor) -> Tensor:
|
| 674 |
+
"""
|
| 675 |
+
Get embeddings for specific indices (DTensor-safe).
|
| 676 |
+
|
| 677 |
+
For dense: uses F.embedding
|
| 678 |
+
For factorized: looks up coef, then applies basis
|
| 679 |
+
|
| 680 |
+
Args:
|
| 681 |
+
indices: (...) concept indices
|
| 682 |
+
|
| 683 |
+
Returns:
|
| 684 |
+
(..., D) embeddings
|
| 685 |
+
"""
|
| 686 |
+
if self.concept_embedding is not None:
|
| 687 |
+
return self.concept_embedding(indices)
|
| 688 |
+
else:
|
| 689 |
+
coef = self.embedding_coef(indices)
|
| 690 |
+
return self.embedding_basis(coef)
|
| 691 |
+
|
| 692 |
+
def _get_predictor_weight(self) -> Tensor | None:
|
| 693 |
+
"""
|
| 694 |
+
Get full predictor weight matrix (for linear path only).
|
| 695 |
+
|
| 696 |
+
Returns:
|
| 697 |
+
(C, D) predictor weight, or None if using attention
|
| 698 |
+
"""
|
| 699 |
+
if self.concept_predictor is not None:
|
| 700 |
+
return self.concept_predictor.weight
|
| 701 |
+
elif self.predictor_down is not None and self.predictor_up is not None:
|
| 702 |
+
return self.predictor_up.weight @ self.predictor_down.weight
|
| 703 |
+
else:
|
| 704 |
+
return None
|
| 705 |
+
|
| 706 |
+
@staticmethod
|
| 707 |
+
def _merge_topk(topv: Tensor, topi: Tensor, v_blk: Tensor, i_blk: Tensor, k: int) -> tuple[Tensor, Tensor]:
|
| 708 |
+
"""Efficient merge of two top-k sets. Memory: O(BT × 2k)."""
|
| 709 |
+
cand_v = torch.cat([topv, v_blk], dim=1)
|
| 710 |
+
cand_i = torch.cat([topi, i_blk], dim=1)
|
| 711 |
+
new_v, sel = torch.topk(cand_v, k, dim=1)
|
| 712 |
+
new_i = torch.gather(cand_i, 1, sel)
|
| 713 |
+
return (new_v, new_i)
|
| 714 |
+
|
| 715 |
+
@staticmethod
|
| 716 |
+
def linear_block_features(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor:
|
| 717 |
+
"""
|
| 718 |
+
Memory-efficient linear prediction without materializing (B, T, C).
|
| 719 |
+
|
| 720 |
+
Args:
|
| 721 |
+
hidden: (B, T, D)
|
| 722 |
+
predictor_weight: (C, D)
|
| 723 |
+
embeddings: (C, D)
|
| 724 |
+
block_size: Concepts per block
|
| 725 |
+
|
| 726 |
+
Returns:
|
| 727 |
+
Features (B, T, D)
|
| 728 |
+
"""
|
| 729 |
+
B, T, D = hidden.shape
|
| 730 |
+
C = predictor_weight.size(0)
|
| 731 |
+
output = torch.zeros(B, T, D, dtype=hidden.dtype, device=hidden.device)
|
| 732 |
+
flat_h = hidden.reshape(-1, D)
|
| 733 |
+
W_t = predictor_weight.t().contiguous()
|
| 734 |
+
for start in range(0, C, block_size):
|
| 735 |
+
end = min(start + block_size, C)
|
| 736 |
+
logits_block = (flat_h @ W_t[:, start:end]).to(torch.float32)
|
| 737 |
+
logits_block = logits_block.clamp(-15, 15)
|
| 738 |
+
weights_block = torch.sigmoid(logits_block)
|
| 739 |
+
E_block = embeddings[start:end].to(weights_block.dtype)
|
| 740 |
+
output.add_((weights_block @ E_block).reshape(B, T, D))
|
| 741 |
+
return output.to(hidden.dtype)
|
| 742 |
+
|
| 743 |
+
@staticmethod
|
| 744 |
+
def attention_block_features(query: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor:
|
| 745 |
+
"""Memory-efficient attention features without materializing (B, T, C)."""
|
| 746 |
+
B, T, D = query.shape
|
| 747 |
+
C = embeddings.shape[0]
|
| 748 |
+
scale = 1.0 / math.sqrt(D)
|
| 749 |
+
flat_q = query.reshape(-1, D)
|
| 750 |
+
emb_T = embeddings.t().contiguous()
|
| 751 |
+
output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device)
|
| 752 |
+
for start in range(0, C, block_size):
|
| 753 |
+
end = min(start + block_size, C)
|
| 754 |
+
scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale
|
| 755 |
+
scores = scores.clamp(-15, 15)
|
| 756 |
+
weights = torch.sigmoid(scores)
|
| 757 |
+
output.add_(weights @ embeddings[start:end].to(weights.dtype))
|
| 758 |
+
return output.reshape(B, T, D).to(query.dtype)
|
| 759 |
+
|
| 760 |
+
@staticmethod
|
| 761 |
+
def linear_features_topk_streaming(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]:
|
| 762 |
+
"""
|
| 763 |
+
Memory-efficient linear prediction with streaming top-k.
|
| 764 |
+
|
| 765 |
+
Uses merge-k-with-k to keep memory O(BT × k), not O(BT × block_size).
|
| 766 |
+
|
| 767 |
+
Args:
|
| 768 |
+
hidden: (B, T, D)
|
| 769 |
+
predictor_weight: (C, D)
|
| 770 |
+
embeddings: (C, D)
|
| 771 |
+
k: Number of top concepts
|
| 772 |
+
block_size: Concepts per block
|
| 773 |
+
topk_on_logits: If True, select top-k by logits; else by sigmoid
|
| 774 |
+
|
| 775 |
+
Returns:
|
| 776 |
+
features: (B, T, D) weighted concept features
|
| 777 |
+
topk_indices: (B, T, k) indices of top-k concepts
|
| 778 |
+
topk_logits: (B, T, k) logits for top-k concepts
|
| 779 |
+
"""
|
| 780 |
+
B, T, D = hidden.shape
|
| 781 |
+
C = predictor_weight.size(0)
|
| 782 |
+
BT = B * T
|
| 783 |
+
device = hidden.device
|
| 784 |
+
k = min(k, C)
|
| 785 |
+
flat_h = hidden.reshape(BT, D)
|
| 786 |
+
W_t = predictor_weight.t().contiguous()
|
| 787 |
+
topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype)
|
| 788 |
+
topi = torch.zeros((BT, k), device=device, dtype=torch.long)
|
| 789 |
+
for start in range(0, C, block_size):
|
| 790 |
+
end = min(start + block_size, C)
|
| 791 |
+
logits_blk = (flat_h @ W_t[:, start:end]).to(torch.float32).clamp_(-15, 15)
|
| 792 |
+
vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk)
|
| 793 |
+
blk_k = min(k, end - start)
|
| 794 |
+
v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1)
|
| 795 |
+
i_blk = idx_blk + start
|
| 796 |
+
if blk_k < k:
|
| 797 |
+
pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
|
| 798 |
+
pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
|
| 799 |
+
v_blk = torch.cat([v_blk, pad_v], dim=1)
|
| 800 |
+
i_blk = torch.cat([i_blk, pad_i], dim=1)
|
| 801 |
+
topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k)
|
| 802 |
+
W_sel = ConceptHead._safe_index(predictor_weight, topi)
|
| 803 |
+
logits_sel = torch.einsum('bd,bkd->bk', flat_h.to(torch.float32), W_sel.to(torch.float32))
|
| 804 |
+
logits_sel = logits_sel.clamp(-15, 15)
|
| 805 |
+
del W_sel
|
| 806 |
+
weights_sel = torch.sigmoid(logits_sel)
|
| 807 |
+
E_sel = ConceptHead._safe_index(embeddings, topi)
|
| 808 |
+
features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel)
|
| 809 |
+
return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
|
| 810 |
+
|
| 811 |
+
@staticmethod
|
| 812 |
+
def attention_features_topk_streaming(query: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]:
|
| 813 |
+
"""Memory-efficient attention with streaming top-k."""
|
| 814 |
+
B, T, D = query.shape
|
| 815 |
+
C = embeddings.shape[0]
|
| 816 |
+
BT = B * T
|
| 817 |
+
device = query.device
|
| 818 |
+
scale = 1.0 / math.sqrt(D)
|
| 819 |
+
k = min(k, C)
|
| 820 |
+
flat_q = query.reshape(BT, D)
|
| 821 |
+
emb_T = embeddings.t().contiguous()
|
| 822 |
+
topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype)
|
| 823 |
+
topi = torch.zeros((BT, k), device=device, dtype=torch.long)
|
| 824 |
+
for start in range(0, C, block_size):
|
| 825 |
+
end = min(start + block_size, C)
|
| 826 |
+
logits_blk = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale
|
| 827 |
+
logits_blk = logits_blk.clamp(-15, 15)
|
| 828 |
+
vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk)
|
| 829 |
+
blk_k = min(k, end - start)
|
| 830 |
+
v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1)
|
| 831 |
+
i_blk = idx_blk + start
|
| 832 |
+
if blk_k < k:
|
| 833 |
+
pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
|
| 834 |
+
pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
|
| 835 |
+
v_blk = torch.cat([v_blk, pad_v], dim=1)
|
| 836 |
+
i_blk = torch.cat([i_blk, pad_i], dim=1)
|
| 837 |
+
topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k)
|
| 838 |
+
E_sel = ConceptHead._safe_index(embeddings, topi)
|
| 839 |
+
logits_sel = torch.einsum('bd,bkd->bk', flat_q.to(torch.float32), E_sel.to(torch.float32)) * scale
|
| 840 |
+
logits_sel = logits_sel.clamp(-15, 15)
|
| 841 |
+
weights_sel = torch.sigmoid(logits_sel)
|
| 842 |
+
features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel)
|
| 843 |
+
return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
|
| 844 |
+
|
| 845 |
+
def attention_block_features_factorized(self, query: Tensor, block_size: int=4096) -> Tensor:
|
| 846 |
+
"""
|
| 847 |
+
Memory-efficient factorized attention over ALL concepts.
|
| 848 |
+
|
| 849 |
+
Uses factorized scoring and feature computation:
|
| 850 |
+
- Scoring: (query @ basis.T) @ coef.T instead of query @ E.T
|
| 851 |
+
- Features: (weights @ coef) @ basis instead of weights @ E
|
| 852 |
+
|
| 853 |
+
FLOPs: O(BT * r * (D + C)) instead of O(BT * D * C)
|
| 854 |
+
|
| 855 |
+
Args:
|
| 856 |
+
query: (B, T, D) query vectors from concept_query_projection
|
| 857 |
+
block_size: Concepts per block for chunked processing
|
| 858 |
+
|
| 859 |
+
Returns:
|
| 860 |
+
(B, T, D) weighted concept features
|
| 861 |
+
"""
|
| 862 |
+
assert self.factorize, 'Only valid for factorized head'
|
| 863 |
+
B, T, D = query.shape
|
| 864 |
+
BT = B * T
|
| 865 |
+
C = self.n_concepts
|
| 866 |
+
_ = self.factorize_rank
|
| 867 |
+
device = query.device
|
| 868 |
+
scale = 1.0 / math.sqrt(D)
|
| 869 |
+
flat_q = query.reshape(BT, D)
|
| 870 |
+
coef = self.embedding_coef.weight[:C]
|
| 871 |
+
basis_weight = self.embedding_basis.weight
|
| 872 |
+
q_compressed = flat_q @ basis_weight
|
| 873 |
+
output = torch.zeros(BT, D, dtype=query.dtype, device=device)
|
| 874 |
+
_ = (C + block_size - 1) // block_size
|
| 875 |
+
for _block_idx, start in enumerate(range(0, C, block_size)):
|
| 876 |
+
end = min(start + block_size, C)
|
| 877 |
+
coef_chunk = coef[start:end]
|
| 878 |
+
scores_chunk = (q_compressed @ coef_chunk.T).float() * scale
|
| 879 |
+
scores_chunk = scores_chunk.clamp(-15, 15)
|
| 880 |
+
weights_chunk = torch.sigmoid(scores_chunk)
|
| 881 |
+
weighted_coef = weights_chunk @ coef_chunk.float()
|
| 882 |
+
features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
|
| 883 |
+
output.add_(features_chunk)
|
| 884 |
+
return output.reshape(B, T, D).to(query.dtype)
|
| 885 |
+
|
| 886 |
+
def attention_features_topk_factorized(self, query: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]:
|
| 887 |
+
"""
|
| 888 |
+
Memory-efficient factorized attention with streaming top-k.
|
| 889 |
+
|
| 890 |
+
Pass 1: Find top-k concepts using factorized scoring
|
| 891 |
+
Pass 2: Compute features using only top-k embeddings
|
| 892 |
+
|
| 893 |
+
Args:
|
| 894 |
+
query: (B, T, D) query vectors
|
| 895 |
+
k: Number of top concepts per token
|
| 896 |
+
block_size: Concepts per block
|
| 897 |
+
|
| 898 |
+
Returns:
|
| 899 |
+
features: (B, T, D) weighted concept features
|
| 900 |
+
topk_indices: (B, T, k) top-k concept indices
|
| 901 |
+
topk_logits: (B, T, k) logits for top-k concepts
|
| 902 |
+
"""
|
| 903 |
+
assert self.factorize, 'Only valid for factorized head'
|
| 904 |
+
B, T, D = query.shape
|
| 905 |
+
BT = B * T
|
| 906 |
+
C = self.n_concepts
|
| 907 |
+
_ = self.factorize_rank
|
| 908 |
+
device = query.device
|
| 909 |
+
scale = 1.0 / math.sqrt(D)
|
| 910 |
+
k = min(k, C)
|
| 911 |
+
flat_q = query.reshape(BT, D)
|
| 912 |
+
coef = self.embedding_coef.weight[:C]
|
| 913 |
+
basis_weight = self.embedding_basis.weight
|
| 914 |
+
q_compressed = flat_q @ basis_weight
|
| 915 |
+
topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype)
|
| 916 |
+
topi = torch.zeros((BT, k), device=device, dtype=torch.long)
|
| 917 |
+
for start in range(0, C, block_size):
|
| 918 |
+
end = min(start + block_size, C)
|
| 919 |
+
coef_chunk = coef[start:end]
|
| 920 |
+
scores_chunk = q_compressed.float() @ coef_chunk.T.float() * scale
|
| 921 |
+
scores_chunk = scores_chunk.clamp(-15, 15)
|
| 922 |
+
blk_k = min(k, end - start)
|
| 923 |
+
v_chunk, idx_chunk = torch.topk(scores_chunk, blk_k, dim=1)
|
| 924 |
+
i_chunk = idx_chunk + start
|
| 925 |
+
if blk_k < k:
|
| 926 |
+
pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
|
| 927 |
+
pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
|
| 928 |
+
v_chunk = torch.cat([v_chunk, pad_v], dim=1)
|
| 929 |
+
i_chunk = torch.cat([i_chunk, pad_i], dim=1)
|
| 930 |
+
topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k)
|
| 931 |
+
coef_sel = self.embedding_coef(topi)
|
| 932 |
+
logits_sel = torch.einsum('br,bkr->bk', q_compressed.float(), coef_sel.float()) * scale
|
| 933 |
+
logits_sel = logits_sel.clamp(-15, 15)
|
| 934 |
+
weights_sel = torch.sigmoid(logits_sel)
|
| 935 |
+
weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel)
|
| 936 |
+
features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
|
| 937 |
+
return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
|
| 938 |
+
|
| 939 |
+
def linear_block_features_factorized(self, hidden: Tensor, block_size: int=4096) -> Tensor:
|
| 940 |
+
"""
|
| 941 |
+
Memory-efficient factorized linear prediction over ALL concepts.
|
| 942 |
+
|
| 943 |
+
Uses factorized predictor: logits = hidden @ down @ up.T
|
| 944 |
+
Uses factorized embeddings: features = weights @ coef @ basis
|
| 945 |
+
|
| 946 |
+
Args:
|
| 947 |
+
hidden: (B, T, D) hidden states
|
| 948 |
+
block_size: Concepts per block
|
| 949 |
+
|
| 950 |
+
Returns:
|
| 951 |
+
(B, T, D) weighted concept features
|
| 952 |
+
"""
|
| 953 |
+
assert self.factorize, 'Only valid for factorized head'
|
| 954 |
+
assert self.predictor_down is not None, 'Linear path requires predictor'
|
| 955 |
+
B, T, D = hidden.shape
|
| 956 |
+
BT = B * T
|
| 957 |
+
C = self.n_concepts
|
| 958 |
+
_ = self.factorize_rank
|
| 959 |
+
device = hidden.device
|
| 960 |
+
flat_h = hidden.reshape(BT, D)
|
| 961 |
+
coef = self.embedding_coef.weight[:C]
|
| 962 |
+
basis_weight = self.embedding_basis.weight
|
| 963 |
+
down_weight = self.predictor_down.weight
|
| 964 |
+
up_weight = self.predictor_up.weight[:C]
|
| 965 |
+
h_compressed = flat_h @ down_weight.T
|
| 966 |
+
output = torch.zeros(BT, D, dtype=hidden.dtype, device=device)
|
| 967 |
+
for start in range(0, C, block_size):
|
| 968 |
+
end = min(start + block_size, C)
|
| 969 |
+
up_chunk = up_weight[start:end]
|
| 970 |
+
coef_chunk = coef[start:end]
|
| 971 |
+
logits_chunk = h_compressed.float() @ up_chunk.T.float()
|
| 972 |
+
logits_chunk = logits_chunk.clamp(-15, 15)
|
| 973 |
+
weights_chunk = torch.sigmoid(logits_chunk)
|
| 974 |
+
weighted_coef = weights_chunk @ coef_chunk.float()
|
| 975 |
+
features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
|
| 976 |
+
output.add_(features_chunk)
|
| 977 |
+
return output.reshape(B, T, D).to(hidden.dtype)
|
| 978 |
+
|
| 979 |
+
def linear_features_topk_factorized(self, hidden: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]:
|
| 980 |
+
"""
|
| 981 |
+
Memory-efficient factorized linear with streaming top-k.
|
| 982 |
+
|
| 983 |
+
Args:
|
| 984 |
+
hidden: (B, T, D) hidden states
|
| 985 |
+
k: Number of top concepts per token
|
| 986 |
+
block_size: Concepts per block
|
| 987 |
+
|
| 988 |
+
Returns:
|
| 989 |
+
features: (B, T, D) weighted concept features
|
| 990 |
+
topk_indices: (B, T, k) top-k concept indices
|
| 991 |
+
topk_logits: (B, T, k) logits for top-k concepts
|
| 992 |
+
"""
|
| 993 |
+
assert self.factorize, 'Only valid for factorized head'
|
| 994 |
+
assert self.predictor_down is not None, 'Linear path requires predictor'
|
| 995 |
+
B, T, D = hidden.shape
|
| 996 |
+
BT = B * T
|
| 997 |
+
C = self.n_concepts
|
| 998 |
+
_ = self.factorize_rank
|
| 999 |
+
device = hidden.device
|
| 1000 |
+
k = min(k, C)
|
| 1001 |
+
flat_h = hidden.reshape(BT, D)
|
| 1002 |
+
down_weight = self.predictor_down.weight
|
| 1003 |
+
up_weight = self.predictor_up.weight[:C]
|
| 1004 |
+
basis_weight = self.embedding_basis.weight
|
| 1005 |
+
h_compressed = flat_h @ down_weight.T
|
| 1006 |
+
topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype)
|
| 1007 |
+
topi = torch.zeros((BT, k), device=device, dtype=torch.long)
|
| 1008 |
+
for start in range(0, C, block_size):
|
| 1009 |
+
end = min(start + block_size, C)
|
| 1010 |
+
up_chunk = up_weight[start:end]
|
| 1011 |
+
logits_chunk = h_compressed.float() @ up_chunk.T.float()
|
| 1012 |
+
logits_chunk = logits_chunk.clamp(-15, 15)
|
| 1013 |
+
blk_k = min(k, end - start)
|
| 1014 |
+
v_chunk, idx_chunk = torch.topk(logits_chunk, blk_k, dim=1)
|
| 1015 |
+
i_chunk = idx_chunk + start
|
| 1016 |
+
if blk_k < k:
|
| 1017 |
+
pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
|
| 1018 |
+
pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
|
| 1019 |
+
v_chunk = torch.cat([v_chunk, pad_v], dim=1)
|
| 1020 |
+
i_chunk = torch.cat([i_chunk, pad_i], dim=1)
|
| 1021 |
+
topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k)
|
| 1022 |
+
coef_sel = self.embedding_coef(topi)
|
| 1023 |
+
up_sel = self._safe_index(self.predictor_up.weight[:C], topi)
|
| 1024 |
+
logits_sel = torch.einsum('br,bkr->bk', h_compressed.float(), up_sel.float())
|
| 1025 |
+
logits_sel = logits_sel.clamp(-15, 15)
|
| 1026 |
+
weights_sel = torch.sigmoid(logits_sel)
|
| 1027 |
+
weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel)
|
| 1028 |
+
features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
|
| 1029 |
+
return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
|
| 1030 |
+
|
| 1031 |
+
def compute_logits_for_indices(self, hidden: Tensor, indices: Tensor) -> Tensor:
|
| 1032 |
+
"""
|
| 1033 |
+
Compute logits for specific concept indices only (sparse).
|
| 1034 |
+
|
| 1035 |
+
Supports both dense and factorized heads.
|
| 1036 |
+
|
| 1037 |
+
IMPORTANT: This function materializes (M, K, D) where M is the number of
|
| 1038 |
+
tokens in hidden. Only call this with small M (e.g., masked tokens only).
|
| 1039 |
+
|
| 1040 |
+
Args:
|
| 1041 |
+
hidden: (M, D) or (B, T, D) hidden states
|
| 1042 |
+
indices: (M, K) or (B, T, K) concept indices
|
| 1043 |
+
|
| 1044 |
+
Returns:
|
| 1045 |
+
logits: Same shape as indices
|
| 1046 |
+
"""
|
| 1047 |
+
if hidden.dim() == 2:
|
| 1048 |
+
M, D = hidden.shape
|
| 1049 |
+
K = indices.size(-1)
|
| 1050 |
+
flat_h = hidden
|
| 1051 |
+
flat_idx = indices
|
| 1052 |
+
output_shape = indices.shape
|
| 1053 |
+
else:
|
| 1054 |
+
B, T, D = hidden.shape
|
| 1055 |
+
K = indices.size(-1)
|
| 1056 |
+
M = B * T
|
| 1057 |
+
flat_h = hidden.reshape(M, D)
|
| 1058 |
+
flat_idx = indices.reshape(M, K)
|
| 1059 |
+
output_shape = indices.shape
|
| 1060 |
+
estimated_bytes = M * K * D * 2
|
| 1061 |
+
if estimated_bytes > 1000000000.0:
|
| 1062 |
+
warnings.warn(f'compute_logits_for_indices will allocate ~{estimated_bytes / 1000000000.0:.1f} GB. Consider reducing M={M} (use masked tokens only) or K={K}.')
|
| 1063 |
+
n_valid = self.n_concepts
|
| 1064 |
+
indices_safe = flat_idx.clamp(0, n_valid - 1)
|
| 1065 |
+
if self.use_attention:
|
| 1066 |
+
query = self.concept_query_projection(flat_h.unsqueeze(0)).squeeze(0)
|
| 1067 |
+
scale = 1.0 / math.sqrt(self.concept_dim)
|
| 1068 |
+
E_sel = self._get_embedding(indices_safe)
|
| 1069 |
+
logits = torch.einsum('md,mkd->mk', query.float(), E_sel.float()) * scale
|
| 1070 |
+
else:
|
| 1071 |
+
if self.factorize:
|
| 1072 |
+
W = self._get_predictor_weight()[:n_valid]
|
| 1073 |
+
W_sel = self._safe_index(W, indices_safe)
|
| 1074 |
+
else:
|
| 1075 |
+
W = self.concept_predictor.weight[:n_valid]
|
| 1076 |
+
W_sel = self._safe_index(W, indices_safe)
|
| 1077 |
+
logits = torch.einsum('md,mkd->mk', flat_h.float(), W_sel.float())
|
| 1078 |
+
return logits.clamp(-15, 15).reshape(output_shape)
|
| 1079 |
+
|
| 1080 |
+
def get_concept_weights(self, hidden: Tensor, concept_ids: Tensor) -> Tensor:
|
| 1081 |
+
"""
|
| 1082 |
+
Get sigmoid weights for specific concepts (for attribution).
|
| 1083 |
+
|
| 1084 |
+
Args:
|
| 1085 |
+
hidden: (B, T, D) or (M, D) hidden states
|
| 1086 |
+
concept_ids: (B, T, K) or (M, K) or (K,) concept indices
|
| 1087 |
+
|
| 1088 |
+
Returns:
|
| 1089 |
+
weights: Same shape as concept_ids, values in [0, 1]
|
| 1090 |
+
"""
|
| 1091 |
+
if concept_ids.dim() == 1:
|
| 1092 |
+
if hidden.dim() == 2:
|
| 1093 |
+
M = hidden.size(0)
|
| 1094 |
+
concept_ids = concept_ids.unsqueeze(0).expand(M, -1)
|
| 1095 |
+
else:
|
| 1096 |
+
B, T, _ = hidden.shape
|
| 1097 |
+
concept_ids = concept_ids.unsqueeze(0).unsqueeze(0).expand(B, T, -1)
|
| 1098 |
+
logits = self.compute_logits_for_indices(hidden, concept_ids)
|
| 1099 |
+
return torch.sigmoid(logits)
|
| 1100 |
+
|
| 1101 |
+
@staticmethod
|
| 1102 |
+
def blocked_logits(query: Tensor, embeddings: Tensor, block_size: int=8192, out_device: torch.device | None=None, out_dtype: torch.dtype=torch.float32) -> Tensor:
|
| 1103 |
+
"""
|
| 1104 |
+
Compute concept logits in column blocks for memory efficiency.
|
| 1105 |
+
|
| 1106 |
+
logits = query @ embeddings.T / sqrt(D)
|
| 1107 |
+
"""
|
| 1108 |
+
B, T, D = query.shape
|
| 1109 |
+
C = embeddings.size(0)
|
| 1110 |
+
scale = 1.0 / math.sqrt(D)
|
| 1111 |
+
dev = query.device if out_device is None else out_device
|
| 1112 |
+
logits = torch.empty(B, T, C, device=dev, dtype=out_dtype)
|
| 1113 |
+
q = query.reshape(-1, D).to(torch.float32)
|
| 1114 |
+
Et = embeddings.t().contiguous().to(torch.float32)
|
| 1115 |
+
for s in range(0, C, block_size):
|
| 1116 |
+
e = min(s + block_size, C)
|
| 1117 |
+
scores = q @ Et[:, s:e] * scale
|
| 1118 |
+
scores = scores.clamp(-15, 15)
|
| 1119 |
+
logits[:, :, s:e] = scores.reshape(B, T, e - s).to(out_dtype)
|
| 1120 |
+
return logits
|
| 1121 |
+
|
| 1122 |
+
@staticmethod
|
| 1123 |
+
def blocked_mix(weights: Tensor, embeddings: Tensor, block_size: int=8192) -> Tensor:
|
| 1124 |
+
"""
|
| 1125 |
+
Compute weighted sum of embeddings in column blocks.
|
| 1126 |
+
|
| 1127 |
+
output = weights @ embeddings
|
| 1128 |
+
"""
|
| 1129 |
+
B, T, C = weights.shape
|
| 1130 |
+
D = embeddings.size(1)
|
| 1131 |
+
out = torch.zeros(B, T, D, device=weights.device, dtype=weights.dtype)
|
| 1132 |
+
for s in range(0, C, block_size):
|
| 1133 |
+
e = min(s + block_size, C)
|
| 1134 |
+
w_blk = weights[:, :, s:e].to(torch.float32)
|
| 1135 |
+
V_blk = embeddings[s:e].to(w_blk.dtype)
|
| 1136 |
+
out.add_(w_blk @ V_blk)
|
| 1137 |
+
return out.to(weights.dtype)
|
| 1138 |
+
|
| 1139 |
+
@staticmethod
|
| 1140 |
+
def sigmoid_block_attention(query: Tensor, embeddings: Tensor, block_size: int=8192, return_logits: bool=False) -> Tensor | tuple[Tensor, Tensor]:
|
| 1141 |
+
"""Memory-efficient sigmoid attention using block processing."""
|
| 1142 |
+
B, T, D = query.shape
|
| 1143 |
+
C = embeddings.shape[0]
|
| 1144 |
+
scale = 1.0 / math.sqrt(D)
|
| 1145 |
+
flat_q = query.reshape(-1, D)
|
| 1146 |
+
emb_T = embeddings.t().contiguous()
|
| 1147 |
+
output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device)
|
| 1148 |
+
logits: Tensor | None = None
|
| 1149 |
+
if return_logits:
|
| 1150 |
+
logits = torch.empty(B, T, C, dtype=torch.float32, device=query.device)
|
| 1151 |
+
for start in range(0, C, block_size):
|
| 1152 |
+
end = min(start + block_size, C)
|
| 1153 |
+
scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale
|
| 1154 |
+
scores = scores.clamp(-15, 15)
|
| 1155 |
+
if logits is not None:
|
| 1156 |
+
logits[:, :, start:end] = scores.reshape(B, T, end - start)
|
| 1157 |
+
weights = torch.sigmoid(scores)
|
| 1158 |
+
output.add_(weights @ embeddings[start:end].to(weights.dtype))
|
| 1159 |
+
output = output.reshape(B, T, D).to(query.dtype)
|
| 1160 |
+
if return_logits:
|
| 1161 |
+
assert logits is not None
|
| 1162 |
+
return (output, logits)
|
| 1163 |
+
return output
|
| 1164 |
+
|
| 1165 |
+
def _apply_sparse_interventions(self, features: Tensor, hidden: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor:
|
| 1166 |
+
"""
|
| 1167 |
+
Apply sparse interventions matching original dense behavior.
|
| 1168 |
+
|
| 1169 |
+
Original dense behavior:
|
| 1170 |
+
weights = sigmoid(logits) # (B, T, C)
|
| 1171 |
+
weights[..., c] = new_val # Override
|
| 1172 |
+
features = weights @ embeddings
|
| 1173 |
+
|
| 1174 |
+
Sparse equivalent:
|
| 1175 |
+
features += (new_val - current_weight) * embedding[c]
|
| 1176 |
+
"""
|
| 1177 |
+
B, T, D = features.shape
|
| 1178 |
+
valid = intervene_ids != -1
|
| 1179 |
+
if not valid.any():
|
| 1180 |
+
return features
|
| 1181 |
+
ids_safe = intervene_ids.clamp(0, self.n_concepts - 1)
|
| 1182 |
+
current_logits = self.compute_logits_for_indices(hidden, ids_safe)
|
| 1183 |
+
current_weights = torch.sigmoid(current_logits)
|
| 1184 |
+
emb = self._get_embedding(ids_safe)
|
| 1185 |
+
delta = (intervene_vals - current_weights) * valid.float()
|
| 1186 |
+
correction = (delta.unsqueeze(-1) * emb).sum(dim=2)
|
| 1187 |
+
return features + correction
|
| 1188 |
+
|
| 1189 |
+
def _apply_dense_interventions(self, concept_weight: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor:
|
| 1190 |
+
"""Apply interventions by overriding concept weights (dense path)."""
|
| 1191 |
+
n_valid = min(self.n_concepts, concept_weight.size(-1))
|
| 1192 |
+
valid_edit = intervene_ids != -1
|
| 1193 |
+
ids = intervene_ids.clamp(0, n_valid - 1).long()
|
| 1194 |
+
vals = intervene_vals.to(concept_weight.dtype)
|
| 1195 |
+
updates = torch.zeros_like(concept_weight)
|
| 1196 |
+
updates.scatter_add_(2, ids, torch.where(valid_edit, vals, torch.zeros_like(vals)))
|
| 1197 |
+
set_mask = torch.zeros_like(concept_weight, dtype=torch.bool)
|
| 1198 |
+
set_mask.scatter_(2, ids, valid_edit)
|
| 1199 |
+
return torch.where(set_mask, updates, concept_weight)
|
| 1200 |
+
|
| 1201 |
+
def topk_with_cutoff(self, tensor: Tensor, dim: int=-1) -> Tensor:
|
| 1202 |
+
"""
|
| 1203 |
+
Apply top-k sparsity, zeroing out all but top-k values.
|
| 1204 |
+
|
| 1205 |
+
Args:
|
| 1206 |
+
tensor: Input tensor, typically (B, T, C)
|
| 1207 |
+
dim: Dimension to apply top-k (default: last)
|
| 1208 |
+
|
| 1209 |
+
Returns:
|
| 1210 |
+
Sparse tensor with only top-k values preserved
|
| 1211 |
+
"""
|
| 1212 |
+
assert dim == -1 or dim == tensor.dim() - 1
|
| 1213 |
+
if self.topk is None:
|
| 1214 |
+
return tensor
|
| 1215 |
+
padded = tensor.size(dim)
|
| 1216 |
+
n_valid = min(self.n_concepts, padded)
|
| 1217 |
+
if n_valid <= 0:
|
| 1218 |
+
return torch.zeros_like(tensor)
|
| 1219 |
+
x = tensor.narrow(dim, 0, n_valid)
|
| 1220 |
+
kk = min(self.topk, n_valid)
|
| 1221 |
+
topv, topi = torch.topk(x, kk, dim=dim)
|
| 1222 |
+
out = torch.zeros_like(x)
|
| 1223 |
+
out.scatter_(dim, topi, topv)
|
| 1224 |
+
if n_valid < padded:
|
| 1225 |
+
pad_shape = list(out.shape)
|
| 1226 |
+
pad_shape[dim] = padded - n_valid
|
| 1227 |
+
pad_zeros = out.new_zeros(pad_shape)
|
| 1228 |
+
out = torch.cat([out, pad_zeros], dim=dim)
|
| 1229 |
+
return out
|
| 1230 |
+
|
| 1231 |
+
def _compute_weights(self, concept_logits: Tensor, E: Tensor) -> Tensor:
|
| 1232 |
+
"""Compute concept weights from logits, with optional top-k sparsity."""
|
| 1233 |
+
apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown)
|
| 1234 |
+
if apply_topk and self.topk_on_logits:
|
| 1235 |
+
logits_for_weights = self.topk_with_cutoff(concept_logits)
|
| 1236 |
+
weights = torch.sigmoid(logits_for_weights).to(E.dtype)
|
| 1237 |
+
return weights
|
| 1238 |
+
weights = torch.sigmoid(concept_logits).to(E.dtype)
|
| 1239 |
+
if apply_topk and (not self.topk_on_logits):
|
| 1240 |
+
weights = self.topk_with_cutoff(weights)
|
| 1241 |
+
return weights
|
| 1242 |
+
|
| 1243 |
+
@torch.compiler.disable
|
| 1244 |
+
def forward(self, hidden: Tensor, intervene_ids: Tensor | None=None, intervene_vals: Tensor | None=None, return_logits: bool=False, store_hidden: bool=False) -> ConceptHeadOutput:
|
| 1245 |
+
"""
|
| 1246 |
+
Forward pass for concept decomposition (inference only, no teacher forcing).
|
| 1247 |
+
|
| 1248 |
+
Args:
|
| 1249 |
+
hidden: Transformer hidden states (B, T, n_embd)
|
| 1250 |
+
intervene_ids: Concept IDs to intervene on (B, T, K_int), -1 = skip
|
| 1251 |
+
intervene_vals: Intervention strength values (B, T, K_int)
|
| 1252 |
+
return_logits: If True, compute full (B, T, C) logits. Forbidden for large heads.
|
| 1253 |
+
store_hidden: If True, store hidden in output for later attribution.
|
| 1254 |
+
|
| 1255 |
+
Returns:
|
| 1256 |
+
ConceptHeadOutput with features, predicted, topk_indices, topk_logits
|
| 1257 |
+
"""
|
| 1258 |
+
B, T, _ = hidden.shape
|
| 1259 |
+
has_interventions = intervene_ids is not None and intervene_vals is not None
|
| 1260 |
+
if return_logits:
|
| 1261 |
+
self._check_dense_allowed('return_logits=True')
|
| 1262 |
+
n_valid = self.n_concepts
|
| 1263 |
+
concept_logits: Tensor | None = None
|
| 1264 |
+
concept_weight: Tensor | None = None
|
| 1265 |
+
predicted: Tensor
|
| 1266 |
+
topk_indices: Tensor | None = None
|
| 1267 |
+
topk_logits: Tensor | None = None
|
| 1268 |
+
apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown)
|
| 1269 |
+
k_features = self.topk_features if self.topk_features is not None else self.topk
|
| 1270 |
+
use_dense_intervention = has_interventions and (not self._is_large)
|
| 1271 |
+
if use_dense_intervention:
|
| 1272 |
+
E = self._get_embedding_weight()[:n_valid]
|
| 1273 |
+
if self.use_attention:
|
| 1274 |
+
query = self.concept_query_projection(hidden)
|
| 1275 |
+
concept_logits = self.blocked_logits(query, E, block_size=self.block_size)
|
| 1276 |
+
else:
|
| 1277 |
+
if self.factorize:
|
| 1278 |
+
W = self._get_predictor_weight()[:n_valid]
|
| 1279 |
+
raw_logits = hidden @ W.T
|
| 1280 |
+
else:
|
| 1281 |
+
raw_logits = self.concept_predictor(hidden)[..., :n_valid]
|
| 1282 |
+
concept_logits = raw_logits.float().clamp(-15, 15)
|
| 1283 |
+
concept_weight = self._compute_weights(concept_logits, E)
|
| 1284 |
+
assert intervene_ids is not None and intervene_vals is not None
|
| 1285 |
+
concept_weight = self._apply_dense_interventions(concept_weight, intervene_ids, intervene_vals)
|
| 1286 |
+
predicted = self.blocked_mix(concept_weight, E, block_size=self.block_size)
|
| 1287 |
+
elif self.factorize:
|
| 1288 |
+
if self.use_attention:
|
| 1289 |
+
query = self.concept_query_projection(hidden)
|
| 1290 |
+
if apply_topk:
|
| 1291 |
+
predicted, topk_indices, topk_logits = self.attention_features_topk_factorized(query, k=k_features, block_size=self.block_size)
|
| 1292 |
+
else:
|
| 1293 |
+
predicted = self.attention_block_features_factorized(query, block_size=self.block_size)
|
| 1294 |
+
elif apply_topk:
|
| 1295 |
+
predicted, topk_indices, topk_logits = self.linear_features_topk_factorized(hidden, k=k_features, block_size=self.block_size)
|
| 1296 |
+
else:
|
| 1297 |
+
predicted = self.linear_block_features_factorized(hidden, block_size=self.block_size)
|
| 1298 |
+
elif apply_topk:
|
| 1299 |
+
E = self._get_embedding_weight()[:n_valid]
|
| 1300 |
+
if self.use_attention:
|
| 1301 |
+
query = self.concept_query_projection(hidden)
|
| 1302 |
+
predicted, topk_indices, topk_logits = self.attention_features_topk_streaming(query, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits)
|
| 1303 |
+
else:
|
| 1304 |
+
W = self.concept_predictor.weight[:n_valid]
|
| 1305 |
+
predicted, topk_indices, topk_logits = self.linear_features_topk_streaming(hidden, W, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits)
|
| 1306 |
+
else:
|
| 1307 |
+
E = self._get_embedding_weight()[:n_valid]
|
| 1308 |
+
if self.use_attention:
|
| 1309 |
+
query = self.concept_query_projection(hidden)
|
| 1310 |
+
predicted = self.attention_block_features(query, E, block_size=self.block_size)
|
| 1311 |
+
else:
|
| 1312 |
+
W = self.concept_predictor.weight[:n_valid]
|
| 1313 |
+
predicted = self.linear_block_features(hidden, W, E, block_size=self.block_size)
|
| 1314 |
+
if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk):
|
| 1315 |
+
_, rerank_idx = torch.topk(topk_logits, self.topk, dim=-1)
|
| 1316 |
+
topk_indices = torch.gather(topk_indices, -1, rerank_idx)
|
| 1317 |
+
topk_logits = torch.gather(topk_logits, -1, rerank_idx)
|
| 1318 |
+
if return_logits and (not use_dense_intervention):
|
| 1319 |
+
E = self._get_embedding_weight()[:n_valid]
|
| 1320 |
+
if self.use_attention:
|
| 1321 |
+
query = self.concept_query_projection(hidden)
|
| 1322 |
+
concept_logits = self.blocked_logits(query, E, block_size=self.block_size)
|
| 1323 |
+
else:
|
| 1324 |
+
if self.factorize:
|
| 1325 |
+
W = self._get_predictor_weight()[:n_valid]
|
| 1326 |
+
raw_logits = hidden @ W.T
|
| 1327 |
+
else:
|
| 1328 |
+
raw_logits = self.concept_predictor(hidden)[..., :n_valid]
|
| 1329 |
+
concept_logits = raw_logits.float().clamp(-15, 15)
|
| 1330 |
+
concept_weight = self._compute_weights(concept_logits, E)
|
| 1331 |
+
if not hasattr(self, '_logged_forward_path'):
|
| 1332 |
+
self._logged_forward_path = True
|
| 1333 |
+
path = 'dense_intervention' if use_dense_intervention else 'factorized_topk' if self.factorize and apply_topk else 'factorized_all' if self.factorize else 'streaming_topk' if apply_topk else 'dense_all'
|
| 1334 |
+
logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: path={path}, topk={self.topk}, topk_features={self.topk_features}, n_concepts={self.n_concepts}, factorize={self.factorize}, apply_topk={apply_topk}")
|
| 1335 |
+
if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk):
|
| 1336 |
+
if not hasattr(self, '_logged_topk_slice'):
|
| 1337 |
+
self._logged_topk_slice = True
|
| 1338 |
+
logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: Sliced topk: {self.topk_features} features -> {self.topk} for loss")
|
| 1339 |
+
if has_interventions and (not use_dense_intervention):
|
| 1340 |
+
assert intervene_ids is not None and intervene_vals is not None
|
| 1341 |
+
predicted = self._apply_sparse_interventions(predicted, hidden, intervene_ids, intervene_vals)
|
| 1342 |
+
return ConceptHeadOutput(features=predicted, gt_features=None, logits=concept_logits, predicted=predicted, weights=concept_weight, topk_indices=topk_indices, topk_logits=topk_logits, hidden=hidden.detach() if store_hidden else None)
|
| 1343 |
+
|
| 1344 |
+
# ======================================================================
|
| 1345 |
+
# steerling/models/interpretable/interpretable_causal_diffusion.py
|
| 1346 |
+
# ======================================================================
|
| 1347 |
+
|
| 1348 |
+
logger = logging.getLogger(__name__)
|
| 1349 |
+
|
| 1350 |
+
class InterpretableCausalDiffusionLM(nn.Module):
|
| 1351 |
+
"""
|
| 1352 |
+
Interpretable CausalDiffusionLM with concept decomposition heads.
|
| 1353 |
+
|
| 1354 |
+
Wraps a CausalDiffusionLM and adds:
|
| 1355 |
+
- Known concept head: predicts known concepts from hidden states
|
| 1356 |
+
- Unknown concept head: captures residual features (optional)
|
| 1357 |
+
- Steering via concept interventions
|
| 1358 |
+
|
| 1359 |
+
Args:
|
| 1360 |
+
config: CausalDiffusionConfig (model architecture)
|
| 1361 |
+
concept_config: ConceptConfig (concept decomposition)
|
| 1362 |
+
vocab_size: Vocabulary size
|
| 1363 |
+
"""
|
| 1364 |
+
|
| 1365 |
+
def __init__(self, config: CausalDiffusionConfig, concept_config: ConceptConfig, vocab_size: int):
|
| 1366 |
+
super().__init__()
|
| 1367 |
+
self.config = config
|
| 1368 |
+
self.concept_config = concept_config
|
| 1369 |
+
self.vocab_size = vocab_size
|
| 1370 |
+
self.transformer = CausalDiffusionLM(config, vocab_size)
|
| 1371 |
+
self.known_head = ConceptHead(n_concepts=concept_config.n_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=False, use_attention=concept_config.use_attention_known, topk=concept_config.topk_known, topk_features=concept_config.topk_known_features, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=False, topk_on_logits=concept_config.topk_on_logits)
|
| 1372 |
+
if concept_config.use_unknown:
|
| 1373 |
+
if concept_config.n_unknown_concepts is None:
|
| 1374 |
+
raise ValueError('n_unknown_concepts must be set when use_unknown=True')
|
| 1375 |
+
self.unknown_head: ConceptHead | None = ConceptHead(n_concepts=concept_config.n_unknown_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=True, use_attention=concept_config.use_attention_unknown, topk=concept_config.unknown_topk, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=concept_config.apply_topk_to_unknown, topk_on_logits=concept_config.topk_on_logits, factorize=concept_config.factorize_unknown, factorize_rank=concept_config.factorize_rank)
|
| 1376 |
+
else:
|
| 1377 |
+
self.unknown_head = None
|
| 1378 |
+
|
| 1379 |
+
def forward(self, input_ids: Tensor, *, input_embeds: Tensor | None=None, intervene_known_ids: Tensor | None=None, intervene_known_vals: Tensor | None=None, intervene_unknown_ids: Tensor | None=None, intervene_unknown_vals: Tensor | None=None, minimal_output: bool=False, position_injection: Tensor | None=None, steering_inject_layer: int | None=None, steering_inject_alpha: float=1.0, unknown_topk: int=64) -> tuple[Tensor, InterpretableOutput]:
|
| 1380 |
+
"""
|
| 1381 |
+
Forward pass with concept decomposition.
|
| 1382 |
+
|
| 1383 |
+
Args:
|
| 1384 |
+
input_ids: Token IDs (B, T). May contain mask tokens.
|
| 1385 |
+
input_embeds: Pre-computed embeddings (B, T, D). Overrides input_ids.
|
| 1386 |
+
intervene_known_ids: Known concept IDs to intervene (B, T, K_int)
|
| 1387 |
+
intervene_known_vals: Intervention values for known (B, T, K_int)
|
| 1388 |
+
intervene_unknown_ids: Unknown concept IDs to intervene (B, T, K_int)
|
| 1389 |
+
intervene_unknown_vals: Intervention values for unknown (B, T, K_int)
|
| 1390 |
+
minimal_output: If True, skip some expensive computations
|
| 1391 |
+
position_injection: Per-position steering injection (B, T, D)
|
| 1392 |
+
steering_inject_layer: Inject at layers >= this
|
| 1393 |
+
steering_inject_alpha: Injection strength
|
| 1394 |
+
unknown_topk: Top-k for unknown head attribution
|
| 1395 |
+
|
| 1396 |
+
Returns:
|
| 1397 |
+
logits: LM logits (B, T, V)
|
| 1398 |
+
outputs: InterpretableOutput with all decomposition components
|
| 1399 |
+
"""
|
| 1400 |
+
need_dense_logits = not minimal_output
|
| 1401 |
+
if position_injection is not None and steering_inject_layer is not None:
|
| 1402 |
+
hidden = self._forward_with_injection(input_ids, input_embeds, position_injection, steering_inject_layer, steering_inject_alpha)
|
| 1403 |
+
else:
|
| 1404 |
+
hidden = self.transformer(input_ids, input_embeds=input_embeds, return_hidden=True)
|
| 1405 |
+
known_out: ConceptHeadOutput = self.known_head(hidden, intervene_ids=intervene_known_ids, intervene_vals=intervene_known_vals, return_logits=need_dense_logits)
|
| 1406 |
+
known_features = known_out.features.to(hidden.dtype)
|
| 1407 |
+
unk = hidden - known_features.detach()
|
| 1408 |
+
unk_for_lm: Tensor = unk
|
| 1409 |
+
unknown_out: ConceptHeadOutput | None = None
|
| 1410 |
+
unk_hat: Tensor | None = None
|
| 1411 |
+
if self.unknown_head is not None:
|
| 1412 |
+
unknown_out = self.unknown_head(hidden.detach(), intervene_ids=intervene_unknown_ids, intervene_vals=intervene_unknown_vals, return_logits=not minimal_output and (not self.unknown_head._is_large))
|
| 1413 |
+
assert unknown_out is not None
|
| 1414 |
+
unk_hat = unknown_out.features.to(hidden.dtype)
|
| 1415 |
+
unk_for_lm = unk_hat.detach()
|
| 1416 |
+
epsilon_true = None
|
| 1417 |
+
if self.unknown_head is not None and unk_hat is not None:
|
| 1418 |
+
epsilon_true = hidden.detach() - (known_out.predicted + unk_hat)
|
| 1419 |
+
epsilon = None
|
| 1420 |
+
if self.concept_config.use_epsilon_correction and intervene_known_ids is None:
|
| 1421 |
+
epsilon = hidden - (unk_for_lm + known_features)
|
| 1422 |
+
unk_for_lm = unk_for_lm + epsilon
|
| 1423 |
+
composed = unk_for_lm + known_features
|
| 1424 |
+
logits = self.transformer.lm_head(composed)
|
| 1425 |
+
_unk_topk_indices = unknown_out.topk_indices if unknown_out else None
|
| 1426 |
+
_unk_topk_logits = unknown_out.topk_logits if unknown_out else None
|
| 1427 |
+
if not minimal_output and self.unknown_head is not None and (unknown_out is not None) and (_unk_topk_indices is None) and (unknown_topk > 0):
|
| 1428 |
+
with torch.no_grad():
|
| 1429 |
+
_unk_topk_indices, _unk_topk_logits = self._compute_unknown_topk(hidden, unknown_topk)
|
| 1430 |
+
outputs = InterpretableOutput(hidden=hidden, known_features=known_features, known_logits=known_out.logits, known_gt_features=known_out.gt_features, known_predicted=known_out.predicted, known_weights=known_out.weights, known_topk_indices=known_out.topk_indices, known_topk_logits=known_out.topk_logits, unk=unk, unk_hat=unk_hat, unk_for_lm=unk_for_lm, unknown_logits=unknown_out.logits if unknown_out else None, unknown_weights=unknown_out.weights if unknown_out else None, unknown_topk_indices=_unk_topk_indices, unknown_topk_logits=_unk_topk_logits, composed=composed, epsilon=epsilon, epsilon_true=epsilon_true)
|
| 1431 |
+
return (logits, outputs)
|
| 1432 |
+
|
| 1433 |
+
def _compute_unknown_topk(self, hidden: Tensor, unknown_topk: int) -> tuple[Tensor | None, Tensor | None]:
|
| 1434 |
+
"""Compute unknown head top-k indices for attribution."""
|
| 1435 |
+
assert self.unknown_head is not None
|
| 1436 |
+
if self.unknown_head.factorize:
|
| 1437 |
+
if self.unknown_head.use_attention:
|
| 1438 |
+
_query = self.unknown_head.concept_query_projection(hidden.detach())
|
| 1439 |
+
_, indices, logits = self.unknown_head.attention_features_topk_factorized(_query, k=unknown_topk, block_size=self.unknown_head.block_size)
|
| 1440 |
+
else:
|
| 1441 |
+
_, indices, logits = self.unknown_head.linear_features_topk_factorized(hidden.detach(), k=unknown_topk, block_size=self.unknown_head.block_size)
|
| 1442 |
+
else:
|
| 1443 |
+
_E = self.unknown_head._get_embedding_weight()[:self.unknown_head.n_concepts]
|
| 1444 |
+
if self.unknown_head.use_attention:
|
| 1445 |
+
_query = self.unknown_head.concept_query_projection(hidden.detach())
|
| 1446 |
+
_, indices, logits = self.unknown_head.attention_features_topk_streaming(_query, _E, k=unknown_topk, block_size=self.unknown_head.block_size)
|
| 1447 |
+
else:
|
| 1448 |
+
_W = self.unknown_head.concept_predictor.weight[:self.unknown_head.n_concepts]
|
| 1449 |
+
_, indices, logits = self.unknown_head.linear_features_topk_streaming(hidden.detach(), _W, _E, k=unknown_topk, block_size=self.unknown_head.block_size)
|
| 1450 |
+
return (indices, logits)
|
| 1451 |
+
|
| 1452 |
+
def _forward_with_injection(self, input_ids: Tensor, input_embeds: Tensor | None, position_injection: Tensor, inject_layer: int, inject_alpha: float) -> Tensor:
|
| 1453 |
+
"""Forward through transformer with steering injection at specified layers."""
|
| 1454 |
+
x = input_embeds if input_embeds is not None else self.transformer.tok_emb(input_ids)
|
| 1455 |
+
for i, block in enumerate(self.transformer.blocks):
|
| 1456 |
+
x = block(x)
|
| 1457 |
+
if i + 1 >= inject_layer:
|
| 1458 |
+
x = x + inject_alpha * position_injection
|
| 1459 |
+
x = self.transformer.ln_f(x)
|
| 1460 |
+
return x
|
| 1461 |
+
|
| 1462 |
+
@torch.no_grad()
|
| 1463 |
+
def intervene(self, input_ids: Tensor, known: dict[int, float] | None=None, unknown: dict[int, float] | None=None, positions: Tensor | None=None) -> tuple[Tensor, InterpretableOutput]:
|
| 1464 |
+
"""
|
| 1465 |
+
Run inference with concept interventions.
|
| 1466 |
+
|
| 1467 |
+
Args:
|
| 1468 |
+
input_ids: Input token IDs (B, T)
|
| 1469 |
+
known: Dict mapping known concept IDs to intervention strengths
|
| 1470 |
+
unknown: Dict mapping unknown concept IDs to intervention strengths
|
| 1471 |
+
positions: Bool mask of positions to intervene (B, T). Default: all.
|
| 1472 |
+
|
| 1473 |
+
Returns:
|
| 1474 |
+
logits: LM logits (B, T, V)
|
| 1475 |
+
outputs: InterpretableOutput
|
| 1476 |
+
"""
|
| 1477 |
+
B, T = input_ids.shape
|
| 1478 |
+
device = input_ids.device
|
| 1479 |
+
if positions is None:
|
| 1480 |
+
positions = torch.ones(B, T, dtype=torch.bool, device=device)
|
| 1481 |
+
int_known_ids, int_known_vals = (None, None)
|
| 1482 |
+
if known is not None and len(known) > 0:
|
| 1483 |
+
int_known_ids, int_known_vals = self._build_intervention_tensors(known, B, T, positions, device)
|
| 1484 |
+
int_unknown_ids, int_unknown_vals = (None, None)
|
| 1485 |
+
if unknown is not None and len(unknown) > 0:
|
| 1486 |
+
int_unknown_ids, int_unknown_vals = self._build_intervention_tensors(unknown, B, T, positions, device)
|
| 1487 |
+
return self(input_ids, intervene_known_ids=int_known_ids, intervene_known_vals=int_known_vals, intervene_unknown_ids=int_unknown_ids, intervene_unknown_vals=int_unknown_vals, minimal_output=False)
|
| 1488 |
+
|
| 1489 |
+
@staticmethod
|
| 1490 |
+
def _build_intervention_tensors(interventions: dict[int, float], B: int, T: int, positions: Tensor, device: torch.device) -> tuple[Tensor, Tensor]:
|
| 1491 |
+
"""Build intervention tensors for concept steering."""
|
| 1492 |
+
K = len(interventions)
|
| 1493 |
+
concept_ids = list(interventions.keys())
|
| 1494 |
+
directions = list(interventions.values())
|
| 1495 |
+
ids = torch.full((B, T, K), -1, dtype=torch.long, device=device)
|
| 1496 |
+
vals = torch.zeros((B, T, K), dtype=torch.float32, device=device)
|
| 1497 |
+
concept_tensor = torch.tensor(concept_ids, device=device)
|
| 1498 |
+
direction_tensor = torch.tensor(directions, dtype=torch.float32, device=device)
|
| 1499 |
+
n_active = int(positions.sum().item())
|
| 1500 |
+
ids[positions] = concept_tensor.unsqueeze(0).expand(n_active, -1)
|
| 1501 |
+
vals[positions] = direction_tensor.unsqueeze(0).expand(n_active, -1)
|
| 1502 |
+
return (ids, vals)
|
| 1503 |
+
|
| 1504 |
+
def get_num_params(self, non_embedding: bool=True) -> int:
|
| 1505 |
+
n_params = sum((p.numel() for p in self.parameters()))
|
| 1506 |
+
if non_embedding and hasattr(self.transformer, 'tok_emb'):
|
| 1507 |
+
n_params -= self.transformer.tok_emb.weight.numel()
|
| 1508 |
+
return n_params
|
| 1509 |
+
from transformers import PreTrainedModel
|
| 1510 |
+
from .configuration_steerling import SteerlingConfig
|
| 1511 |
+
|
| 1512 |
+
|
| 1513 |
+
# CausalDiffusionLM is the backbone — alias to HF-friendly name
|
| 1514 |
+
SteerlingBackbone = CausalDiffusionLM
|
| 1515 |
+
|
| 1516 |
+
|
| 1517 |
+
class SteerlingForCausalLM(PreTrainedModel):
|
| 1518 |
+
config_class = SteerlingConfig
|
| 1519 |
+
supports_gradient_checkpointing = False
|
| 1520 |
+
_tied_weights_keys = ["transformer.lm_head.weight"]
|
| 1521 |
+
|
| 1522 |
+
def __init__(self, config: SteerlingConfig):
|
| 1523 |
+
super().__init__(config)
|
| 1524 |
+
# SteerlingConfig has all fields from both arch and concept configs
|
| 1525 |
+
self.concept_config = config
|
| 1526 |
+
self.transformer = SteerlingBackbone(config, config.vocab_size)
|
| 1527 |
+
self.known_head = ConceptHead(
|
| 1528 |
+
n_concepts=config.n_concepts,
|
| 1529 |
+
concept_dim=config.concept_dim,
|
| 1530 |
+
n_embd=config.n_embd,
|
| 1531 |
+
is_unknown=False,
|
| 1532 |
+
use_attention=config.use_attention_known,
|
| 1533 |
+
topk=config.topk_known,
|
| 1534 |
+
topk_features=config.topk_known_features,
|
| 1535 |
+
block_size=config.concept_block_size,
|
| 1536 |
+
pad_multiple=config.pad_multiple,
|
| 1537 |
+
store_unknown_weights=False,
|
| 1538 |
+
apply_topk_to_unknown=False,
|
| 1539 |
+
topk_on_logits=config.topk_on_logits,
|
| 1540 |
+
factorize=False,
|
| 1541 |
+
)
|
| 1542 |
+
if config.use_unknown:
|
| 1543 |
+
self.unknown_head = ConceptHead(
|
| 1544 |
+
n_concepts=config.n_unknown_concepts,
|
| 1545 |
+
concept_dim=config.concept_dim,
|
| 1546 |
+
n_embd=config.n_embd,
|
| 1547 |
+
is_unknown=True,
|
| 1548 |
+
use_attention=config.use_attention_unknown,
|
| 1549 |
+
topk=config.unknown_topk,
|
| 1550 |
+
block_size=config.concept_block_size,
|
| 1551 |
+
pad_multiple=config.pad_multiple,
|
| 1552 |
+
store_unknown_weights=config.store_unknown_weights,
|
| 1553 |
+
apply_topk_to_unknown=config.apply_topk_to_unknown,
|
| 1554 |
+
topk_on_logits=config.topk_on_logits,
|
| 1555 |
+
factorize=config.factorize_unknown,
|
| 1556 |
+
factorize_rank=config.factorize_rank,
|
| 1557 |
+
)
|
| 1558 |
+
else:
|
| 1559 |
+
self.unknown_head = None
|
| 1560 |
+
self.post_init()
|
| 1561 |
+
|
| 1562 |
+
def _init_weights(self, module):
|
| 1563 |
+
pass
|
| 1564 |
+
|
| 1565 |
+
def _tie_weights(self):
|
| 1566 |
+
if self.config.weight_sharing:
|
| 1567 |
+
self.transformer.lm_head.weight = self.transformer.tok_emb.weight
|
| 1568 |
+
|
| 1569 |
+
def forward(self, input_ids=None, **kwargs):
|
| 1570 |
+
if self.config.interpretable:
|
| 1571 |
+
return InterpretableCausalDiffusionLM.forward(self, input_ids, **kwargs)
|
| 1572 |
+
else:
|
| 1573 |
+
kwargs.pop('minimal_output', None)
|
| 1574 |
+
return CausalDiffusionLM.forward(self, input_ids, **kwargs)
|