Add n_tenstorrent_port.py — Tenstorrent N300s training port
Browse files- n_tenstorrent_port.py +1754 -0
n_tenstorrent_port.py
ADDED
|
@@ -0,0 +1,1754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
n_tenstorrent_port.py
|
| 4 |
+
|
| 5 |
+
Training-first port of the user's joint AR+SAT trainer to support:
|
| 6 |
+
- Tenstorrent via TT-XLA / PJRT (`--backend tt`)
|
| 7 |
+
- NVIDIA CUDA (`--backend cuda`)
|
| 8 |
+
- CPU fallback (`--backend cpu`)
|
| 9 |
+
|
| 10 |
+
Design goals:
|
| 11 |
+
- Keep checkpoint format PyTorch-native and cross-device loadable.
|
| 12 |
+
- Prioritize stable training on TT over aggressive graph tricks.
|
| 13 |
+
- Preserve NVIDIA-trained checkpoint compatibility for inference.
|
| 14 |
+
- Stay as close as practical to the original single-file workflow.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import pathlib
|
| 24 |
+
import time
|
| 25 |
+
from contextlib import nullcontext
|
| 26 |
+
from dataclasses import dataclass
|
| 27 |
+
from datetime import datetime, timedelta, timezone
|
| 28 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
from datasets import DownloadConfig, load_dataset
|
| 34 |
+
from transformers import AutoTokenizer, logging as hf_log
|
| 35 |
+
|
| 36 |
+
STATUS_FILE = "/workspace/status.json"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ───────────────────────── Status helpers ─────────────────────────
|
| 40 |
+
def write_status(step, seen_tok, loss, batch, block, tok_per_sec, phase):
|
| 41 |
+
try:
|
| 42 |
+
with open(STATUS_FILE, "w") as f:
|
| 43 |
+
json.dump(
|
| 44 |
+
{
|
| 45 |
+
"step": step,
|
| 46 |
+
"seen_tok": seen_tok,
|
| 47 |
+
"loss": float(loss) if loss is not None else None,
|
| 48 |
+
"batch": batch,
|
| 49 |
+
"block": block,
|
| 50 |
+
"tok_per_sec": tok_per_sec,
|
| 51 |
+
"phase": phase,
|
| 52 |
+
"updated": time.time(),
|
| 53 |
+
"target_tok": 35737600000,
|
| 54 |
+
},
|
| 55 |
+
f,
|
| 56 |
+
)
|
| 57 |
+
except Exception:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def show_status():
|
| 62 |
+
try:
|
| 63 |
+
with open(STATUS_FILE) as f:
|
| 64 |
+
s = json.load(f)
|
| 65 |
+
age = time.time() - s.get("updated", 0)
|
| 66 |
+
target = s.get("target_tok") or 35737600000
|
| 67 |
+
remaining = target - s.get("seen_tok", 0)
|
| 68 |
+
eta_sec = remaining / max(s.get("tok_per_sec", 1), 1)
|
| 69 |
+
eta_days = eta_sec / 86400
|
| 70 |
+
print(
|
| 71 |
+
f"Step: {s.get('step', '?'):,} | Tokens: {s.get('seen_tok', 0)/1e9:.2f}B / {target/1e9:.1f}B | Loss: {s.get('loss', 0):.4f}"
|
| 72 |
+
)
|
| 73 |
+
print(
|
| 74 |
+
f"Speed: {s.get('tok_per_sec', 0):.0f} tok/s | B={s.get('batch')} L={s.get('block')} | ETA: {eta_days:.1f} days | {age:.0f}s ago"
|
| 75 |
+
)
|
| 76 |
+
except FileNotFoundError:
|
| 77 |
+
print("No status file. Training not running?")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"Error: {e}")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ───────────────────────── Safe progress ─────────────────────────
|
| 83 |
+
class SafeProgress:
|
| 84 |
+
def __init__(self, total, initial=0, unit="tok"):
|
| 85 |
+
self.total = total
|
| 86 |
+
self.n = initial
|
| 87 |
+
self.unit = unit
|
| 88 |
+
self.last_print = initial
|
| 89 |
+
self.postfix = {}
|
| 90 |
+
self.start_time = time.time()
|
| 91 |
+
|
| 92 |
+
def update(self, n=1):
|
| 93 |
+
self.n += n
|
| 94 |
+
if self.n - self.last_print >= 1_000_000:
|
| 95 |
+
self._print()
|
| 96 |
+
self.last_print = self.n
|
| 97 |
+
|
| 98 |
+
def set_postfix(self, **kwargs):
|
| 99 |
+
self.postfix = kwargs
|
| 100 |
+
|
| 101 |
+
def _print(self):
|
| 102 |
+
elapsed = time.time() - self.start_time
|
| 103 |
+
rate = self.n / elapsed if elapsed > 0 else 0
|
| 104 |
+
pct = 100 * self.n / self.total if self.total > 0 else 0
|
| 105 |
+
pf = " ".join(f"{k}={v}" for k, v in self.postfix.items())
|
| 106 |
+
print(f"[{pct:.1f}%] {self.n:,}/{self.total:,} {self.unit} | {rate:.0f} tok/s | {pf}")
|
| 107 |
+
|
| 108 |
+
def close(self):
|
| 109 |
+
self._print()
|
| 110 |
+
print("Done.")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ───────────────────────── ANSI colors ─────────────────────────
|
| 114 |
+
class Colors:
|
| 115 |
+
RESET = "\033[0m"
|
| 116 |
+
BOLD = "\033[1m"
|
| 117 |
+
PROMPT = "\033[36m"
|
| 118 |
+
GEN = "\033[0m"
|
| 119 |
+
INFO = "\033[90m"
|
| 120 |
+
WARN = "\033[93m"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
hf_log.set_verbosity_error()
|
| 124 |
+
|
| 125 |
+
if torch.cuda.is_available():
|
| 126 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 127 |
+
try:
|
| 128 |
+
torch.set_float32_matmul_precision("high")
|
| 129 |
+
except Exception:
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ───────────────────────── Runtime backend ─────────────────────────
|
| 134 |
+
@dataclass
|
| 135 |
+
class BackendRuntime:
|
| 136 |
+
backend: str
|
| 137 |
+
device: torch.device
|
| 138 |
+
is_cuda: bool = False
|
| 139 |
+
is_tt: bool = False
|
| 140 |
+
is_xla: bool = False
|
| 141 |
+
dtype: torch.dtype = torch.float32
|
| 142 |
+
xm: Any = None
|
| 143 |
+
xr: Any = None
|
| 144 |
+
xs: Any = None
|
| 145 |
+
mesh: Any = None
|
| 146 |
+
spmd: bool = False
|
| 147 |
+
compile_options: Optional[Dict[str, str]] = None
|
| 148 |
+
num_devices: int = 1
|
| 149 |
+
|
| 150 |
+
def sync(self, wait: bool = False) -> None:
|
| 151 |
+
if self.is_cuda:
|
| 152 |
+
torch.cuda.synchronize(self.device)
|
| 153 |
+
return
|
| 154 |
+
if self.is_tt:
|
| 155 |
+
try:
|
| 156 |
+
import torch_xla
|
| 157 |
+
|
| 158 |
+
torch_xla.sync(wait=wait)
|
| 159 |
+
return
|
| 160 |
+
except Exception:
|
| 161 |
+
pass
|
| 162 |
+
if self.xm is not None:
|
| 163 |
+
try:
|
| 164 |
+
self.xm.mark_step()
|
| 165 |
+
except Exception:
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
def optimizer_step(self, optimizer: torch.optim.Optimizer) -> None:
|
| 169 |
+
if self.is_tt and self.xm is not None:
|
| 170 |
+
try:
|
| 171 |
+
self.xm.optimizer_step(optimizer, barrier=True)
|
| 172 |
+
except TypeError:
|
| 173 |
+
self.xm.optimizer_step(optimizer)
|
| 174 |
+
else:
|
| 175 |
+
optimizer.step()
|
| 176 |
+
|
| 177 |
+
def maybe_mark_batch_sharding(self, *tensors: torch.Tensor) -> None:
|
| 178 |
+
if not (self.is_tt and self.spmd and self.xs is not None and self.mesh is not None):
|
| 179 |
+
return
|
| 180 |
+
for tensor in tensors:
|
| 181 |
+
if tensor is None:
|
| 182 |
+
continue
|
| 183 |
+
try:
|
| 184 |
+
if tensor.ndim == 1:
|
| 185 |
+
self.xs.mark_sharding(tensor, self.mesh, ("batch",))
|
| 186 |
+
elif tensor.ndim >= 2:
|
| 187 |
+
spec = ["batch"] + [None] * (tensor.ndim - 1)
|
| 188 |
+
self.xs.mark_sharding(tensor, self.mesh, tuple(spec))
|
| 189 |
+
except Exception:
|
| 190 |
+
# Sharding is best-effort and still fairly sharp-edged.
|
| 191 |
+
pass
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
RUNTIME = BackendRuntime(backend="cpu", device=torch.device("cpu"))
|
| 195 |
+
DEV = RUNTIME.device
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def setup_runtime(args) -> BackendRuntime:
|
| 199 |
+
global RUNTIME, DEV
|
| 200 |
+
|
| 201 |
+
if getattr(args, "backend", "auto") == "tt" and (
|
| 202 |
+
getattr(args, "tt_bfp8", False) or getattr(args, "tt_weight_bfp8", False)
|
| 203 |
+
) and getattr(args, "tt_dtype", "bf16") != "bf16":
|
| 204 |
+
print("[tt-xla] forcing --tt_dtype bf16 because bfp8 conversion requires a bf16 model input dtype")
|
| 205 |
+
args.tt_dtype = "bf16"
|
| 206 |
+
|
| 207 |
+
requested = getattr(args, "backend", "auto")
|
| 208 |
+
if requested == "auto":
|
| 209 |
+
if os.environ.get("PJRT_DEVICE", "").upper() == "TT":
|
| 210 |
+
requested = "tt"
|
| 211 |
+
elif torch.cuda.is_available():
|
| 212 |
+
requested = "cuda"
|
| 213 |
+
else:
|
| 214 |
+
requested = "cpu"
|
| 215 |
+
|
| 216 |
+
if requested == "cuda":
|
| 217 |
+
runtime = BackendRuntime(
|
| 218 |
+
backend="cuda",
|
| 219 |
+
device=torch.device("cuda"),
|
| 220 |
+
is_cuda=True,
|
| 221 |
+
dtype=torch.float32,
|
| 222 |
+
)
|
| 223 |
+
RUNTIME = runtime
|
| 224 |
+
DEV = runtime.device
|
| 225 |
+
return runtime
|
| 226 |
+
|
| 227 |
+
if requested == "tt":
|
| 228 |
+
os.environ.setdefault("PJRT_DEVICE", "TT")
|
| 229 |
+
os.environ.setdefault("XLA_STABLEHLO_COMPILE", "1")
|
| 230 |
+
if getattr(args, "tt_spmd", False):
|
| 231 |
+
os.environ.setdefault("XLA_ALWAYS_ALLREDUCE", "1")
|
| 232 |
+
os.environ.setdefault("CONVERT_SHLO_TO_SHARDY", "1")
|
| 233 |
+
if getattr(args, "tt_trace", False):
|
| 234 |
+
os.environ.setdefault(
|
| 235 |
+
"TT_RUNTIME_TRACE_REGION_SIZE",
|
| 236 |
+
str(getattr(args, "tt_trace_region_size", 10_000_000)),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
import numpy as np # local import to avoid dependency unless needed
|
| 240 |
+
import torch_xla
|
| 241 |
+
import torch_xla.core.xla_model as xm
|
| 242 |
+
import torch_xla.runtime as xr
|
| 243 |
+
|
| 244 |
+
xr.set_device_type("TT")
|
| 245 |
+
compile_options = {
|
| 246 |
+
"optimization_level": str(getattr(args, "tt_optimization_level", 1)),
|
| 247 |
+
}
|
| 248 |
+
if getattr(args, "tt_bfp8", False):
|
| 249 |
+
compile_options["enable_bfp8_conversion"] = "true"
|
| 250 |
+
if getattr(args, "tt_weight_bfp8", False):
|
| 251 |
+
compile_options["experimental_enable_weight_bfp8_conversion"] = "true"
|
| 252 |
+
if getattr(args, "tt_trace", False):
|
| 253 |
+
compile_options["enable_trace"] = "true"
|
| 254 |
+
torch_xla.set_custom_compile_options(compile_options)
|
| 255 |
+
|
| 256 |
+
xs = None
|
| 257 |
+
mesh = None
|
| 258 |
+
num_devices = 1
|
| 259 |
+
if getattr(args, "tt_spmd", False):
|
| 260 |
+
try:
|
| 261 |
+
import torch_xla.distributed.spmd as xs
|
| 262 |
+
from torch_xla.distributed.spmd import Mesh
|
| 263 |
+
|
| 264 |
+
xr.use_spmd()
|
| 265 |
+
num_devices = xr.global_runtime_device_count()
|
| 266 |
+
mesh = Mesh(
|
| 267 |
+
device_ids=np.arange(num_devices),
|
| 268 |
+
mesh_shape=(1, num_devices),
|
| 269 |
+
axis_names=("batch", "model"),
|
| 270 |
+
)
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"[tt-spmd] disabled due to setup failure: {e}")
|
| 273 |
+
xs = None
|
| 274 |
+
mesh = None
|
| 275 |
+
num_devices = 1
|
| 276 |
+
|
| 277 |
+
runtime = BackendRuntime(
|
| 278 |
+
backend="tt",
|
| 279 |
+
device=xm.xla_device(),
|
| 280 |
+
is_tt=True,
|
| 281 |
+
is_xla=True,
|
| 282 |
+
dtype=torch.bfloat16 if getattr(args, "tt_dtype", "bf16") == "bf16" else torch.float32,
|
| 283 |
+
xm=xm,
|
| 284 |
+
xr=xr,
|
| 285 |
+
xs=xs,
|
| 286 |
+
mesh=mesh,
|
| 287 |
+
spmd=bool(mesh is not None),
|
| 288 |
+
compile_options=compile_options,
|
| 289 |
+
num_devices=num_devices,
|
| 290 |
+
)
|
| 291 |
+
RUNTIME = runtime
|
| 292 |
+
DEV = runtime.device
|
| 293 |
+
return runtime
|
| 294 |
+
|
| 295 |
+
runtime = BackendRuntime(backend="cpu", device=torch.device("cpu"), dtype=torch.float32)
|
| 296 |
+
RUNTIME = runtime
|
| 297 |
+
DEV = runtime.device
|
| 298 |
+
return runtime
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# ───────────────────────── Tokenizer / vocab ─────────────────────────
|
| 302 |
+
TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V3.2")
|
| 303 |
+
tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
|
| 304 |
+
if tok.pad_token is None:
|
| 305 |
+
tok.add_special_tokens({"pad_token": "<|pad|>"})
|
| 306 |
+
VOCAB = max(tok.get_vocab().values()) + 1
|
| 307 |
+
EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
|
| 308 |
+
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else (EOS if EOS is not None else 0)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# ───────────────────────── Presets / defaults ─────────────────────────
|
| 312 |
+
PRESETS: Dict[str, Dict[str, int]] = {
|
| 313 |
+
"femto_1x": dict(d=16, layers=1, heads=1, rank=16),
|
| 314 |
+
"femto_12x": dict(d=16, layers=1, heads=1, rank=192),
|
| 315 |
+
"femto_24x": dict(d=16, layers=1, heads=1, rank=384),
|
| 316 |
+
"pico_1x": dict(d=32, layers=1, heads=2, rank=16),
|
| 317 |
+
"pico_3x": dict(d=32, layers=1, heads=2, rank=48),
|
| 318 |
+
"pico_6x": dict(d=32, layers=1, heads=2, rank=96),
|
| 319 |
+
"pico_12x": dict(d=32, layers=1, heads=2, rank=192),
|
| 320 |
+
"pico_24x": dict(d=32, layers=1, heads=2, rank=384),
|
| 321 |
+
"pico_48x": dict(d=32, layers=1, heads=2, rank=768),
|
| 322 |
+
"nano_1x": dict(d=64, layers=2, heads=4, rank=16),
|
| 323 |
+
"nano_3x": dict(d=64, layers=2, heads=4, rank=48),
|
| 324 |
+
"nano_6x": dict(d=64, layers=2, heads=4, rank=96),
|
| 325 |
+
"nano_12x": dict(d=64, layers=2, heads=4, rank=192),
|
| 326 |
+
"nano_24x": dict(d=64, layers=2, heads=4, rank=384),
|
| 327 |
+
"nano_48x": dict(d=64, layers=2, heads=4, rank=768),
|
| 328 |
+
"nano_96x": dict(d=64, layers=2, heads=4, rank=1536),
|
| 329 |
+
"micro_3x": dict(d=128, layers=4, heads=8, rank=48),
|
| 330 |
+
"micro_6x": dict(d=128, layers=4, heads=8, rank=96),
|
| 331 |
+
"micro_12x": dict(d=128, layers=4, heads=8, rank=192),
|
| 332 |
+
"micro_24x": dict(d=128, layers=4, heads=8, rank=384),
|
| 333 |
+
"small": dict(d=512, layers=8, heads=16, rank=64),
|
| 334 |
+
"smallx2": dict(d=512, layers=16, heads=16, rank=64),
|
| 335 |
+
"base": dict(d=768, layers=12, heads=24, rank=96),
|
| 336 |
+
"base18": dict(d=768, layers=18, heads=24, rank=96),
|
| 337 |
+
"large": dict(d=1024, layers=24, heads=16, rank=128),
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
DEFAULT_BLOCK = 1122
|
| 341 |
+
DEFAULT_BATCH = 1
|
| 342 |
+
SAT_BLOCK = 2
|
| 343 |
+
LR_CORE, LR_HEAD = 5e-5, 2e-4
|
| 344 |
+
EMIT_LAMBDA = 0.1
|
| 345 |
+
DEFAULT_SAVE_SEC = 24 * 3600
|
| 346 |
+
CKDIR = pathlib.Path("ckpts_expansion")
|
| 347 |
+
|
| 348 |
+
DEFAULT_PRETRAIN_SOURCES = (
|
| 349 |
+
"OpenTransformer/goddess-crawl,OpenTransformer/agillm-crawl-data,"
|
| 350 |
+
"OpenTransformer/web-crawl-2026,OpenTransformer/web-crawl-clean-v2,"
|
| 351 |
+
"OpenTransformer/scraped-web-data,OpenTransformer/turbo-crawl,"
|
| 352 |
+
"OpenTransformer/sft-data-clean,OpenTransformer/web-crawl-v1"
|
| 353 |
+
)
|
| 354 |
+
DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k"
|
| 355 |
+
DEFAULT_AFTER_SFT_BLOCK = 1122
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# ───────────────────────── Utilities ─────────────────────────
|
| 359 |
+
def get_uk_time() -> str:
|
| 360 |
+
utc_now = datetime.now(timezone.utc)
|
| 361 |
+
year = utc_now.year
|
| 362 |
+
march_last = datetime(year, 3, 31, 1, 0, tzinfo=timezone.utc)
|
| 363 |
+
while march_last.weekday() != 6:
|
| 364 |
+
march_last = march_last.replace(day=march_last.day - 1)
|
| 365 |
+
oct_last = datetime(year, 10, 31, 1, 0, tzinfo=timezone.utc)
|
| 366 |
+
while oct_last.weekday() != 6:
|
| 367 |
+
oct_last = oct_last.replace(day=oct_last.day - 1)
|
| 368 |
+
if march_last <= utc_now < oct_last:
|
| 369 |
+
uk_offset = 1
|
| 370 |
+
tz_name = "BST"
|
| 371 |
+
else:
|
| 372 |
+
uk_offset = 0
|
| 373 |
+
tz_name = "GMT"
|
| 374 |
+
uk_time = utc_now + timedelta(hours=uk_offset)
|
| 375 |
+
return uk_time.strftime(f"%Y-%m-%d %H:%M:%S {tz_name}")
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def _is_probably_ckpt(path: pathlib.Path) -> bool:
|
| 379 |
+
try:
|
| 380 |
+
return (
|
| 381 |
+
path.is_file()
|
| 382 |
+
and path.suffix == ".pt"
|
| 383 |
+
and not path.name.endswith(".pt.tmp")
|
| 384 |
+
and path.stat().st_size > (1 << 20)
|
| 385 |
+
)
|
| 386 |
+
except Exception:
|
| 387 |
+
return False
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def _resolve_ckpt(path: pathlib.Path) -> Optional[pathlib.Path]:
|
| 391 |
+
try:
|
| 392 |
+
if path.is_dir():
|
| 393 |
+
cands = sorted(
|
| 394 |
+
[p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
|
| 395 |
+
key=lambda p: p.stat().st_mtime,
|
| 396 |
+
reverse=True,
|
| 397 |
+
)
|
| 398 |
+
return cands[0] if cands else None
|
| 399 |
+
if path.suffix == ".tmp":
|
| 400 |
+
solid = path.with_suffix("")
|
| 401 |
+
return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
|
| 402 |
+
return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
|
| 403 |
+
except Exception:
|
| 404 |
+
return None
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _try_load(path: pathlib.Path, map_location="cpu"):
|
| 408 |
+
try:
|
| 409 |
+
return torch.load(path, map_location=map_location)
|
| 410 |
+
except Exception as e:
|
| 411 |
+
print(f"[ckpt-skip] {path} not usable: {e}")
|
| 412 |
+
return None
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def _strip_compiled_prefix(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 416 |
+
return {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def _tree_to_cpu(obj: Any) -> Any:
|
| 420 |
+
if torch.is_tensor(obj):
|
| 421 |
+
return obj.detach().cpu()
|
| 422 |
+
if isinstance(obj, dict):
|
| 423 |
+
return {k: _tree_to_cpu(v) for k, v in obj.items()}
|
| 424 |
+
if isinstance(obj, list):
|
| 425 |
+
return [_tree_to_cpu(v) for v in obj]
|
| 426 |
+
if isinstance(obj, tuple):
|
| 427 |
+
return tuple(_tree_to_cpu(v) for v in obj)
|
| 428 |
+
return obj
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def optimizer_to(optimizer: torch.optim.Optimizer, device: torch.device) -> None:
|
| 432 |
+
for state in optimizer.state.values():
|
| 433 |
+
if not isinstance(state, dict):
|
| 434 |
+
continue
|
| 435 |
+
for k, v in list(state.items()):
|
| 436 |
+
if torch.is_tensor(v):
|
| 437 |
+
state[k] = v.to(device)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: Optional[int]):
|
| 441 |
+
if max_ckpts is None or max_ckpts <= 0:
|
| 442 |
+
return
|
| 443 |
+
try:
|
| 444 |
+
for tmp in save_dir.glob("*.pt.tmp"):
|
| 445 |
+
try:
|
| 446 |
+
tmp.unlink()
|
| 447 |
+
print(f" [prune] cleaned stale tmp {tmp.name}")
|
| 448 |
+
except Exception:
|
| 449 |
+
pass
|
| 450 |
+
pattern = f"{phase_name}_step*.pt"
|
| 451 |
+
ckpts = sorted(
|
| 452 |
+
[p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)],
|
| 453 |
+
key=lambda p: p.stat().st_mtime,
|
| 454 |
+
)
|
| 455 |
+
excess = len(ckpts) - max_ckpts
|
| 456 |
+
if excess > 0:
|
| 457 |
+
for p in ckpts[:excess]:
|
| 458 |
+
try:
|
| 459 |
+
p.unlink()
|
| 460 |
+
print(f" [prune] deleted old {p.name}")
|
| 461 |
+
except Exception:
|
| 462 |
+
pass
|
| 463 |
+
except Exception as e:
|
| 464 |
+
print(f"[ckpt-prune] error: {e}")
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def print_expansion_info(cfg: dict, tie_weights: bool = False):
|
| 468 |
+
d_k = cfg["d"] // cfg["heads"]
|
| 469 |
+
rank = cfg["rank"]
|
| 470 |
+
ratio = rank / d_k
|
| 471 |
+
regime = "COMPRESSION" if ratio < 1 else ("IDENTITY" if ratio == 1 else "EXPANSION")
|
| 472 |
+
tie_str = "YES" if tie_weights else "NO"
|
| 473 |
+
print("┌─────────────────────────────────────────┐")
|
| 474 |
+
print("│ TUNEABLE ATTENTION CONFIG │")
|
| 475 |
+
print("├─────────────────────────────────────────┤")
|
| 476 |
+
print(f"│ d_model: {cfg['d']:4d} heads: {cfg['heads']:2d} d_k: {d_k:3d} │")
|
| 477 |
+
print(f"│ layers: {cfg['layers']:4d} tie_weights: {tie_str:3s} │")
|
| 478 |
+
print(f"│ rank: {rank:4d} ratio: {ratio:.1f}x [{regime:11s}] │")
|
| 479 |
+
print("└─────────────────────────────────────────┘")
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _parse_grow_plan(s: str) -> List[int]:
|
| 483 |
+
return sorted(set(int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128))
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def _count_enabled_params(*modules) -> int:
|
| 487 |
+
seen_data_ptrs = set()
|
| 488 |
+
total = 0
|
| 489 |
+
for m in modules:
|
| 490 |
+
if m is None:
|
| 491 |
+
continue
|
| 492 |
+
for p in m.parameters():
|
| 493 |
+
if p.data_ptr() not in seen_data_ptrs:
|
| 494 |
+
seen_data_ptrs.add(p.data_ptr())
|
| 495 |
+
total += p.numel()
|
| 496 |
+
return total
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool):
|
| 500 |
+
for p in core.parameters():
|
| 501 |
+
p.requires_grad = not freeze_core
|
| 502 |
+
if freeze_core:
|
| 503 |
+
if unfreeze_ln:
|
| 504 |
+
for blk in core.blocks:
|
| 505 |
+
for p in blk.ln1.parameters():
|
| 506 |
+
p.requires_grad = True
|
| 507 |
+
for p in blk.ln2.parameters():
|
| 508 |
+
p.requires_grad = True
|
| 509 |
+
for p in core.ln.parameters():
|
| 510 |
+
p.requires_grad = True
|
| 511 |
+
if train_emb:
|
| 512 |
+
for p in core.emb.parameters():
|
| 513 |
+
p.requires_grad = True
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def retie_weights(core: nn.Module, ar_h: nn.Module, tie_weights: bool) -> None:
|
| 517 |
+
if tie_weights:
|
| 518 |
+
ar_h.proj.weight = core.emb.weight
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
# ───────────────────────── AMP helper ─────────────────────────
|
| 522 |
+
try:
|
| 523 |
+
from torch.amp import GradScaler, autocast as _ac
|
| 524 |
+
except ImportError:
|
| 525 |
+
from torch.cuda.amp import GradScaler, autocast as _ac
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def _auto_amp_dtype():
|
| 529 |
+
if DEV.type == "cuda":
|
| 530 |
+
try:
|
| 531 |
+
if torch.cuda.is_bf16_supported():
|
| 532 |
+
return torch.bfloat16
|
| 533 |
+
return torch.float16
|
| 534 |
+
except Exception:
|
| 535 |
+
return torch.float16
|
| 536 |
+
return torch.float32
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def amp(enabled: bool):
|
| 540 |
+
if not (enabled and DEV.type == "cuda"):
|
| 541 |
+
return nullcontext()
|
| 542 |
+
try:
|
| 543 |
+
return _ac(device_type="cuda", dtype=_auto_amp_dtype())
|
| 544 |
+
except TypeError:
|
| 545 |
+
return _ac(dtype=_auto_amp_dtype())
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
# ───────────────────────── Chat & data stream ─────────────────────────
|
| 549 |
+
def _coerce_role(r: str) -> str:
|
| 550 |
+
r = (r or "").lower()
|
| 551 |
+
if r in {"user", "human", "customer"}:
|
| 552 |
+
return "user"
|
| 553 |
+
if r in {"assistant", "gpt", "bot"}:
|
| 554 |
+
return "assistant"
|
| 555 |
+
if r in {"system", "context"}:
|
| 556 |
+
return "system"
|
| 557 |
+
return r or "user"
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]:
|
| 561 |
+
msgs = ex.get(messages_key)
|
| 562 |
+
if msgs is None:
|
| 563 |
+
for alt in ("conversations", "dialog", "turns"):
|
| 564 |
+
if isinstance(ex.get(alt), list):
|
| 565 |
+
msgs = ex[alt]
|
| 566 |
+
break
|
| 567 |
+
if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict):
|
| 568 |
+
try:
|
| 569 |
+
norm = []
|
| 570 |
+
for m in msgs:
|
| 571 |
+
role = _coerce_role(m.get("role", ""))
|
| 572 |
+
content = m.get("content", m.get("text", ""))
|
| 573 |
+
if not isinstance(content, str):
|
| 574 |
+
continue
|
| 575 |
+
norm.append({"role": role, "content": content})
|
| 576 |
+
if not norm:
|
| 577 |
+
return None
|
| 578 |
+
return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt)
|
| 579 |
+
except Exception:
|
| 580 |
+
return None
|
| 581 |
+
for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")):
|
| 582 |
+
if isinstance(ex.get(a), str) and isinstance(ex.get(b), str):
|
| 583 |
+
return f"User: {ex[a]}\nAssistant: {ex[b]}"
|
| 584 |
+
return None
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def _open_stream_one(ds_name: str, seed: int, streaming: bool = True):
|
| 588 |
+
dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
|
| 589 |
+
if ":" in ds_name:
|
| 590 |
+
base, config = ds_name.split(":", 1)
|
| 591 |
+
else:
|
| 592 |
+
base, config = ds_name, None
|
| 593 |
+
if not streaming:
|
| 594 |
+
print(f"[download] Downloading {ds_name} (non-streaming)...")
|
| 595 |
+
if base == "json":
|
| 596 |
+
data_files = {"train": config}
|
| 597 |
+
ds = load_dataset("json", data_files=data_files, split="train", streaming=streaming, download_config=dc)
|
| 598 |
+
else:
|
| 599 |
+
ds = (
|
| 600 |
+
load_dataset(base, config, split="train", streaming=streaming, download_config=dc)
|
| 601 |
+
if config
|
| 602 |
+
else load_dataset(base, split="train", streaming=streaming, download_config=dc)
|
| 603 |
+
)
|
| 604 |
+
if streaming:
|
| 605 |
+
return iter(ds.shuffle(buffer_size=1000, seed=seed))
|
| 606 |
+
print(f"[download] Got {len(ds):,} examples. Shuffling...")
|
| 607 |
+
ds = ds.shuffle(seed=seed)
|
| 608 |
+
return iter(ds)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
_HOT_CFG_PATH = pathlib.Path("/workspace/hot_config.json")
|
| 612 |
+
_hot_cache = {"mtime": 0, "data": {}}
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def get_hot_datasets(default):
|
| 616 |
+
try:
|
| 617 |
+
if _HOT_CFG_PATH.exists():
|
| 618 |
+
mt = _HOT_CFG_PATH.stat().st_mtime
|
| 619 |
+
if mt > _hot_cache["mtime"]:
|
| 620 |
+
_hot_cache["data"] = json.loads(_HOT_CFG_PATH.read_text())
|
| 621 |
+
_hot_cache["mtime"] = mt
|
| 622 |
+
cfg = _hot_cache["data"]
|
| 623 |
+
if "datasets" in cfg:
|
| 624 |
+
ds = cfg["datasets"]
|
| 625 |
+
if isinstance(ds, list):
|
| 626 |
+
ds = ",".join(ds)
|
| 627 |
+
print(f"[HOT] Using: {ds[:60]}...")
|
| 628 |
+
return ds
|
| 629 |
+
except Exception as e:
|
| 630 |
+
print(f"[HOT] Error: {e}")
|
| 631 |
+
return default
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def token_stream(
|
| 635 |
+
ds_names: str,
|
| 636 |
+
target: int,
|
| 637 |
+
seed: int = 42,
|
| 638 |
+
chat: bool = False,
|
| 639 |
+
chat_messages_key: str = "messages",
|
| 640 |
+
sft_add_generation_prompt: bool = False,
|
| 641 |
+
dataset_field_text: str = "text",
|
| 642 |
+
streaming: bool = True,
|
| 643 |
+
):
|
| 644 |
+
ds_names = get_hot_datasets(ds_names)
|
| 645 |
+
sources = [s.strip() for s in ds_names.split(",") if s.strip()]
|
| 646 |
+
if not sources:
|
| 647 |
+
return
|
| 648 |
+
src_idx = 0
|
| 649 |
+
emitted = 0
|
| 650 |
+
it = None
|
| 651 |
+
attempts = 0
|
| 652 |
+
backoff_base = 2.0
|
| 653 |
+
while emitted < target:
|
| 654 |
+
try:
|
| 655 |
+
if it is None:
|
| 656 |
+
it = _open_stream_one(sources[src_idx], seed, streaming=streaming)
|
| 657 |
+
ex = next(it)
|
| 658 |
+
text = None
|
| 659 |
+
if isinstance(ex, dict):
|
| 660 |
+
if chat:
|
| 661 |
+
text = _render_chat_text_from_ex(ex, chat_messages_key, sft_add_generation_prompt)
|
| 662 |
+
if text is None:
|
| 663 |
+
if dataset_field_text and isinstance(ex.get(dataset_field_text), str):
|
| 664 |
+
text = ex[dataset_field_text]
|
| 665 |
+
elif isinstance(ex.get("text"), str):
|
| 666 |
+
text = ex["text"]
|
| 667 |
+
if not isinstance(text, str):
|
| 668 |
+
attempts = 0
|
| 669 |
+
continue
|
| 670 |
+
enc = tok.encode(text)
|
| 671 |
+
if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
|
| 672 |
+
enc = enc + [EOS]
|
| 673 |
+
for t in enc:
|
| 674 |
+
yield t
|
| 675 |
+
emitted += 1
|
| 676 |
+
if emitted >= target:
|
| 677 |
+
return
|
| 678 |
+
attempts = 0
|
| 679 |
+
except StopIteration:
|
| 680 |
+
it = None
|
| 681 |
+
src_idx = (src_idx + 1) % len(sources)
|
| 682 |
+
except Exception as e:
|
| 683 |
+
attempts += 1
|
| 684 |
+
sleep_s = min(60.0, backoff_base ** min(attempts, 6))
|
| 685 |
+
print(f"[stream-retry] {sources[src_idx]} error: {type(e).__name__}, sleeping {sleep_s:.1f}s")
|
| 686 |
+
time.sleep(sleep_s)
|
| 687 |
+
it = None
|
| 688 |
+
if attempts % 2 == 0 and len(sources) > 1:
|
| 689 |
+
src_idx = (src_idx + 1) % len(sources)
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
# ───────────────────────── ALiBi ─────────────────────────
|
| 693 |
+
@torch._dynamo.disable
|
| 694 |
+
def _alibi_slopes(n_heads: int):
|
| 695 |
+
def pow2slopes(n):
|
| 696 |
+
start = 2 ** (-2 ** -(math.log2(n) - 3))
|
| 697 |
+
ratio = start
|
| 698 |
+
return [start * (ratio**i) for i in range(n)]
|
| 699 |
+
|
| 700 |
+
if math.log2(n_heads).is_integer():
|
| 701 |
+
vals = pow2slopes(n_heads)
|
| 702 |
+
else:
|
| 703 |
+
closest = 2 ** math.floor(math.log2(n_heads))
|
| 704 |
+
vals = pow2slopes(closest)
|
| 705 |
+
extra = pow2slopes(2 * closest)
|
| 706 |
+
vals += extra[0::2][: n_heads - closest]
|
| 707 |
+
return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
@torch._dynamo.disable
|
| 711 |
+
def alibi_bias(n_heads: int, n_tokens: int):
|
| 712 |
+
i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
|
| 713 |
+
j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
|
| 714 |
+
dist = (j - i).clamp_min(0)
|
| 715 |
+
return -_alibi_slopes(n_heads) * dist
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# ───────────────────────── Model components ─────────────────────────
|
| 719 |
+
class TuneableAttentionMHA(nn.Module):
|
| 720 |
+
def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
|
| 721 |
+
super().__init__()
|
| 722 |
+
assert d % h == 0
|
| 723 |
+
self.h, self.dk, self.r = h, d // h, r
|
| 724 |
+
self.use_relpos = use_relpos
|
| 725 |
+
self.q = nn.Linear(d, d, bias=False)
|
| 726 |
+
self.k = nn.Linear(d, d, bias=False)
|
| 727 |
+
self.v = nn.Linear(d, d, bias=False)
|
| 728 |
+
self.U = nn.Parameter(torch.randn(self.dk, r))
|
| 729 |
+
nn.init.orthogonal_(self.U)
|
| 730 |
+
self.proj = nn.Linear(h * self.dk, d, bias=False)
|
| 731 |
+
self.drop = nn.Dropout(0.1)
|
| 732 |
+
|
| 733 |
+
def _proj_qk(self, x):
|
| 734 |
+
B, N, _ = x.shape
|
| 735 |
+
return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
|
| 736 |
+
|
| 737 |
+
def _reshape_v(self, x):
|
| 738 |
+
B, N, _ = x.shape
|
| 739 |
+
return x.view(B, N, self.h, self.dk).transpose(1, 2)
|
| 740 |
+
|
| 741 |
+
def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
|
| 742 |
+
q = self._proj_qk(self.q(x))
|
| 743 |
+
k_new = self._proj_qk(self.k(x))
|
| 744 |
+
v_new = self._reshape_v(self.v(x))
|
| 745 |
+
if kv_cache is None:
|
| 746 |
+
k, v = k_new, v_new
|
| 747 |
+
else:
|
| 748 |
+
k_cached, v_cached = kv_cache
|
| 749 |
+
if use_cache:
|
| 750 |
+
k = torch.cat([k_cached, k_new], dim=2)
|
| 751 |
+
v = torch.cat([v_cached, v_new], dim=2)
|
| 752 |
+
else:
|
| 753 |
+
k, v = k_new, v_new
|
| 754 |
+
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
|
| 755 |
+
if self.use_relpos and rel_bias_tokens is not None:
|
| 756 |
+
att = att + alibi_bias(self.h, rel_bias_tokens).to(att.dtype)[:, :, -q.size(2) :, :]
|
| 757 |
+
if mask is not None:
|
| 758 |
+
att = att + mask.to(att.dtype)
|
| 759 |
+
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(x.size(0), x.size(1), -1)
|
| 760 |
+
out = self.drop(self.proj(z))
|
| 761 |
+
return (out, (k, v)) if use_cache else out
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
class Block(nn.Module):
|
| 765 |
+
def __init__(self, d: int, h: int, r: int):
|
| 766 |
+
super().__init__()
|
| 767 |
+
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
|
| 768 |
+
self.mha = TuneableAttentionMHA(d, h, r)
|
| 769 |
+
self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
|
| 770 |
+
|
| 771 |
+
def forward(self, x, mask, kv=None, use_cache=False, total_seq_len=None):
|
| 772 |
+
if use_cache:
|
| 773 |
+
y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=total_seq_len, kv_cache=kv, use_cache=True)
|
| 774 |
+
x = x + y + self.ff(self.ln2(x + y))
|
| 775 |
+
return x, new_kv
|
| 776 |
+
n = x.size(1)
|
| 777 |
+
x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
|
| 778 |
+
return x + self.ff(self.ln2(x))
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
class Encoder(nn.Module):
|
| 782 |
+
def __init__(self, cfg, tie_weights: bool = False):
|
| 783 |
+
super().__init__()
|
| 784 |
+
d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
|
| 785 |
+
self.emb = nn.Embedding(VOCAB, d)
|
| 786 |
+
self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
|
| 787 |
+
self.ln = nn.LayerNorm(d)
|
| 788 |
+
self.tie_weights = tie_weights
|
| 789 |
+
|
| 790 |
+
def forward(self, ids, mask, kv_caches=None, use_cache=False, total_seq_len=None):
|
| 791 |
+
x = self.emb(ids)
|
| 792 |
+
if not use_cache:
|
| 793 |
+
for blk in self.blocks:
|
| 794 |
+
x = blk(x, mask)
|
| 795 |
+
return self.ln(x)
|
| 796 |
+
new_kvs = []
|
| 797 |
+
for i, blk in enumerate(self.blocks):
|
| 798 |
+
kv = kv_caches[i] if kv_caches else None
|
| 799 |
+
x, kv_out = blk(x, mask, kv, use_cache=True, total_seq_len=total_seq_len)
|
| 800 |
+
new_kvs.append(kv_out)
|
| 801 |
+
return self.ln(x), new_kvs
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
class ARHead(nn.Module):
|
| 805 |
+
def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None):
|
| 806 |
+
super().__init__()
|
| 807 |
+
self.tie_weights = tie_weights
|
| 808 |
+
if tie_weights and embedding_weight is not None:
|
| 809 |
+
self.proj = nn.Linear(d, VOCAB, bias=False)
|
| 810 |
+
self.proj.weight = embedding_weight
|
| 811 |
+
else:
|
| 812 |
+
self.proj = nn.Linear(d, VOCAB)
|
| 813 |
+
|
| 814 |
+
def forward(self, h):
|
| 815 |
+
return self.proj(h)
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
class SATHead(nn.Module):
|
| 819 |
+
def __init__(self, d, mode="var"):
|
| 820 |
+
super().__init__()
|
| 821 |
+
self.proj = nn.Linear(d, VOCAB)
|
| 822 |
+
self.gate = nn.Linear(d, 2) if mode == "var" else None
|
| 823 |
+
|
| 824 |
+
def forward(self, h_last):
|
| 825 |
+
return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
# ───────────────────────── Masks ─────────────────────────
|
| 829 |
+
def causal_mask(n):
|
| 830 |
+
return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
def sat_mask(n, block=SAT_BLOCK):
|
| 834 |
+
idx = torch.arange(n, device=DEV)
|
| 835 |
+
grp = idx.unsqueeze(0) // block
|
| 836 |
+
allow = (grp.T == grp) | (grp.T > grp)
|
| 837 |
+
return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK):
|
| 841 |
+
total_len = cached_len + new_len
|
| 842 |
+
return torch.zeros((1, 1, new_len, total_len), device=DEV)
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
def causal_padded_mask(total_len: int, valid_len: int):
|
| 846 |
+
mask = causal_mask(total_len)
|
| 847 |
+
if valid_len < total_len:
|
| 848 |
+
mask[:, :, :, valid_len:] = float("-inf")
|
| 849 |
+
mask[:, :, valid_len:, :] = float("-inf")
|
| 850 |
+
return mask
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
def sat_padded_mask(total_len: int, valid_len: int):
|
| 854 |
+
mask = sat_mask(total_len)
|
| 855 |
+
if valid_len < total_len:
|
| 856 |
+
mask[:, :, :, valid_len:] = float("-inf")
|
| 857 |
+
mask[:, :, valid_len:, :] = float("-inf")
|
| 858 |
+
return mask
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
# ───────────────────────── Checkpoint helpers ─────────────────────────
|
| 862 |
+
def save_ckpt(path: pathlib.Path, core, ar_h, sat_h, opt, scaler, meta):
|
| 863 |
+
if RUNTIME.is_tt:
|
| 864 |
+
RUNTIME.sync(wait=True)
|
| 865 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
| 866 |
+
tmp = path.with_suffix(path.suffix + ".tmp")
|
| 867 |
+
state = {
|
| 868 |
+
"core": _tree_to_cpu(_strip_compiled_prefix(core.state_dict())),
|
| 869 |
+
"ar": _tree_to_cpu(_strip_compiled_prefix(ar_h.state_dict())),
|
| 870 |
+
"sat": _tree_to_cpu(_strip_compiled_prefix(sat_h.state_dict())),
|
| 871 |
+
"opt": _tree_to_cpu(opt.state_dict()),
|
| 872 |
+
"scaler": _tree_to_cpu(scaler.state_dict()),
|
| 873 |
+
"cfg": meta.get("cfg"),
|
| 874 |
+
"tokenizer_id": TOKENIZER_ID,
|
| 875 |
+
"tie_weights": meta.get("tie_weights", False),
|
| 876 |
+
**{k: v for k, v in meta.items() if k not in ("cfg", "tie_weights")},
|
| 877 |
+
}
|
| 878 |
+
torch.save(state, tmp, _use_new_zipfile_serialization=False)
|
| 879 |
+
tmp.replace(path)
|
| 880 |
+
(path.parent / "latest.json").write_text(
|
| 881 |
+
json.dumps(
|
| 882 |
+
{
|
| 883 |
+
"path": str(path),
|
| 884 |
+
"step": meta["step"],
|
| 885 |
+
"block_size": meta.get("block_size"),
|
| 886 |
+
"batch_size": meta.get("batch_size"),
|
| 887 |
+
"seen_tok": meta.get("seen_tok"),
|
| 888 |
+
}
|
| 889 |
+
)
|
| 890 |
+
)
|
| 891 |
+
print(f"\n✓ saved checkpoint {path.name}")
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
def load_ckpt(path, core, ar_h, sat_h, opt, scaler):
|
| 896 |
+
p = _resolve_ckpt(path) or path
|
| 897 |
+
ck = _try_load(p, map_location="cpu")
|
| 898 |
+
if ck is None:
|
| 899 |
+
raise FileNotFoundError(f"No valid checkpoint at {p}")
|
| 900 |
+
core.load_state_dict(_strip_compiled_prefix(ck["core"]))
|
| 901 |
+
ar_h.load_state_dict(_strip_compiled_prefix(ck["ar"]))
|
| 902 |
+
sat_h.load_state_dict(_strip_compiled_prefix(ck["sat"]))
|
| 903 |
+
try:
|
| 904 |
+
opt.load_state_dict(ck["opt"])
|
| 905 |
+
optimizer_to(opt, DEV)
|
| 906 |
+
except Exception as e:
|
| 907 |
+
print(f"[resume] optimizer state skipped: {e}")
|
| 908 |
+
if ck.get("scaler"):
|
| 909 |
+
try:
|
| 910 |
+
scaler.load_state_dict(ck["scaler"])
|
| 911 |
+
except Exception:
|
| 912 |
+
pass
|
| 913 |
+
return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time()), ck.get("block_size")
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None) -> int:
|
| 918 |
+
p = _resolve_ckpt(path) or path
|
| 919 |
+
if not p.exists():
|
| 920 |
+
return 0
|
| 921 |
+
ck = _try_load(p, map_location="cpu")
|
| 922 |
+
if ck is None:
|
| 923 |
+
return 0
|
| 924 |
+
sd = ck.get(key, ck) if key else ck
|
| 925 |
+
if isinstance(sd, dict) and "state_dict" in sd:
|
| 926 |
+
sd = sd["state_dict"]
|
| 927 |
+
tgt_sd = tgt.state_dict()
|
| 928 |
+
filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
|
| 929 |
+
if filt:
|
| 930 |
+
tgt.load_state_dict(filt, strict=False)
|
| 931 |
+
return len(filt)
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
def infer_cfg_from_ckpt(path: pathlib.Path):
|
| 936 |
+
p = _resolve_ckpt(path) or path
|
| 937 |
+
if not p.exists():
|
| 938 |
+
return None
|
| 939 |
+
sd = _try_load(p, map_location="cpu")
|
| 940 |
+
if sd is None:
|
| 941 |
+
return None
|
| 942 |
+
if "cfg" in sd:
|
| 943 |
+
return dict(sd["cfg"])
|
| 944 |
+
return None
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
# ───────────────────────── Training logic ─────────────────────────
|
| 948 |
+
def _loss_float(x: torch.Tensor) -> float:
|
| 949 |
+
try:
|
| 950 |
+
return float(x.detach().float().cpu().item())
|
| 951 |
+
except Exception:
|
| 952 |
+
return float(x.item())
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
def _forward_train_losses(args, core, ar_h, sat_h, ids, ce_tok, ce_gate):
|
| 957 |
+
h_ar = core(ids, causal_mask(ids.size(1)))
|
| 958 |
+
logits_ar = ar_h(h_ar)[:, :-1]
|
| 959 |
+
loss_ar = ce_tok(logits_ar.float().reshape(-1, VOCAB), ids[:, 1:].reshape(-1))
|
| 960 |
+
if args.ar_only:
|
| 961 |
+
return loss_ar
|
| 962 |
+
h_sat = core(ids, sat_mask(ids.size(1)))
|
| 963 |
+
logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
|
| 964 |
+
tgt_sat = ids[:, 1 : SAT_BLOCK + 1]
|
| 965 |
+
loss_sat = ce_tok(logits_sat.float().reshape(-1, VOCAB), tgt_sat.reshape(-1))
|
| 966 |
+
if gate is not None:
|
| 967 |
+
loss_sat += EMIT_LAMBDA * ce_gate(gate.float(), torch.ones(ids.size(0), device=DEV, dtype=torch.long))
|
| 968 |
+
return loss_ar + loss_sat
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
def _run_optimizer_step(args, opt, scaler, loss, trainable_params: Iterable[torch.nn.Parameter]):
|
| 973 |
+
trainable_params = list(trainable_params)
|
| 974 |
+
if args.amp and DEV.type == "cuda":
|
| 975 |
+
scaler.scale(loss).backward()
|
| 976 |
+
scaler.unscale_(opt)
|
| 977 |
+
if trainable_params:
|
| 978 |
+
nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
| 979 |
+
scaler.step(opt)
|
| 980 |
+
scaler.update()
|
| 981 |
+
return
|
| 982 |
+
|
| 983 |
+
loss.backward()
|
| 984 |
+
if trainable_params:
|
| 985 |
+
nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
| 986 |
+
RUNTIME.optimizer_step(opt)
|
| 987 |
+
if RUNTIME.is_tt:
|
| 988 |
+
RUNTIME.sync(wait=True)
|
| 989 |
+
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
def _maybe_handle_oom(e: RuntimeError) -> bool:
|
| 993 |
+
msg = str(e).lower()
|
| 994 |
+
return (
|
| 995 |
+
"out of memory" in msg
|
| 996 |
+
or "cuda out of memory" in msg
|
| 997 |
+
or "resource exhausted" in msg
|
| 998 |
+
or "failed to allocate" in msg
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
def _train_phase(
|
| 1004 |
+
args,
|
| 1005 |
+
phase_name: str,
|
| 1006 |
+
core,
|
| 1007 |
+
ar_h,
|
| 1008 |
+
sat_h,
|
| 1009 |
+
opt,
|
| 1010 |
+
scaler,
|
| 1011 |
+
start_step,
|
| 1012 |
+
seen_tok,
|
| 1013 |
+
resume_wall_time,
|
| 1014 |
+
cfg,
|
| 1015 |
+
source,
|
| 1016 |
+
steps,
|
| 1017 |
+
block_size,
|
| 1018 |
+
batch_size,
|
| 1019 |
+
chat_cfg: dict,
|
| 1020 |
+
max_ckpts: Optional[int],
|
| 1021 |
+
target_tokens_override: Optional[int] = None,
|
| 1022 |
+
tie_weights: bool = False,
|
| 1023 |
+
streaming: bool = True,
|
| 1024 |
+
):
|
| 1025 |
+
BLOCK = block_size
|
| 1026 |
+
BATCH = batch_size
|
| 1027 |
+
if target_tokens_override is not None:
|
| 1028 |
+
target_tokens = target_tokens_override
|
| 1029 |
+
else:
|
| 1030 |
+
ratio = 51.2 if args.chilla_max_double else 25
|
| 1031 |
+
param_count = _count_enabled_params(core, ar_h, sat_h)
|
| 1032 |
+
target_tokens = int(ratio * param_count)
|
| 1033 |
+
|
| 1034 |
+
if steps:
|
| 1035 |
+
phase_target_tokens = steps * BLOCK * BATCH
|
| 1036 |
+
total_tokens_needed = seen_tok + phase_target_tokens
|
| 1037 |
+
else:
|
| 1038 |
+
total_tokens_needed = target_tokens
|
| 1039 |
+
if total_tokens_needed <= seen_tok:
|
| 1040 |
+
print(f"[{phase_name}] target {total_tokens_needed} already reached.")
|
| 1041 |
+
return start_step, seen_tok, resume_wall_time
|
| 1042 |
+
|
| 1043 |
+
stream = token_stream(
|
| 1044 |
+
source,
|
| 1045 |
+
total_tokens_needed,
|
| 1046 |
+
seed=42,
|
| 1047 |
+
chat=chat_cfg.get("chat", False),
|
| 1048 |
+
chat_messages_key=chat_cfg.get("key", "messages"),
|
| 1049 |
+
sft_add_generation_prompt=chat_cfg.get("gen_prompt", False),
|
| 1050 |
+
dataset_field_text=chat_cfg.get("text_field", "text"),
|
| 1051 |
+
streaming=streaming,
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
ce_tok = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
| 1055 |
+
ce_gate = nn.CrossEntropyLoss()
|
| 1056 |
+
pbar = SafeProgress(total=total_tokens_needed, initial=seen_tok, unit="tok")
|
| 1057 |
+
grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
|
| 1058 |
+
buf: List[int] = []
|
| 1059 |
+
batch_accum: List[List[int]] = []
|
| 1060 |
+
step = start_step
|
| 1061 |
+
steps_since_last_grow = 0
|
| 1062 |
+
oom_retries = 0
|
| 1063 |
+
max_oom_retries = 2
|
| 1064 |
+
|
| 1065 |
+
now_wall = time.time()
|
| 1066 |
+
last_save_mono = time.monotonic() - (now_wall - (resume_wall_time or now_wall))
|
| 1067 |
+
print(f"[{phase_name}] Starting. Goal: {total_tokens_needed:,} tokens. Batch={BATCH}, Block={BLOCK}")
|
| 1068 |
+
print(f"[{phase_name}] BACKEND={RUNTIME.backend} AR_ONLY={args.ar_only} TIE_WEIGHTS={tie_weights} STREAMING={streaming}")
|
| 1069 |
+
if RUNTIME.is_tt:
|
| 1070 |
+
print(
|
| 1071 |
+
f"[{phase_name}] TT dtype={str(RUNTIME.dtype).replace('torch.', '')} opt_level={args.tt_optimization_level} spmd={RUNTIME.spmd} devices={RUNTIME.num_devices}"
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
step_start_time = time.monotonic()
|
| 1075 |
+
tok_per_sec_avg = 0.0
|
| 1076 |
+
trainable_params = [p for p in list(core.parameters()) + list(ar_h.parameters()) + list(sat_h.parameters()) if p.requires_grad]
|
| 1077 |
+
|
| 1078 |
+
while seen_tok < total_tokens_needed:
|
| 1079 |
+
try:
|
| 1080 |
+
while len(buf) < BLOCK:
|
| 1081 |
+
buf.append(next(stream))
|
| 1082 |
+
except StopIteration:
|
| 1083 |
+
break
|
| 1084 |
+
|
| 1085 |
+
seq = buf[:BLOCK]
|
| 1086 |
+
buf = buf[BLOCK:]
|
| 1087 |
+
batch_accum.append(seq)
|
| 1088 |
+
if len(batch_accum) < BATCH:
|
| 1089 |
+
continue
|
| 1090 |
+
|
| 1091 |
+
ids = torch.tensor(batch_accum, device=DEV, dtype=torch.long)
|
| 1092 |
+
batch_accum = []
|
| 1093 |
+
if RUNTIME.is_tt:
|
| 1094 |
+
RUNTIME.maybe_mark_batch_sharding(ids)
|
| 1095 |
+
|
| 1096 |
+
try:
|
| 1097 |
+
opt.zero_grad(set_to_none=True)
|
| 1098 |
+
with amp(args.amp):
|
| 1099 |
+
loss = _forward_train_losses(args, core, ar_h, sat_h, ids, ce_tok, ce_gate)
|
| 1100 |
+
_run_optimizer_step(args, opt, scaler, loss, trainable_params)
|
| 1101 |
+
retie_weights(core, ar_h, tie_weights)
|
| 1102 |
+
except RuntimeError as e:
|
| 1103 |
+
if _maybe_handle_oom(e):
|
| 1104 |
+
batch_accum = []
|
| 1105 |
+
opt.zero_grad(set_to_none=True)
|
| 1106 |
+
if DEV.type == "cuda":
|
| 1107 |
+
torch.cuda.empty_cache()
|
| 1108 |
+
torch.cuda.synchronize()
|
| 1109 |
+
oom_retries += 1
|
| 1110 |
+
if oom_retries <= max_oom_retries:
|
| 1111 |
+
print(f"\n[{phase_name} OOM] Retry {oom_retries}/{max_oom_retries} at Batch={BATCH}, clearing caches...")
|
| 1112 |
+
time.sleep(4)
|
| 1113 |
+
continue
|
| 1114 |
+
oom_retries = 0
|
| 1115 |
+
if BATCH > 1:
|
| 1116 |
+
print(f"\n[{phase_name} OOM] Reducing Batch: {BATCH} -> {BATCH - 1}")
|
| 1117 |
+
BATCH -= 1
|
| 1118 |
+
time.sleep(4)
|
| 1119 |
+
else:
|
| 1120 |
+
if grow_plan:
|
| 1121 |
+
smaller = [b for b in grow_plan if b < BLOCK]
|
| 1122 |
+
new_block = smaller[-1] if smaller else max(128, BLOCK // 2)
|
| 1123 |
+
else:
|
| 1124 |
+
new_block = max(128, BLOCK // 2)
|
| 1125 |
+
print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}")
|
| 1126 |
+
BLOCK = new_block
|
| 1127 |
+
time.sleep(4)
|
| 1128 |
+
steps_since_last_grow = 0
|
| 1129 |
+
continue
|
| 1130 |
+
raise
|
| 1131 |
+
|
| 1132 |
+
step += 1
|
| 1133 |
+
oom_retries = 0
|
| 1134 |
+
toks_processed = BLOCK * BATCH
|
| 1135 |
+
seen_tok += toks_processed
|
| 1136 |
+
pbar.update(toks_processed)
|
| 1137 |
+
loss_value = _loss_float(loss)
|
| 1138 |
+
pbar.set_postfix(loss=f"{loss_value:.3f}", B=BATCH, L=BLOCK)
|
| 1139 |
+
|
| 1140 |
+
step_elapsed = time.monotonic() - step_start_time
|
| 1141 |
+
tok_per_sec_now = toks_processed / step_elapsed if step_elapsed > 0 else 0.0
|
| 1142 |
+
tok_per_sec_avg = 0.9 * tok_per_sec_avg + 0.1 * tok_per_sec_now if tok_per_sec_avg > 0 else tok_per_sec_now
|
| 1143 |
+
step_start_time = time.monotonic()
|
| 1144 |
+
write_status(step, seen_tok, loss_value, BATCH, BLOCK, tok_per_sec_avg, phase_name)
|
| 1145 |
+
|
| 1146 |
+
if args.save_every_sec > 0:
|
| 1147 |
+
now_mono = time.monotonic()
|
| 1148 |
+
if now_mono - last_save_mono >= args.save_every_sec:
|
| 1149 |
+
ck_name = f"{phase_name}_step{step:08d}.pt"
|
| 1150 |
+
save_ckpt(
|
| 1151 |
+
pathlib.Path(args.save_dir) / ck_name,
|
| 1152 |
+
core,
|
| 1153 |
+
ar_h,
|
| 1154 |
+
sat_h,
|
| 1155 |
+
opt,
|
| 1156 |
+
scaler,
|
| 1157 |
+
meta={
|
| 1158 |
+
"cfg": cfg,
|
| 1159 |
+
"step": step,
|
| 1160 |
+
"seen_tok": seen_tok,
|
| 1161 |
+
"wall_time": time.time(),
|
| 1162 |
+
"tie_weights": tie_weights,
|
| 1163 |
+
"block_size": BLOCK,
|
| 1164 |
+
"batch_size": BATCH,
|
| 1165 |
+
},
|
| 1166 |
+
)
|
| 1167 |
+
_prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts)
|
| 1168 |
+
last_save_mono = now_mono
|
| 1169 |
+
|
| 1170 |
+
if args.auto_grow:
|
| 1171 |
+
steps_since_last_grow += 1
|
| 1172 |
+
if steps_since_last_grow >= args.grow_every_steps:
|
| 1173 |
+
steps_since_last_grow = 0
|
| 1174 |
+
try:
|
| 1175 |
+
idx = grow_plan.index(BLOCK)
|
| 1176 |
+
if idx + 1 < len(grow_plan):
|
| 1177 |
+
BLOCK = grow_plan[idx + 1]
|
| 1178 |
+
print(f"[{phase_name} Grow] Block -> {BLOCK}")
|
| 1179 |
+
if DEV.type == "cuda":
|
| 1180 |
+
torch.cuda.empty_cache()
|
| 1181 |
+
except ValueError:
|
| 1182 |
+
grow_plan = sorted(set(grow_plan + [BLOCK]))
|
| 1183 |
+
|
| 1184 |
+
pbar.close()
|
| 1185 |
+
save_ckpt(
|
| 1186 |
+
pathlib.Path(args.save_dir) / f"{phase_name}_final.pt",
|
| 1187 |
+
core,
|
| 1188 |
+
ar_h,
|
| 1189 |
+
sat_h,
|
| 1190 |
+
opt,
|
| 1191 |
+
scaler,
|
| 1192 |
+
meta={
|
| 1193 |
+
"cfg": cfg,
|
| 1194 |
+
"step": step,
|
| 1195 |
+
"seen_tok": seen_tok,
|
| 1196 |
+
"wall_time": time.time(),
|
| 1197 |
+
"tie_weights": tie_weights,
|
| 1198 |
+
"block_size": BLOCK,
|
| 1199 |
+
"batch_size": BATCH,
|
| 1200 |
+
},
|
| 1201 |
+
)
|
| 1202 |
+
return step, seen_tok, time.time()
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
# ───────────────────────── Main orchestrator ─────────────────────────
|
| 1206 |
+
def _build_models(cfg, tie_weights: bool):
|
| 1207 |
+
core = Encoder(cfg, tie_weights=tie_weights)
|
| 1208 |
+
ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None)
|
| 1209 |
+
sat_h = SATHead(cfg["d"], mode="var")
|
| 1210 |
+
retie_weights(core, ar_h, tie_weights)
|
| 1211 |
+
return core, ar_h, sat_h
|
| 1212 |
+
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
def _maybe_cast_models_for_runtime(core, ar_h, sat_h):
|
| 1216 |
+
if RUNTIME.is_tt and RUNTIME.dtype == torch.bfloat16:
|
| 1217 |
+
core = core.to(dtype=torch.bfloat16)
|
| 1218 |
+
ar_h = ar_h.to(dtype=torch.bfloat16)
|
| 1219 |
+
sat_h = sat_h.to(dtype=torch.bfloat16)
|
| 1220 |
+
retie_weights(core, ar_h, True if getattr(core, "tie_weights", False) or getattr(ar_h, "tie_weights", False) else False)
|
| 1221 |
+
return core, ar_h, sat_h
|
| 1222 |
+
|
| 1223 |
+
|
| 1224 |
+
|
| 1225 |
+
def _move_models_to_device(core, ar_h, sat_h, tie_weights: bool):
|
| 1226 |
+
core = core.to(DEV)
|
| 1227 |
+
ar_h = ar_h.to(DEV)
|
| 1228 |
+
sat_h = sat_h.to(DEV)
|
| 1229 |
+
retie_weights(core, ar_h, tie_weights)
|
| 1230 |
+
return core, ar_h, sat_h
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
|
| 1234 |
+
def _maybe_compile_models(args, core, ar_h, sat_h, tie_weights: bool):
|
| 1235 |
+
if not args.compile:
|
| 1236 |
+
return core, ar_h, sat_h
|
| 1237 |
+
if RUNTIME.is_tt:
|
| 1238 |
+
print("[tt-xla] Skipping torch.compile for training stability; TT-XLA lazy compilation is still active.")
|
| 1239 |
+
return core, ar_h, sat_h
|
| 1240 |
+
if hasattr(torch, "compile"):
|
| 1241 |
+
print("[torch.compile] Compiling model...")
|
| 1242 |
+
core = torch.compile(core, mode="reduce-overhead")
|
| 1243 |
+
ar_h = torch.compile(ar_h, mode="reduce-overhead")
|
| 1244 |
+
sat_h = torch.compile(sat_h, mode="reduce-overhead")
|
| 1245 |
+
retie_weights(core, ar_h, tie_weights)
|
| 1246 |
+
print("[torch.compile] Done.")
|
| 1247 |
+
return core, ar_h, sat_h
|
| 1248 |
+
|
| 1249 |
+
|
| 1250 |
+
|
| 1251 |
+
def train(args):
|
| 1252 |
+
setup_runtime(args)
|
| 1253 |
+
cfg = PRESETS[args.preset].copy()
|
| 1254 |
+
tie_weights = args.tie_weights
|
| 1255 |
+
print_expansion_info(cfg, tie_weights)
|
| 1256 |
+
|
| 1257 |
+
if not args.fresh:
|
| 1258 |
+
src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
|
| 1259 |
+
prev_cfg = infer_cfg_from_ckpt(src_probe)
|
| 1260 |
+
else:
|
| 1261 |
+
prev_cfg = None
|
| 1262 |
+
if prev_cfg:
|
| 1263 |
+
cfg.update({k: v for k, v in prev_cfg.items() if k in cfg})
|
| 1264 |
+
if args.x2 and prev_cfg.get("layers"):
|
| 1265 |
+
cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
|
| 1266 |
+
if args.rank:
|
| 1267 |
+
cfg["rank"] = args.rank
|
| 1268 |
+
if args.x2 and not prev_cfg:
|
| 1269 |
+
cfg["layers"] *= 2
|
| 1270 |
+
|
| 1271 |
+
print(f"Config: {cfg}")
|
| 1272 |
+
core, ar_h, sat_h = _build_models(cfg, tie_weights=tie_weights)
|
| 1273 |
+
|
| 1274 |
+
total_params = _count_enabled_params(core, ar_h, sat_h)
|
| 1275 |
+
print(f"Total parameters: {total_params:,}")
|
| 1276 |
+
if tie_weights:
|
| 1277 |
+
print(f"{Colors.WARN}[weight-tying] Embedding and LM head share weights{Colors.RESET}")
|
| 1278 |
+
|
| 1279 |
+
if not args.fresh:
|
| 1280 |
+
src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
|
| 1281 |
+
src = _resolve_ckpt(src)
|
| 1282 |
+
if src:
|
| 1283 |
+
loaded = _safe_load_any(src, core, key="core")
|
| 1284 |
+
_safe_load_any(src, ar_h, key="ar")
|
| 1285 |
+
_safe_load_any(src, sat_h, key="sat")
|
| 1286 |
+
retie_weights(core, ar_h, tie_weights)
|
| 1287 |
+
if loaded:
|
| 1288 |
+
print(f"Warm-start loaded from {src}")
|
| 1289 |
+
|
| 1290 |
+
core, ar_h, sat_h = _maybe_cast_models_for_runtime(core, ar_h, sat_h)
|
| 1291 |
+
core, ar_h, sat_h = _move_models_to_device(core, ar_h, sat_h, tie_weights)
|
| 1292 |
+
|
| 1293 |
+
_phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb)
|
| 1294 |
+
|
| 1295 |
+
opt = torch.optim.AdamW(
|
| 1296 |
+
[
|
| 1297 |
+
{"params": [p for p in core.parameters() if p.requires_grad], "lr": args.lr_core},
|
| 1298 |
+
{"params": ar_h.parameters(), "lr": args.lr_head},
|
| 1299 |
+
{"params": sat_h.parameters(), "lr": args.lr_head},
|
| 1300 |
+
]
|
| 1301 |
+
)
|
| 1302 |
+
scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
|
| 1303 |
+
|
| 1304 |
+
start_step, seen_tok, last_wall, resumed_block = 0, 0, None, None
|
| 1305 |
+
if args.resume and not args.fresh:
|
| 1306 |
+
start_step, seen_tok, last_wall, resumed_block = load_ckpt(pathlib.Path(args.resume), core, ar_h, sat_h, opt, scaler)
|
| 1307 |
+
retie_weights(core, ar_h, tie_weights)
|
| 1308 |
+
print(f"Resumed from step {start_step}" + (f", block_size={resumed_block}" if resumed_block else ""))
|
| 1309 |
+
|
| 1310 |
+
core, ar_h, sat_h = _maybe_compile_models(args, core, ar_h, sat_h, tie_weights)
|
| 1311 |
+
|
| 1312 |
+
step, seen_tok, last_wall = _train_phase(
|
| 1313 |
+
args,
|
| 1314 |
+
"pretrain",
|
| 1315 |
+
core,
|
| 1316 |
+
ar_h,
|
| 1317 |
+
sat_h,
|
| 1318 |
+
opt,
|
| 1319 |
+
scaler,
|
| 1320 |
+
start_step,
|
| 1321 |
+
seen_tok,
|
| 1322 |
+
last_wall,
|
| 1323 |
+
cfg,
|
| 1324 |
+
args.source,
|
| 1325 |
+
args.steps,
|
| 1326 |
+
(resumed_block if resumed_block and args.auto_grow else None) or args.block or DEFAULT_BLOCK,
|
| 1327 |
+
args.batch_size or DEFAULT_BATCH,
|
| 1328 |
+
chat_cfg={
|
| 1329 |
+
"chat": args.chat,
|
| 1330 |
+
"key": args.chat_messages_key,
|
| 1331 |
+
"gen_prompt": args.sft_add_generation_prompt,
|
| 1332 |
+
"text_field": args.dataset_field_text,
|
| 1333 |
+
},
|
| 1334 |
+
max_ckpts=args.max_ckpts,
|
| 1335 |
+
target_tokens_override=args.target_tokens,
|
| 1336 |
+
tie_weights=tie_weights,
|
| 1337 |
+
)
|
| 1338 |
+
|
| 1339 |
+
if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0):
|
| 1340 |
+
args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES
|
| 1341 |
+
args.after_sft_chat = True
|
| 1342 |
+
if args.after_sft_add_generation_prompt is None:
|
| 1343 |
+
args.after_sft_add_generation_prompt = True
|
| 1344 |
+
if not args.after_sft_block:
|
| 1345 |
+
args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK
|
| 1346 |
+
|
| 1347 |
+
if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0:
|
| 1348 |
+
print("\n[Orchestrator] Starting Post-Pretraining SFT Phase...")
|
| 1349 |
+
_phase_freeze(
|
| 1350 |
+
core,
|
| 1351 |
+
freeze_core=args.after_sft_freeze_core,
|
| 1352 |
+
unfreeze_ln=args.after_sft_unfreeze_ln,
|
| 1353 |
+
train_emb=args.after_sft_train_emb,
|
| 1354 |
+
)
|
| 1355 |
+
opt = torch.optim.AdamW(
|
| 1356 |
+
[
|
| 1357 |
+
{"params": [p for p in core.parameters() if p.requires_grad], "lr": args.after_sft_lr_core or args.lr_core},
|
| 1358 |
+
{"params": ar_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head},
|
| 1359 |
+
{"params": sat_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head},
|
| 1360 |
+
]
|
| 1361 |
+
)
|
| 1362 |
+
step, seen_tok, last_wall = _train_phase(
|
| 1363 |
+
args,
|
| 1364 |
+
"sft",
|
| 1365 |
+
core,
|
| 1366 |
+
ar_h,
|
| 1367 |
+
sat_h,
|
| 1368 |
+
opt,
|
| 1369 |
+
scaler,
|
| 1370 |
+
step,
|
| 1371 |
+
seen_tok,
|
| 1372 |
+
last_wall,
|
| 1373 |
+
cfg,
|
| 1374 |
+
args.after_sft_source,
|
| 1375 |
+
args.after_sft_steps,
|
| 1376 |
+
args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK,
|
| 1377 |
+
args.batch_size or DEFAULT_BATCH,
|
| 1378 |
+
chat_cfg={
|
| 1379 |
+
"chat": args.after_sft_chat,
|
| 1380 |
+
"key": args.after_sft_chat_messages_key,
|
| 1381 |
+
"gen_prompt": args.after_sft_add_generation_prompt if args.after_sft_add_generation_prompt is not None else args.sft_add_generation_prompt,
|
| 1382 |
+
"text_field": args.after_sft_dataset_field_text,
|
| 1383 |
+
},
|
| 1384 |
+
max_ckpts=args.max_ckpts,
|
| 1385 |
+
target_tokens_override=None,
|
| 1386 |
+
tie_weights=tie_weights,
|
| 1387 |
+
streaming=False,
|
| 1388 |
+
)
|
| 1389 |
+
|
| 1390 |
+
save_ckpt(
|
| 1391 |
+
pathlib.Path(args.save_dir) / "final.pt",
|
| 1392 |
+
core,
|
| 1393 |
+
ar_h,
|
| 1394 |
+
sat_h,
|
| 1395 |
+
opt,
|
| 1396 |
+
scaler,
|
| 1397 |
+
meta={
|
| 1398 |
+
"cfg": cfg,
|
| 1399 |
+
"step": step,
|
| 1400 |
+
"seen_tok": seen_tok,
|
| 1401 |
+
"wall_time": time.time(),
|
| 1402 |
+
"tie_weights": tie_weights,
|
| 1403 |
+
"block_size": args.block or DEFAULT_BLOCK,
|
| 1404 |
+
"batch_size": args.batch_size or DEFAULT_BATCH,
|
| 1405 |
+
},
|
| 1406 |
+
)
|
| 1407 |
+
print("🎉 All Training Complete")
|
| 1408 |
+
|
| 1409 |
+
|
| 1410 |
+
# ───────────────────────── Sampling / inference ─────────────────────────
|
| 1411 |
+
def _apply_penalties(logits, ids, n, rep_p, pres_p, freq_p):
|
| 1412 |
+
if ids.numel() == 0:
|
| 1413 |
+
return logits
|
| 1414 |
+
hist = ids[0, -n:].long() if n > 0 else ids[0].long()
|
| 1415 |
+
uniq, counts = torch.unique(hist, return_counts=True)
|
| 1416 |
+
if pres_p or freq_p:
|
| 1417 |
+
logits[..., uniq] -= pres_p + freq_p * counts.to(logits.dtype)
|
| 1418 |
+
if rep_p != 1.0:
|
| 1419 |
+
sel = logits[..., uniq]
|
| 1420 |
+
logits[..., uniq] = torch.where(sel > 0, sel / rep_p, sel * rep_p)
|
| 1421 |
+
return logits
|
| 1422 |
+
|
| 1423 |
+
|
| 1424 |
+
|
| 1425 |
+
def _sample(logits, T, top_k, top_p, min_p, greedy):
|
| 1426 |
+
if greedy:
|
| 1427 |
+
return logits.argmax(-1, keepdim=True)
|
| 1428 |
+
probs = (logits / max(T, 1e-8)).softmax(-1)
|
| 1429 |
+
if top_k:
|
| 1430 |
+
v, i = torch.topk(probs, min(top_k, probs.size(-1)))
|
| 1431 |
+
probs = torch.zeros_like(probs).scatter_(-1, i, v)
|
| 1432 |
+
if top_p < 1.0:
|
| 1433 |
+
s_probs, s_idx = torch.sort(probs, descending=True, dim=-1)
|
| 1434 |
+
keep = (torch.cumsum(s_probs, -1) <= top_p).to(probs.dtype)
|
| 1435 |
+
probs = torch.zeros_like(probs).scatter_(-1, s_idx, s_probs * keep)
|
| 1436 |
+
if min_p > 0:
|
| 1437 |
+
probs[probs < min_p] = 0
|
| 1438 |
+
if probs.sum() == 0:
|
| 1439 |
+
return logits.argmax(-1, keepdim=True)
|
| 1440 |
+
return probs.div_(probs.sum()).multinomial(1)
|
| 1441 |
+
|
| 1442 |
+
|
| 1443 |
+
|
| 1444 |
+
def _sample_on_cpu(logits_device, ids_device, args):
|
| 1445 |
+
logits = logits_device.detach().float().cpu()
|
| 1446 |
+
ids = ids_device.detach().cpu()
|
| 1447 |
+
logits = _apply_penalties(
|
| 1448 |
+
logits,
|
| 1449 |
+
ids,
|
| 1450 |
+
args.penalty_last_n,
|
| 1451 |
+
args.repetition_penalty,
|
| 1452 |
+
args.presence_penalty,
|
| 1453 |
+
args.frequency_penalty,
|
| 1454 |
+
)
|
| 1455 |
+
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 1456 |
+
return nxt.to(DEV)
|
| 1457 |
+
|
| 1458 |
+
|
| 1459 |
+
@torch.no_grad()
|
| 1460 |
+
def _infer_tt_static(args, core, ar_h, sat_h, ids):
|
| 1461 |
+
prompt_len = ids.size(1)
|
| 1462 |
+
total_len = prompt_len + args.max_new
|
| 1463 |
+
work = torch.full((1, total_len), PAD_ID, dtype=torch.long, device=DEV)
|
| 1464 |
+
work[:, :prompt_len] = ids
|
| 1465 |
+
|
| 1466 |
+
if args.mode == "ar":
|
| 1467 |
+
for step in range(args.max_new):
|
| 1468 |
+
cur_len = prompt_len + step
|
| 1469 |
+
h = core(work, causal_padded_mask(total_len, cur_len))
|
| 1470 |
+
logits = ar_h(h)[:, cur_len - 1]
|
| 1471 |
+
nxt = _sample_on_cpu(logits, work[:, :cur_len], args)
|
| 1472 |
+
work[:, cur_len] = nxt.squeeze(-1)
|
| 1473 |
+
return work
|
| 1474 |
+
|
| 1475 |
+
added = 0
|
| 1476 |
+
while added < args.max_new:
|
| 1477 |
+
cur_len = prompt_len + added
|
| 1478 |
+
h = core(work, sat_padded_mask(total_len, cur_len))
|
| 1479 |
+
start = max(0, cur_len - SAT_BLOCK)
|
| 1480 |
+
h_last = h[:, start:cur_len]
|
| 1481 |
+
if h_last.size(1) < SAT_BLOCK:
|
| 1482 |
+
pad = torch.zeros(
|
| 1483 |
+
h_last.size(0),
|
| 1484 |
+
SAT_BLOCK - h_last.size(1),
|
| 1485 |
+
h_last.size(2),
|
| 1486 |
+
device=h_last.device,
|
| 1487 |
+
dtype=h_last.dtype,
|
| 1488 |
+
)
|
| 1489 |
+
h_last = torch.cat([pad, h_last], dim=1)
|
| 1490 |
+
logits_all, gate = sat_h(h_last)
|
| 1491 |
+
stride = SAT_BLOCK if (not args.var or gate is None) else (gate.float().softmax(-1).cpu().multinomial(1).item() + 1)
|
| 1492 |
+
for i in range(int(stride)):
|
| 1493 |
+
if added >= args.max_new:
|
| 1494 |
+
break
|
| 1495 |
+
logits = logits_all[:, i]
|
| 1496 |
+
nxt = _sample_on_cpu(logits, work[:, :cur_len], args)
|
| 1497 |
+
work[:, cur_len] = nxt.squeeze(-1)
|
| 1498 |
+
cur_len += 1
|
| 1499 |
+
added += 1
|
| 1500 |
+
return work
|
| 1501 |
+
|
| 1502 |
+
|
| 1503 |
+
@torch.no_grad()
|
| 1504 |
+
def infer(args):
|
| 1505 |
+
setup_runtime(args)
|
| 1506 |
+
if args.mode == "ar":
|
| 1507 |
+
if args.temperature is None:
|
| 1508 |
+
args.temperature = 0.7
|
| 1509 |
+
if args.top_k is None:
|
| 1510 |
+
args.top_k = 0
|
| 1511 |
+
if args.repetition_penalty is None:
|
| 1512 |
+
args.repetition_penalty = 1.3
|
| 1513 |
+
if args.presence_penalty is None:
|
| 1514 |
+
args.presence_penalty = 0.0
|
| 1515 |
+
if args.frequency_penalty is None:
|
| 1516 |
+
args.frequency_penalty = 0.3
|
| 1517 |
+
if args.penalty_last_n is None:
|
| 1518 |
+
args.penalty_last_n = 128
|
| 1519 |
+
if args.var is None:
|
| 1520 |
+
args.var = False
|
| 1521 |
+
else:
|
| 1522 |
+
if args.temperature is None:
|
| 1523 |
+
args.temperature = 0.5
|
| 1524 |
+
if args.top_k is None:
|
| 1525 |
+
args.top_k = 30
|
| 1526 |
+
if args.repetition_penalty is None:
|
| 1527 |
+
args.repetition_penalty = 2.0
|
| 1528 |
+
if args.presence_penalty is None:
|
| 1529 |
+
args.presence_penalty = 0.6
|
| 1530 |
+
if args.frequency_penalty is None:
|
| 1531 |
+
args.frequency_penalty = 1.0
|
| 1532 |
+
if args.penalty_last_n is None:
|
| 1533 |
+
args.penalty_last_n = 200
|
| 1534 |
+
if args.var is None:
|
| 1535 |
+
args.var = True
|
| 1536 |
+
|
| 1537 |
+
path = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt)
|
| 1538 |
+
sd = torch.load(path, map_location="cpu")
|
| 1539 |
+
cfg = sd["cfg"]
|
| 1540 |
+
tie_weights = sd.get("tie_weights", False)
|
| 1541 |
+
uk_time = get_uk_time()
|
| 1542 |
+
ckpt_name = path.name
|
| 1543 |
+
|
| 1544 |
+
print("┌─────────────────────────────────────────────────┐")
|
| 1545 |
+
print(f"│ INFERENCE @ {uk_time:<35s} │")
|
| 1546 |
+
print("├─────────────────────────────────────────────────┤")
|
| 1547 |
+
print(f"│ Checkpoint: {ckpt_name:<35s} │")
|
| 1548 |
+
print("└─────────────────────────────────────────────────┘")
|
| 1549 |
+
print_expansion_info(cfg, tie_weights)
|
| 1550 |
+
|
| 1551 |
+
core, ar_h, sat_h = _build_models(cfg, tie_weights=tie_weights)
|
| 1552 |
+
core.load_state_dict(sd["core"])
|
| 1553 |
+
ar_h.load_state_dict(sd["ar"])
|
| 1554 |
+
sat_h.load_state_dict(sd["sat"])
|
| 1555 |
+
retie_weights(core, ar_h, tie_weights)
|
| 1556 |
+
|
| 1557 |
+
if RUNTIME.is_tt and args.tt_dtype == "bf16":
|
| 1558 |
+
core = core.to(dtype=torch.bfloat16)
|
| 1559 |
+
ar_h = ar_h.to(dtype=torch.bfloat16)
|
| 1560 |
+
sat_h = sat_h.to(dtype=torch.bfloat16)
|
| 1561 |
+
retie_weights(core, ar_h, tie_weights)
|
| 1562 |
+
elif getattr(args, "fp16", False):
|
| 1563 |
+
core.half()
|
| 1564 |
+
ar_h.half()
|
| 1565 |
+
sat_h.half()
|
| 1566 |
+
retie_weights(core, ar_h, tie_weights)
|
| 1567 |
+
print(f"{Colors.INFO}Using fp16 inference{Colors.RESET}")
|
| 1568 |
+
|
| 1569 |
+
core, ar_h, sat_h = _move_models_to_device(core, ar_h, sat_h, tie_weights)
|
| 1570 |
+
core.eval()
|
| 1571 |
+
ar_h.eval()
|
| 1572 |
+
sat_h.eval()
|
| 1573 |
+
|
| 1574 |
+
total_params = _count_enabled_params(core, ar_h, sat_h)
|
| 1575 |
+
if total_params >= 1_000_000_000:
|
| 1576 |
+
param_str = f"{total_params / 1_000_000_000:.2f}B"
|
| 1577 |
+
elif total_params >= 1_000_000:
|
| 1578 |
+
param_str = f"{total_params / 1_000_000:.2f}M"
|
| 1579 |
+
elif total_params >= 1_000:
|
| 1580 |
+
param_str = f"{total_params / 1_000:.2f}K"
|
| 1581 |
+
else:
|
| 1582 |
+
param_str = f"{total_params}"
|
| 1583 |
+
print(f"Model size: {param_str} parameters ({total_params:,})")
|
| 1584 |
+
|
| 1585 |
+
prompt_tokens = tok.encode(args.prompt)
|
| 1586 |
+
prompt_len = len(prompt_tokens)
|
| 1587 |
+
ids = torch.tensor([prompt_tokens], device=DEV, dtype=torch.long)
|
| 1588 |
+
if ids.size(1) == 0:
|
| 1589 |
+
ids = torch.tensor([[EOS]], device=DEV, dtype=torch.long)
|
| 1590 |
+
prompt_len = 1
|
| 1591 |
+
|
| 1592 |
+
mode_str = args.mode if args.mode == "ar" else f"sat-{'var' if args.var else 'fixed'}"
|
| 1593 |
+
print(f"{Colors.INFO}Generating ({mode_str}) on backend={RUNTIME.backend}...{Colors.RESET}")
|
| 1594 |
+
|
| 1595 |
+
start = time.time()
|
| 1596 |
+
if RUNTIME.is_tt:
|
| 1597 |
+
ids = _infer_tt_static(args, core, ar_h, sat_h, ids)
|
| 1598 |
+
elif args.mode == "ar":
|
| 1599 |
+
h, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True, total_seq_len=ids.size(1))
|
| 1600 |
+
for _ in range(args.max_new):
|
| 1601 |
+
logits = ar_h(h)[:, -1]
|
| 1602 |
+
logits = _apply_penalties(
|
| 1603 |
+
logits,
|
| 1604 |
+
ids,
|
| 1605 |
+
args.penalty_last_n,
|
| 1606 |
+
args.repetition_penalty,
|
| 1607 |
+
args.presence_penalty,
|
| 1608 |
+
args.frequency_penalty,
|
| 1609 |
+
)
|
| 1610 |
+
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 1611 |
+
ids = torch.cat([ids, nxt], 1)
|
| 1612 |
+
h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
|
| 1613 |
+
else:
|
| 1614 |
+
cached_len = ids.size(1)
|
| 1615 |
+
h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len)
|
| 1616 |
+
added = 0
|
| 1617 |
+
while added < args.max_new:
|
| 1618 |
+
logits_all, gate = sat_h(h[:, -SAT_BLOCK:])
|
| 1619 |
+
stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
|
| 1620 |
+
new_tokens = []
|
| 1621 |
+
for i in range(int(stride)):
|
| 1622 |
+
logits = logits_all[:, i]
|
| 1623 |
+
logits = _apply_penalties(
|
| 1624 |
+
logits,
|
| 1625 |
+
ids,
|
| 1626 |
+
args.penalty_last_n,
|
| 1627 |
+
args.repetition_penalty,
|
| 1628 |
+
args.presence_penalty,
|
| 1629 |
+
args.frequency_penalty,
|
| 1630 |
+
)
|
| 1631 |
+
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 1632 |
+
new_tokens.append(nxt)
|
| 1633 |
+
ids = torch.cat([ids, nxt], 1)
|
| 1634 |
+
added += 1
|
| 1635 |
+
if added >= args.max_new:
|
| 1636 |
+
break
|
| 1637 |
+
if added >= args.max_new:
|
| 1638 |
+
break
|
| 1639 |
+
new_ids = torch.cat(new_tokens, dim=1)
|
| 1640 |
+
mask = sat_mask_cached(new_ids.size(1), cached_len)
|
| 1641 |
+
h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
|
| 1642 |
+
cached_len = ids.size(1)
|
| 1643 |
+
|
| 1644 |
+
if RUNTIME.is_tt:
|
| 1645 |
+
RUNTIME.sync(wait=True)
|
| 1646 |
+
elapsed = time.time() - start
|
| 1647 |
+
|
| 1648 |
+
all_tokens = ids[0].detach().cpu().tolist()
|
| 1649 |
+
gen_tokens = len(all_tokens) - prompt_len
|
| 1650 |
+
tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0.0
|
| 1651 |
+
prompt_text = tok.decode(all_tokens[:prompt_len], skip_special_tokens=True)
|
| 1652 |
+
gen_text = tok.decode(all_tokens[prompt_len:], skip_special_tokens=True)
|
| 1653 |
+
print(f"{Colors.PROMPT}{prompt_text}{Colors.RESET}{gen_text}")
|
| 1654 |
+
print(f"{Colors.INFO}[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s]{Colors.RESET}")
|
| 1655 |
+
|
| 1656 |
+
|
| 1657 |
+
# ───────────────────────── CLI ─────────────────────────
|
| 1658 |
+
def main():
|
| 1659 |
+
ap = argparse.ArgumentParser(description="AGILLM Expansion Ratio Testing (CUDA / Tenstorrent / CPU)")
|
| 1660 |
+
sub = ap.add_subparsers(dest="cmd", required=True)
|
| 1661 |
+
|
| 1662 |
+
tr = sub.add_parser("train")
|
| 1663 |
+
tr.add_argument("--backend", choices=["auto", "cuda", "tt", "cpu"], default="auto")
|
| 1664 |
+
tr.add_argument("--preset", choices=PRESETS.keys(), default="nano_3x")
|
| 1665 |
+
tr.add_argument("--rank", type=int)
|
| 1666 |
+
tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
|
| 1667 |
+
tr.add_argument("--batch_size", type=int, default=DEFAULT_BATCH)
|
| 1668 |
+
tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES)
|
| 1669 |
+
tr.add_argument("--target_tokens", type=int)
|
| 1670 |
+
tr.add_argument("--steps", type=int)
|
| 1671 |
+
tr.add_argument("--amp", action="store_true")
|
| 1672 |
+
tr.add_argument("--compile", action="store_true", help="Use torch.compile on CUDA. TT path skips this for stability.")
|
| 1673 |
+
tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
|
| 1674 |
+
tr.add_argument("--save_dir", default=str(CKDIR))
|
| 1675 |
+
tr.add_argument("--resume", type=str)
|
| 1676 |
+
tr.add_argument("--x2", action="store_true")
|
| 1677 |
+
tr.add_argument("--warmstart_from", type=str)
|
| 1678 |
+
tr.add_argument("--fresh", action="store_true")
|
| 1679 |
+
tr.add_argument("--max_ckpts", type=int, default=None)
|
| 1680 |
+
tr.add_argument("--chilla_max_double", action="store_true")
|
| 1681 |
+
tr.add_argument("--tie_weights", action="store_true")
|
| 1682 |
+
tr.add_argument("--ar_only", action="store_true")
|
| 1683 |
+
tr.add_argument("--freeze_core", action="store_true")
|
| 1684 |
+
tr.add_argument("--unfreeze_ln", action="store_true")
|
| 1685 |
+
tr.add_argument("--train_emb", action="store_true")
|
| 1686 |
+
tr.add_argument("--lr_core", type=float, default=LR_CORE)
|
| 1687 |
+
tr.add_argument("--lr_head", type=float, default=LR_HEAD)
|
| 1688 |
+
tr.add_argument("--label_smoothing", type=float, default=0.1)
|
| 1689 |
+
tr.add_argument("--chat", action="store_true")
|
| 1690 |
+
tr.add_argument("--chat_messages_key", default="messages")
|
| 1691 |
+
tr.add_argument("--dataset_field_text", default="text")
|
| 1692 |
+
tr.add_argument("--sft_add_generation_prompt", action="store_true")
|
| 1693 |
+
tr.add_argument("--auto_grow", action="store_true")
|
| 1694 |
+
tr.add_argument("--grow_plan", default="576,640,768,896,1024,1122")
|
| 1695 |
+
tr.add_argument("--grow_every_steps", type=int, default=50000)
|
| 1696 |
+
tr.add_argument("--after_sft_source", default="")
|
| 1697 |
+
tr.add_argument("--after_sft_steps", type=int, default=0)
|
| 1698 |
+
tr.add_argument("--after_sft_chat", action="store_true")
|
| 1699 |
+
tr.add_argument("--after_sft_chat_messages_key", default="messages")
|
| 1700 |
+
tr.add_argument("--after_sft_dataset_field_text", default="text")
|
| 1701 |
+
tr.add_argument("--after_sft_add_generation_prompt", type=bool, default=None)
|
| 1702 |
+
tr.add_argument("--after_sft_block", type=int, default=0)
|
| 1703 |
+
tr.add_argument("--after_sft_freeze_core", action="store_true")
|
| 1704 |
+
tr.add_argument("--after_sft_unfreeze_ln", action="store_true")
|
| 1705 |
+
tr.add_argument("--after_sft_train_emb", action="store_true")
|
| 1706 |
+
tr.add_argument("--after_sft_lr_core", type=float, default=0.0)
|
| 1707 |
+
tr.add_argument("--after_sft_lr_head", type=float, default=0.0)
|
| 1708 |
+
tr.add_argument("--tt_dtype", choices=["fp32", "bf16"], default="bf16")
|
| 1709 |
+
tr.add_argument("--tt_bfp8", action="store_true")
|
| 1710 |
+
tr.add_argument("--tt_weight_bfp8", action="store_true")
|
| 1711 |
+
tr.add_argument("--tt_optimization_level", type=int, default=1)
|
| 1712 |
+
tr.add_argument("--tt_trace", action="store_true")
|
| 1713 |
+
tr.add_argument("--tt_trace_region_size", type=int, default=10_000_000)
|
| 1714 |
+
tr.add_argument("--tt_spmd", action="store_true", help="Experimental: shard batch across visible TT chips.")
|
| 1715 |
+
|
| 1716 |
+
inf = sub.add_parser("infer")
|
| 1717 |
+
inf.add_argument("--backend", choices=["auto", "cuda", "tt", "cpu"], default="auto")
|
| 1718 |
+
inf.add_argument("--mode", choices=["ar", "sat"], required=True)
|
| 1719 |
+
inf.add_argument("--ckpt", required=True)
|
| 1720 |
+
inf.add_argument("--prompt", required=True)
|
| 1721 |
+
inf.add_argument("--max_new", type=int, default=120)
|
| 1722 |
+
inf.add_argument("--temperature", type=float, default=None)
|
| 1723 |
+
inf.add_argument("--greedy", action="store_true")
|
| 1724 |
+
inf.add_argument("--top_k", type=int, default=None)
|
| 1725 |
+
inf.add_argument("--top_p", type=float, default=0.9)
|
| 1726 |
+
inf.add_argument("--min_p", type=float, default=0.0)
|
| 1727 |
+
inf.add_argument("--repetition_penalty", type=float, default=None)
|
| 1728 |
+
inf.add_argument("--presence_penalty", type=float, default=None)
|
| 1729 |
+
inf.add_argument("--frequency_penalty", type=float, default=None)
|
| 1730 |
+
inf.add_argument("--penalty_last_n", type=int, default=None)
|
| 1731 |
+
inf.add_argument("--var", action="store_true", default=None)
|
| 1732 |
+
inf.add_argument("--no-var", dest="var", action="store_false")
|
| 1733 |
+
inf.add_argument("--fp16", action="store_true", help="Use fp16 inference on CUDA/CPU-like backends.")
|
| 1734 |
+
inf.add_argument("--tt_dtype", choices=["fp32", "bf16"], default="bf16")
|
| 1735 |
+
inf.add_argument("--tt_bfp8", action="store_true")
|
| 1736 |
+
inf.add_argument("--tt_weight_bfp8", action="store_true")
|
| 1737 |
+
inf.add_argument("--tt_optimization_level", type=int, default=1)
|
| 1738 |
+
inf.add_argument("--tt_trace", action="store_true")
|
| 1739 |
+
inf.add_argument("--tt_trace_region_size", type=int, default=10_000_000)
|
| 1740 |
+
inf.add_argument("--tt_spmd", action="store_true")
|
| 1741 |
+
|
| 1742 |
+
sub.add_parser("status")
|
| 1743 |
+
|
| 1744 |
+
args = ap.parse_args()
|
| 1745 |
+
if args.cmd == "train":
|
| 1746 |
+
train(args)
|
| 1747 |
+
elif args.cmd == "status":
|
| 1748 |
+
show_status()
|
| 1749 |
+
else:
|
| 1750 |
+
infer(args)
|
| 1751 |
+
|
| 1752 |
+
|
| 1753 |
+
if __name__ == "__main__":
|
| 1754 |
+
main()
|