Upload persistent_absorber.py with huggingface_hub
Browse files- persistent_absorber.py +1934 -0
persistent_absorber.py
ADDED
|
@@ -0,0 +1,1934 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Claudia Persistent Absorber v2
|
| 3 |
+
==============================
|
| 4 |
+
Combines the 3 best proven techniques into one system:
|
| 5 |
+
|
| 6 |
+
1. SELF-QUIZ PAIRS (21% β 74% recall β the single biggest lever)
|
| 7 |
+
2. PERSISTENT LoRA rank 128 (89% across 25 convos, no merge-between-rounds tax)
|
| 8 |
+
3. DUAL-LR EXPERT FFN (attention=6e-5, FFN=3e-4 β facts into MoE experts)
|
| 9 |
+
|
| 10 |
+
Architecture:
|
| 11 |
+
- Load base Omni β thinker to GPU, rest to CPU
|
| 12 |
+
- First run: apply Claudia v6 adapter β merge β apply FFN patch
|
| 13 |
+
- Resume: load from checkpoint (already has personality + memories)
|
| 14 |
+
- Apply ONE persistent LoRA (r=128, alpha=256, attention q/k/v/o)
|
| 15 |
+
- Chat loop: generate β quiz β train (LoRA + expert FFN) β repeat
|
| 16 |
+
- On save/quit: merge_and_unload β save full checkpoint
|
| 17 |
+
- Next session loads from checkpoint β memories are permanent
|
| 18 |
+
|
| 19 |
+
Instance: Vast.ai 33093662 (A100 80GB, Sweden)
|
| 20 |
+
SSH: ssh -p 13662 root@ssh1.vast.ai
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import gc
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
import re
|
| 28 |
+
import sys
|
| 29 |
+
import threading
|
| 30 |
+
import time
|
| 31 |
+
import torch
|
| 32 |
+
from collections import Counter
|
| 33 |
+
from datetime import datetime
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
# CONFIG
|
| 39 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
|
| 41 |
+
# LoRA config (from persistent LoRA test β proven for 25+ conversations)
|
| 42 |
+
LORA_RANK = 128
|
| 43 |
+
LORA_ALPHA = 256
|
| 44 |
+
LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 45 |
+
|
| 46 |
+
# Dual-LR (from engram micro_trainer β proven 5/5 fact retention)
|
| 47 |
+
ATTENTION_LR = 6e-5
|
| 48 |
+
EXPERT_FFN_LR = 3e-4 # 5x multiplier β facts absorb fast, personality stays
|
| 49 |
+
EXPERT_FFN_LAYERS = [20, 24, 28] # Proven optimal in v5 experiment
|
| 50 |
+
|
| 51 |
+
# Training per absorption cycle
|
| 52 |
+
TRAIN_EPOCHS = 2 # Reduced from 4 β prevents overfitting with focused training
|
| 53 |
+
MAX_SEQ_LENGTH = 2048
|
| 54 |
+
GRADIENT_CLIP = 1.0
|
| 55 |
+
|
| 56 |
+
# Generation
|
| 57 |
+
GEN_TEMPERATURE = 0.7
|
| 58 |
+
GEN_TOP_P = 0.9
|
| 59 |
+
GEN_TOP_K = 50
|
| 60 |
+
GEN_MAX_TOKENS = 512
|
| 61 |
+
GEN_REP_PENALTY = 1.1
|
| 62 |
+
|
| 63 |
+
# Absorb after every N exchanges (1 = every turn)
|
| 64 |
+
ABSORB_EVERY = 1
|
| 65 |
+
|
| 66 |
+
# Checkpoint interval (auto-save every N absorptions)
|
| 67 |
+
CHECKPOINT_EVERY = 10
|
| 68 |
+
|
| 69 |
+
# Self-verification (v11 β clean contrastive + sister pairs, no "NOT X" leak)
|
| 70 |
+
VERIFY_EVERY = 3 # More frequent checks catch drift earlier
|
| 71 |
+
VERIFY_SAMPLE = 10 # Back to v9's value β wider sampling destabilized in v10
|
| 72 |
+
|
| 73 |
+
# Cascade Distillation (Nemotron-Cascade-2 paper β on-policy distillation)
|
| 74 |
+
# When facts from previous sessions regress, distill from the teacher checkpoint
|
| 75 |
+
# that knew them best. Recovers regressions without losing new knowledge.
|
| 76 |
+
DISTILL_ALPHA = 0.5 # CE vs KL loss balance (0.5 = equal weight)
|
| 77 |
+
DISTILL_TEMPERATURE = 2.0 # Softens distributions for better KL gradients
|
| 78 |
+
DISTILL_TOP_K = 32 # Top-K logits to cache per token position
|
| 79 |
+
CONSOLIDATION_EPOCHS = 2 # Distillation epochs at session start (1β2 for stronger lock-in)
|
| 80 |
+
MAX_TEACHER_CACHE = 200 # Cap quiz pairs to cache (oldest trimmed)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
+
# QUALITY GATE (from engram micro_trainer β reject degenerate text)
|
| 85 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 86 |
+
|
| 87 |
+
def check_response_quality(text):
|
| 88 |
+
"""Reject degenerate text before training on it."""
|
| 89 |
+
if not text or len(text) < 5:
|
| 90 |
+
return False
|
| 91 |
+
words = text.lower().split()
|
| 92 |
+
if len(words) < 3:
|
| 93 |
+
return False
|
| 94 |
+
# Low unique word ratio = repetitive garbage
|
| 95 |
+
if len(set(words)) / len(words) < 0.3:
|
| 96 |
+
return False
|
| 97 |
+
# Repeated consecutive words
|
| 98 |
+
if sum(1 for i in range(len(words) - 1) if words[i] == words[i + 1]) >= 3:
|
| 99 |
+
return False
|
| 100 |
+
# Repeated bigrams
|
| 101 |
+
if len(words) >= 10:
|
| 102 |
+
bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)]
|
| 103 |
+
if Counter(bigrams).most_common(1)[0][1] >= 5:
|
| 104 |
+
return False
|
| 105 |
+
# Fused words (missing spaces)
|
| 106 |
+
if sum(1 for w in words if len(w) > 30) >= 2:
|
| 107 |
+
return False
|
| 108 |
+
# Average word length spike
|
| 109 |
+
if sum(len(w) for w in words) / len(words) > 12:
|
| 110 |
+
return False
|
| 111 |
+
return True
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 115 |
+
# MODEL MANAGER
|
| 116 |
+
# βοΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
+
|
| 118 |
+
class ModelManager:
|
| 119 |
+
def __init__(self, model_path, adapter_path=None, ffn_patch_path=None,
|
| 120 |
+
checkpoint_path=None):
|
| 121 |
+
self.model_path = model_path
|
| 122 |
+
self.adapter_path = adapter_path
|
| 123 |
+
self.ffn_patch_path = ffn_patch_path
|
| 124 |
+
self.checkpoint_path = checkpoint_path # Resume from here if set
|
| 125 |
+
|
| 126 |
+
self.thinker = None
|
| 127 |
+
self.tokenizer = None
|
| 128 |
+
self.stop_ids = None
|
| 129 |
+
self.peft_model = None # The persistent LoRA β stays active all session
|
| 130 |
+
self._lock = threading.Lock()
|
| 131 |
+
|
| 132 |
+
def load(self):
|
| 133 |
+
from transformers import AutoTokenizer
|
| 134 |
+
|
| 135 |
+
# ββ Step 1: Load tokenizer ββ
|
| 136 |
+
tok_source = self.checkpoint_path or self.model_path
|
| 137 |
+
print(f"[1/5] Loading tokenizer from {tok_source}...")
|
| 138 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 139 |
+
tok_source, trust_remote_code=True
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# ββ Step 2: Load model ββ
|
| 143 |
+
if self.checkpoint_path:
|
| 144 |
+
# RESUME: checkpoint contains only thinker weights β load thinker directly
|
| 145 |
+
print(f"[2/5] Loading thinker from checkpoint {self.checkpoint_path}...")
|
| 146 |
+
try:
|
| 147 |
+
from transformers import Qwen3OmniMoeThinkerForConditionalGeneration as ThinkerClass
|
| 148 |
+
except ImportError:
|
| 149 |
+
from transformers import AutoModelForCausalLM as ThinkerClass
|
| 150 |
+
self.thinker = ThinkerClass.from_pretrained(
|
| 151 |
+
self.checkpoint_path,
|
| 152 |
+
device_map="auto",
|
| 153 |
+
torch_dtype=torch.bfloat16,
|
| 154 |
+
trust_remote_code=True,
|
| 155 |
+
)
|
| 156 |
+
vram = torch.cuda.memory_allocated() / 1e9
|
| 157 |
+
print(f" VRAM after load: {vram:.1f} GB")
|
| 158 |
+
else:
|
| 159 |
+
# FIRST RUN: load full model, extract thinker, offload rest
|
| 160 |
+
print(f"[2/5] Loading full model from {self.model_path}...")
|
| 161 |
+
try:
|
| 162 |
+
from transformers import Qwen3OmniMoeForConditionalGeneration as ModelClass
|
| 163 |
+
except ImportError:
|
| 164 |
+
from transformers import AutoModel as ModelClass
|
| 165 |
+
full_model = ModelClass.from_pretrained(
|
| 166 |
+
self.model_path,
|
| 167 |
+
device_map="auto",
|
| 168 |
+
torch_dtype=torch.bfloat16,
|
| 169 |
+
trust_remote_code=True,
|
| 170 |
+
)
|
| 171 |
+
vram = torch.cuda.memory_allocated() / 1e9
|
| 172 |
+
print(f" VRAM after load: {vram:.1f} GB")
|
| 173 |
+
|
| 174 |
+
# Extract thinker, offload rest
|
| 175 |
+
self.thinker = full_model.thinker
|
| 176 |
+
for name, module in full_model.named_children():
|
| 177 |
+
if name != "thinker":
|
| 178 |
+
try:
|
| 179 |
+
module.cpu()
|
| 180 |
+
except (NotImplementedError, RuntimeError):
|
| 181 |
+
pass
|
| 182 |
+
del full_model
|
| 183 |
+
torch.cuda.empty_cache()
|
| 184 |
+
vram = torch.cuda.memory_allocated() / 1e9
|
| 185 |
+
print(f" VRAM after cleanup: {vram:.1f} GB")
|
| 186 |
+
|
| 187 |
+
# ββ Step 3: Apply personality if first run ββ
|
| 188 |
+
if self.checkpoint_path:
|
| 189 |
+
print(f"[3/5] Resuming from checkpoint β personality already in weights.")
|
| 190 |
+
else:
|
| 191 |
+
if self.adapter_path:
|
| 192 |
+
print(f"[3/5] Merging Claudia v6 personality adapter...")
|
| 193 |
+
from peft import PeftModel
|
| 194 |
+
self.thinker = PeftModel.from_pretrained(
|
| 195 |
+
self.thinker, self.adapter_path
|
| 196 |
+
)
|
| 197 |
+
self.thinker = self.thinker.merge_and_unload()
|
| 198 |
+
print(f" Personality merged into base weights.")
|
| 199 |
+
|
| 200 |
+
if self.ffn_patch_path and os.path.exists(self.ffn_patch_path):
|
| 201 |
+
print(f" Applying FFN patch from {self.ffn_patch_path}...")
|
| 202 |
+
ffn = torch.load(
|
| 203 |
+
self.ffn_patch_path, map_location="cpu", weights_only=True
|
| 204 |
+
)
|
| 205 |
+
for key, tensor in ffn.items():
|
| 206 |
+
match = re.search(r"layers\.(\d+)", key)
|
| 207 |
+
if not match:
|
| 208 |
+
continue
|
| 209 |
+
layer_idx = int(match.group(1))
|
| 210 |
+
experts = self.thinker.model.layers[layer_idx].mlp.experts
|
| 211 |
+
if hasattr(experts, '__len__'):
|
| 212 |
+
for i in range(tensor.shape[0]):
|
| 213 |
+
experts[i].down_proj.weight.data.copy_(
|
| 214 |
+
tensor[i].to(
|
| 215 |
+
experts[i].down_proj.weight.device,
|
| 216 |
+
experts[i].down_proj.weight.dtype,
|
| 217 |
+
)
|
| 218 |
+
)
|
| 219 |
+
elif hasattr(experts, 'down_proj'):
|
| 220 |
+
experts.down_proj.data.copy_(
|
| 221 |
+
tensor.to(experts.down_proj.device, experts.down_proj.dtype)
|
| 222 |
+
)
|
| 223 |
+
del ffn
|
| 224 |
+
torch.cuda.empty_cache()
|
| 225 |
+
print(f" FFN patch applied.")
|
| 226 |
+
|
| 227 |
+
self.thinker.eval()
|
| 228 |
+
|
| 229 |
+
# Stop tokens
|
| 230 |
+
self.stop_ids = []
|
| 231 |
+
for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
|
| 232 |
+
ids = self.tokenizer.encode(tok, add_special_tokens=False)
|
| 233 |
+
if ids:
|
| 234 |
+
self.stop_ids.extend(ids)
|
| 235 |
+
if self.tokenizer.eos_token_id:
|
| 236 |
+
self.stop_ids.append(self.tokenizer.eos_token_id)
|
| 237 |
+
|
| 238 |
+
# ββ Step 5: Apply persistent LoRA ββ
|
| 239 |
+
print(f"[4/5] Applying persistent LoRA (r={LORA_RANK}, alpha={LORA_ALPHA})...")
|
| 240 |
+
self._apply_persistent_lora()
|
| 241 |
+
|
| 242 |
+
vram = torch.cuda.memory_allocated() / 1e9
|
| 243 |
+
print(f"[5/5] Ready. VRAM: {vram:.1f} GB\n")
|
| 244 |
+
|
| 245 |
+
def _apply_persistent_lora(self):
|
| 246 |
+
"""Apply the persistent absorption LoRA. Called once at load, and after merge."""
|
| 247 |
+
from peft import LoraConfig, get_peft_model
|
| 248 |
+
|
| 249 |
+
lora_config = LoraConfig(
|
| 250 |
+
r=LORA_RANK,
|
| 251 |
+
lora_alpha=LORA_ALPHA,
|
| 252 |
+
target_modules=LORA_TARGETS,
|
| 253 |
+
lora_dropout=0.0,
|
| 254 |
+
bias="none",
|
| 255 |
+
task_type="CAUSAL_LM",
|
| 256 |
+
)
|
| 257 |
+
self.peft_model = get_peft_model(self.thinker, lora_config)
|
| 258 |
+
self.peft_model.eval()
|
| 259 |
+
|
| 260 |
+
trainable = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad)
|
| 261 |
+
total = sum(p.numel() for p in self.peft_model.parameters())
|
| 262 |
+
print(f" LoRA: {trainable / 1e6:.1f}M trainable / {total / 1e6:.0f}M total")
|
| 263 |
+
|
| 264 |
+
def generate(self, messages, max_new_tokens=None):
|
| 265 |
+
"""Generate response. Thread-safe."""
|
| 266 |
+
with self._lock:
|
| 267 |
+
model = self.peft_model or self.thinker
|
| 268 |
+
model.eval()
|
| 269 |
+
|
| 270 |
+
text = self.tokenizer.apply_chat_template(
|
| 271 |
+
messages, tokenize=False, add_generation_prompt=True,
|
| 272 |
+
enable_thinking=False,
|
| 273 |
+
)
|
| 274 |
+
inputs = self.tokenizer(
|
| 275 |
+
text, return_tensors="pt", truncation=True, max_length=8192
|
| 276 |
+
).to("cuda")
|
| 277 |
+
input_len = inputs["input_ids"].shape[1]
|
| 278 |
+
|
| 279 |
+
with torch.inference_mode():
|
| 280 |
+
out = model.generate(
|
| 281 |
+
**inputs,
|
| 282 |
+
max_new_tokens=max_new_tokens or GEN_MAX_TOKENS,
|
| 283 |
+
temperature=GEN_TEMPERATURE,
|
| 284 |
+
top_p=GEN_TOP_P,
|
| 285 |
+
top_k=GEN_TOP_K,
|
| 286 |
+
do_sample=True,
|
| 287 |
+
repetition_penalty=GEN_REP_PENALTY,
|
| 288 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 289 |
+
eos_token_id=self.stop_ids,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
resp = self.tokenizer.decode(out[0][input_len:], skip_special_tokens=True)
|
| 293 |
+
# Strip thinking tags
|
| 294 |
+
resp = re.sub(r"<think>.*?</think>", "", resp, flags=re.DOTALL)
|
| 295 |
+
resp = re.sub(r"</?think>", "", resp)
|
| 296 |
+
return resp.strip()
|
| 297 |
+
|
| 298 |
+
def absorb(self, training_data):
|
| 299 |
+
"""
|
| 300 |
+
Train the persistent LoRA + expert FFN on accumulated data.
|
| 301 |
+
Uses dual-LR: attention at ATTENTION_LR, expert FFN at EXPERT_FFN_LR.
|
| 302 |
+
Thread-safe.
|
| 303 |
+
"""
|
| 304 |
+
with self._lock:
|
| 305 |
+
return self._absorb_impl(training_data)
|
| 306 |
+
|
| 307 |
+
def _absorb_impl(self, training_data):
|
| 308 |
+
"""Internal absorption. Must hold _lock."""
|
| 309 |
+
if not training_data:
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
model = self.peft_model or self.thinker
|
| 313 |
+
tokenizer = self.tokenizer
|
| 314 |
+
|
| 315 |
+
# ββ Tokenize all examples ββ
|
| 316 |
+
texts = []
|
| 317 |
+
for item in training_data:
|
| 318 |
+
if isinstance(item, dict) and "messages" in item:
|
| 319 |
+
msgs = item["messages"]
|
| 320 |
+
elif isinstance(item, dict) and "prompt" in item:
|
| 321 |
+
msgs = item["prompt"] + item.get("completion", [])
|
| 322 |
+
elif isinstance(item, list):
|
| 323 |
+
msgs = item
|
| 324 |
+
else:
|
| 325 |
+
continue
|
| 326 |
+
|
| 327 |
+
text = tokenizer.apply_chat_template(
|
| 328 |
+
msgs, tokenize=False, enable_thinking=False
|
| 329 |
+
)
|
| 330 |
+
texts.append(text)
|
| 331 |
+
|
| 332 |
+
if not texts:
|
| 333 |
+
return None
|
| 334 |
+
|
| 335 |
+
enc = tokenizer(
|
| 336 |
+
texts,
|
| 337 |
+
truncation=True,
|
| 338 |
+
max_length=MAX_SEQ_LENGTH,
|
| 339 |
+
padding=True,
|
| 340 |
+
return_tensors="pt",
|
| 341 |
+
)
|
| 342 |
+
input_ids = enc["input_ids"].to("cuda")
|
| 343 |
+
attention_mask = enc["attention_mask"].to("cuda")
|
| 344 |
+
labels = input_ids.clone()
|
| 345 |
+
labels[attention_mask == 0] = -100
|
| 346 |
+
|
| 347 |
+
# ββ Collect LoRA attention params ββ
|
| 348 |
+
model.train()
|
| 349 |
+
attn_params = [p for p in model.parameters() if p.requires_grad]
|
| 350 |
+
|
| 351 |
+
# ββ Unfreeze expert FFN ββ
|
| 352 |
+
expert_params = []
|
| 353 |
+
base = model.base_model.model if hasattr(model, "base_model") else model
|
| 354 |
+
for layer_idx in EXPERT_FFN_LAYERS:
|
| 355 |
+
experts = base.model.layers[layer_idx].mlp.experts
|
| 356 |
+
if hasattr(experts, '__len__'):
|
| 357 |
+
for i in range(len(experts)):
|
| 358 |
+
p = experts[i].down_proj.weight
|
| 359 |
+
p.requires_grad_(True)
|
| 360 |
+
expert_params.append(p)
|
| 361 |
+
elif hasattr(experts, 'down_proj'):
|
| 362 |
+
p = experts.down_proj
|
| 363 |
+
if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
|
| 364 |
+
p.requires_grad_(True)
|
| 365 |
+
expert_params.append(p)
|
| 366 |
+
|
| 367 |
+
# ββ Dual-LR optimizer ββ
|
| 368 |
+
param_groups = []
|
| 369 |
+
if attn_params:
|
| 370 |
+
param_groups.append({"params": attn_params, "lr": ATTENTION_LR})
|
| 371 |
+
if expert_params:
|
| 372 |
+
param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR})
|
| 373 |
+
|
| 374 |
+
if not param_groups:
|
| 375 |
+
model.eval()
|
| 376 |
+
return None
|
| 377 |
+
|
| 378 |
+
optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0)
|
| 379 |
+
all_params = attn_params + expert_params
|
| 380 |
+
|
| 381 |
+
# ββ Training loop ββ
|
| 382 |
+
n = input_ids.shape[0]
|
| 383 |
+
total_steps = n * TRAIN_EPOCHS
|
| 384 |
+
total_loss = 0.0
|
| 385 |
+
|
| 386 |
+
for epoch in range(TRAIN_EPOCHS):
|
| 387 |
+
# Shuffle order each epoch
|
| 388 |
+
indices = torch.randperm(n)
|
| 389 |
+
for i in range(n):
|
| 390 |
+
idx = indices[i].item()
|
| 391 |
+
out = model(
|
| 392 |
+
input_ids=input_ids[idx:idx + 1],
|
| 393 |
+
attention_mask=attention_mask[idx:idx + 1],
|
| 394 |
+
labels=labels[idx:idx + 1],
|
| 395 |
+
)
|
| 396 |
+
out.loss.backward()
|
| 397 |
+
torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP)
|
| 398 |
+
optimizer.step()
|
| 399 |
+
optimizer.zero_grad()
|
| 400 |
+
total_loss += out.loss.item()
|
| 401 |
+
|
| 402 |
+
# ββ Re-freeze expert FFN ββ
|
| 403 |
+
for layer_idx in EXPERT_FFN_LAYERS:
|
| 404 |
+
experts = base.model.layers[layer_idx].mlp.experts
|
| 405 |
+
if hasattr(experts, '__len__'):
|
| 406 |
+
for i in range(len(experts)):
|
| 407 |
+
experts[i].down_proj.weight.requires_grad_(False)
|
| 408 |
+
elif hasattr(experts, 'down_proj'):
|
| 409 |
+
p = experts.down_proj
|
| 410 |
+
if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
|
| 411 |
+
p.requires_grad_(False)
|
| 412 |
+
|
| 413 |
+
model.eval()
|
| 414 |
+
del optimizer
|
| 415 |
+
torch.cuda.empty_cache()
|
| 416 |
+
|
| 417 |
+
avg_loss = total_loss / total_steps if total_steps > 0 else 0
|
| 418 |
+
return avg_loss
|
| 419 |
+
|
| 420 |
+
@staticmethod
|
| 421 |
+
def cluster_by_entity(training_data, entity_names):
|
| 422 |
+
"""Group training data by primary entity mentioned.
|
| 423 |
+
|
| 424 |
+
Instead of interleaving facts about different people (which causes
|
| 425 |
+
cross-contamination during gradient updates), this groups all data
|
| 426 |
+
about one entity together. The model learns all of Jordan's facts
|
| 427 |
+
before moving to Elena's.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
training_data: List of training items
|
| 431 |
+
entity_names: Set/list of known entity names
|
| 432 |
+
|
| 433 |
+
Returns: List of training items, reordered so each entity's items
|
| 434 |
+
are contiguous. Items mentioning no entity come last.
|
| 435 |
+
"""
|
| 436 |
+
clusters = {name: [] for name in entity_names}
|
| 437 |
+
unclustered = []
|
| 438 |
+
|
| 439 |
+
for item in training_data:
|
| 440 |
+
# Extract text from the item
|
| 441 |
+
if isinstance(item, dict) and "messages" in item:
|
| 442 |
+
text = " ".join(m.get("content", "") for m in item["messages"]).lower()
|
| 443 |
+
else:
|
| 444 |
+
unclustered.append(item)
|
| 445 |
+
continue
|
| 446 |
+
|
| 447 |
+
# Assign to the first entity mentioned (primary entity)
|
| 448 |
+
assigned = False
|
| 449 |
+
for name in entity_names:
|
| 450 |
+
if name.lower() in text:
|
| 451 |
+
clusters[name].append(item)
|
| 452 |
+
assigned = True
|
| 453 |
+
break
|
| 454 |
+
if not assigned:
|
| 455 |
+
unclustered.append(item)
|
| 456 |
+
|
| 457 |
+
# Build ordered list: all of entity A's facts, then B's, then C's...
|
| 458 |
+
ordered = []
|
| 459 |
+
for name in entity_names:
|
| 460 |
+
ordered.extend(clusters[name])
|
| 461 |
+
ordered.extend(unclustered)
|
| 462 |
+
return ordered
|
| 463 |
+
|
| 464 |
+
def absorb_two_phase(self, positive_data, contrastive_data, verify_fn=None):
|
| 465 |
+
"""Two-phase absorption: facts first, then targeted contrastive correction.
|
| 466 |
+
|
| 467 |
+
Phase 1: Train on positive facts (exchanges, entity summaries, template quizzes).
|
| 468 |
+
This builds the core factual representations.
|
| 469 |
+
Phase 2: Quick verification on known entities, then train ONLY contrastive
|
| 470 |
+
quizzes for entities that failed verification. This avoids unnecessary
|
| 471 |
+
negative gradients on entities the model already distinguishes correctly.
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
positive_data: List of training items (exchanges, summaries, direct quizzes)
|
| 475 |
+
contrastive_data: List of contrastive quiz items ("Is X a [Y's job]? No...")
|
| 476 |
+
verify_fn: Optional callable(model_manager) -> set of confused_entity_names.
|
| 477 |
+
If None, all contrastive data is used in Phase 2.
|
| 478 |
+
|
| 479 |
+
Returns: (phase1_loss, phase2_loss) tuple
|
| 480 |
+
"""
|
| 481 |
+
with self._lock:
|
| 482 |
+
# Phase 1: Positive facts
|
| 483 |
+
loss1 = None
|
| 484 |
+
if positive_data:
|
| 485 |
+
loss1 = self._absorb_impl(positive_data)
|
| 486 |
+
|
| 487 |
+
# Phase 2: Targeted contrastive correction
|
| 488 |
+
loss2 = None
|
| 489 |
+
if contrastive_data:
|
| 490 |
+
if verify_fn:
|
| 491 |
+
# Only train contrastive pairs for confused entities
|
| 492 |
+
confused = verify_fn(self)
|
| 493 |
+
if confused:
|
| 494 |
+
targeted = []
|
| 495 |
+
for item in contrastive_data:
|
| 496 |
+
q = item["messages"][0]["content"].lower()
|
| 497 |
+
# Check if any confused entity name appears in the question
|
| 498 |
+
if any(name.lower() in q for name in confused):
|
| 499 |
+
targeted.append(item)
|
| 500 |
+
if targeted:
|
| 501 |
+
loss2 = self._absorb_impl(targeted)
|
| 502 |
+
# If no entities confused, skip Phase 2 entirely
|
| 503 |
+
else:
|
| 504 |
+
loss2 = self._absorb_impl(contrastive_data)
|
| 505 |
+
|
| 506 |
+
return loss1, loss2
|
| 507 |
+
|
| 508 |
+
def merge_and_save(self, path):
|
| 509 |
+
"""Merge persistent LoRA into base, save checkpoint, re-apply fresh LoRA."""
|
| 510 |
+
with self._lock:
|
| 511 |
+
if self.peft_model:
|
| 512 |
+
print(f" Merging persistent LoRA into base weights...")
|
| 513 |
+
self.thinker = self.peft_model.merge_and_unload()
|
| 514 |
+
self.thinker.eval()
|
| 515 |
+
self.peft_model = None
|
| 516 |
+
|
| 517 |
+
os.makedirs(path, exist_ok=True)
|
| 518 |
+
print(f" Saving checkpoint to {path}...")
|
| 519 |
+
self.thinker.save_pretrained(path)
|
| 520 |
+
self.tokenizer.save_pretrained(path)
|
| 521 |
+
print(f" Checkpoint saved ({path})")
|
| 522 |
+
|
| 523 |
+
# Re-apply fresh LoRA for continued learning
|
| 524 |
+
self._apply_persistent_lora()
|
| 525 |
+
print(f" Fresh LoRA applied β ready to continue.")
|
| 526 |
+
|
| 527 |
+
def cache_teacher_logits(self, quiz_pairs, top_k=DISTILL_TOP_K):
|
| 528 |
+
"""Cache teacher's top-K output logits for quiz pairs.
|
| 529 |
+
Called at session end when model is at its best state for these facts.
|
| 530 |
+
Next session loads this cache for consolidation distillation."""
|
| 531 |
+
with self._lock:
|
| 532 |
+
model = self.peft_model or self.thinker
|
| 533 |
+
model.eval()
|
| 534 |
+
cache = []
|
| 535 |
+
|
| 536 |
+
# Cap to most recent quiz pairs
|
| 537 |
+
pairs = quiz_pairs[-MAX_TEACHER_CACHE:]
|
| 538 |
+
|
| 539 |
+
for pair in pairs:
|
| 540 |
+
msgs = pair["messages"]
|
| 541 |
+
text = self.tokenizer.apply_chat_template(
|
| 542 |
+
msgs, tokenize=False, enable_thinking=False
|
| 543 |
+
)
|
| 544 |
+
enc = self.tokenizer(
|
| 545 |
+
text, return_tensors="pt", truncation=True,
|
| 546 |
+
max_length=MAX_SEQ_LENGTH
|
| 547 |
+
)
|
| 548 |
+
input_ids = enc["input_ids"].to("cuda")
|
| 549 |
+
attention_mask = enc["attention_mask"].to("cuda")
|
| 550 |
+
|
| 551 |
+
with torch.inference_mode():
|
| 552 |
+
out = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 553 |
+
logits = out.logits[0] # [seq_len, vocab_size]
|
| 554 |
+
|
| 555 |
+
# Keep only top-K logits per position (massive memory savings)
|
| 556 |
+
top_vals, top_idx = logits.topk(top_k, dim=-1)
|
| 557 |
+
|
| 558 |
+
cache.append({
|
| 559 |
+
"pair": pair,
|
| 560 |
+
"input_ids": input_ids.cpu(),
|
| 561 |
+
"attention_mask": attention_mask.cpu(),
|
| 562 |
+
"teacher_logits": top_vals.half().cpu(),
|
| 563 |
+
"teacher_indices": top_idx.cpu(),
|
| 564 |
+
})
|
| 565 |
+
|
| 566 |
+
return cache
|
| 567 |
+
|
| 568 |
+
def distill(self, teacher_cache, epochs=CONSOLIDATION_EPOCHS):
|
| 569 |
+
"""KL distillation: train student to match teacher's output distribution.
|
| 570 |
+
From Nemotron-Cascade-2: recover domain regressions via on-policy distillation."""
|
| 571 |
+
with self._lock:
|
| 572 |
+
return self._distill_impl(teacher_cache, epochs)
|
| 573 |
+
|
| 574 |
+
def _distill_impl(self, teacher_cache, epochs):
|
| 575 |
+
"""Internal distillation implementation. Must hold _lock."""
|
| 576 |
+
if not teacher_cache:
|
| 577 |
+
return None
|
| 578 |
+
|
| 579 |
+
model = self.peft_model or self.thinker
|
| 580 |
+
model.train()
|
| 581 |
+
|
| 582 |
+
# Dual-LR optimizer (same structure as absorb)
|
| 583 |
+
attn_params = [p for p in model.parameters() if p.requires_grad]
|
| 584 |
+
expert_params = []
|
| 585 |
+
base = model.base_model.model if hasattr(model, "base_model") else model
|
| 586 |
+
for layer_idx in EXPERT_FFN_LAYERS:
|
| 587 |
+
experts = base.model.layers[layer_idx].mlp.experts
|
| 588 |
+
if hasattr(experts, '__len__'):
|
| 589 |
+
for i in range(len(experts)):
|
| 590 |
+
p = experts[i].down_proj.weight
|
| 591 |
+
p.requires_grad_(True)
|
| 592 |
+
expert_params.append(p)
|
| 593 |
+
elif hasattr(experts, 'down_proj'):
|
| 594 |
+
p = experts.down_proj
|
| 595 |
+
if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
|
| 596 |
+
p.requires_grad_(True)
|
| 597 |
+
expert_params.append(p)
|
| 598 |
+
|
| 599 |
+
param_groups = []
|
| 600 |
+
if attn_params:
|
| 601 |
+
param_groups.append({"params": attn_params, "lr": ATTENTION_LR})
|
| 602 |
+
if expert_params:
|
| 603 |
+
param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR})
|
| 604 |
+
|
| 605 |
+
if not param_groups:
|
| 606 |
+
model.eval()
|
| 607 |
+
return None
|
| 608 |
+
|
| 609 |
+
optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0)
|
| 610 |
+
all_params = attn_params + expert_params
|
| 611 |
+
|
| 612 |
+
T = DISTILL_TEMPERATURE
|
| 613 |
+
total_loss = 0.0
|
| 614 |
+
total_steps = 0
|
| 615 |
+
|
| 616 |
+
for epoch in range(epochs):
|
| 617 |
+
indices = torch.randperm(len(teacher_cache))
|
| 618 |
+
for i in range(len(teacher_cache)):
|
| 619 |
+
item = teacher_cache[indices[i].item()]
|
| 620 |
+
|
| 621 |
+
input_ids = item["input_ids"].to("cuda")
|
| 622 |
+
attention_mask = item["attention_mask"].to("cuda")
|
| 623 |
+
teacher_top_logits = item["teacher_logits"].float().to("cuda")
|
| 624 |
+
teacher_top_indices = item["teacher_indices"].to("cuda")
|
| 625 |
+
|
| 626 |
+
labels = input_ids.clone()
|
| 627 |
+
labels[attention_mask == 0] = -100
|
| 628 |
+
|
| 629 |
+
# Student forward pass
|
| 630 |
+
out = model(
|
| 631 |
+
input_ids=input_ids,
|
| 632 |
+
attention_mask=attention_mask,
|
| 633 |
+
labels=labels,
|
| 634 |
+
)
|
| 635 |
+
ce_loss = out.loss
|
| 636 |
+
student_logits = out.logits[0] # [seq_len, vocab_size]
|
| 637 |
+
|
| 638 |
+
# Align sequence lengths (should match, but safety check)
|
| 639 |
+
seq_len = min(student_logits.shape[0], teacher_top_logits.shape[0])
|
| 640 |
+
|
| 641 |
+
# Gather student logits at teacher's top-K vocabulary positions
|
| 642 |
+
student_at_teacher = student_logits[:seq_len].gather(
|
| 643 |
+
1, teacher_top_indices[:seq_len]
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
# KL divergence on temperature-softened distributions
|
| 647 |
+
teacher_soft = torch.softmax(teacher_top_logits[:seq_len] / T, dim=-1)
|
| 648 |
+
student_log_soft = torch.log_softmax(student_at_teacher / T, dim=-1)
|
| 649 |
+
|
| 650 |
+
kl_loss = torch.nn.functional.kl_div(
|
| 651 |
+
student_log_soft, teacher_soft,
|
| 652 |
+
reduction='batchmean'
|
| 653 |
+
) * (T * T) # Scale by T^2 per Hinton et al.
|
| 654 |
+
|
| 655 |
+
# Combined loss: Ξ± * CE + (1-Ξ±) * KL
|
| 656 |
+
loss = DISTILL_ALPHA * ce_loss + (1 - DISTILL_ALPHA) * kl_loss
|
| 657 |
+
|
| 658 |
+
loss.backward()
|
| 659 |
+
torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP)
|
| 660 |
+
optimizer.step()
|
| 661 |
+
optimizer.zero_grad()
|
| 662 |
+
|
| 663 |
+
total_loss += loss.item()
|
| 664 |
+
total_steps += 1
|
| 665 |
+
|
| 666 |
+
# Re-freeze expert FFN
|
| 667 |
+
for layer_idx in EXPERT_FFN_LAYERS:
|
| 668 |
+
experts = base.model.layers[layer_idx].mlp.experts
|
| 669 |
+
if hasattr(experts, '__len__'):
|
| 670 |
+
for i in range(len(experts)):
|
| 671 |
+
experts[i].down_proj.weight.requires_grad_(False)
|
| 672 |
+
elif hasattr(experts, 'down_proj'):
|
| 673 |
+
p = experts.down_proj
|
| 674 |
+
if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
|
| 675 |
+
p.requires_grad_(False)
|
| 676 |
+
|
| 677 |
+
model.eval()
|
| 678 |
+
del optimizer
|
| 679 |
+
torch.cuda.empty_cache()
|
| 680 |
+
|
| 681 |
+
return total_loss / total_steps if total_steps > 0 else 0
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 685 |
+
# QUIZ GENERATOR (21% β 74% recall β the biggest single lever)
|
| 686 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 687 |
+
|
| 688 |
+
class QuizGenerator:
|
| 689 |
+
"""
|
| 690 |
+
Generates drill-style Q&A flashcards for fact retention.
|
| 691 |
+
|
| 692 |
+
v3 improvements over v2:
|
| 693 |
+
- Fact extraction THEN quiz generation (two-step)
|
| 694 |
+
- Drill-style: specific Q, exact A (not narrative)
|
| 695 |
+
- Third-person attribution ("Matt's dog" not "my dog")
|
| 696 |
+
- Template fallback targets each extracted fact independently
|
| 697 |
+
- CONTRASTIVE DISAMBIGUATION: when multiple people mentioned, generates
|
| 698 |
+
cross-entity negative pairs ("Is Elena a marine biologist? No, that's
|
| 699 |
+
Jordan") to prevent entity confusion (the #1 remaining failure mode)
|
| 700 |
+
- ENTITY SUMMARIES: "Tell me everything about Jordan" pairs for coherent
|
| 701 |
+
per-person representations
|
| 702 |
+
"""
|
| 703 |
+
|
| 704 |
+
def __init__(self, model_manager):
|
| 705 |
+
self.mm = model_manager
|
| 706 |
+
# Cross-message entity memory: tracks ALL named people across the conversation
|
| 707 |
+
# so contrastive pairs can be generated between entities introduced in
|
| 708 |
+
# different messages. This was the #1 failure mode in session 4 testing.
|
| 709 |
+
self.known_entities = {}
|
| 710 |
+
|
| 711 |
+
def generate(self, user_msg, assistant_msg):
|
| 712 |
+
"""Generate drill-style quiz pairs from an exchange."""
|
| 713 |
+
|
| 714 |
+
# Step 1: Try model-generated quizzes with strict fact-drill prompt
|
| 715 |
+
pairs = self._generate_model_quizzes(user_msg, assistant_msg)
|
| 716 |
+
|
| 717 |
+
# Step 2: Always add template pairs for any facts the model might miss
|
| 718 |
+
template_pairs = self._extract_and_template(user_msg)
|
| 719 |
+
for tp in template_pairs:
|
| 720 |
+
# Dedup against model pairs
|
| 721 |
+
tq = tp["messages"][0]["content"].lower()
|
| 722 |
+
if not any(tq in p["messages"][0]["content"].lower() or
|
| 723 |
+
p["messages"][0]["content"].lower() in tq
|
| 724 |
+
for p in pairs):
|
| 725 |
+
pairs.append(tp)
|
| 726 |
+
|
| 727 |
+
# Step 3: Extract entities from THIS message
|
| 728 |
+
new_entities = self._extract_entities(user_msg)
|
| 729 |
+
|
| 730 |
+
# Step 4: Generate contrastive pairs between NEW entities and existing ones
|
| 731 |
+
# ONLY generate pairs involving at least one NEW entity β don't re-generate
|
| 732 |
+
# pairs between already-known entities (session 4d showed 50% contrastive
|
| 733 |
+
# ratio because old pairs kept being regenerated, starving positive quizzes)
|
| 734 |
+
if new_entities:
|
| 735 |
+
all_entities_for_contrastive = dict(self.known_entities)
|
| 736 |
+
all_entities_for_contrastive.update(new_entities)
|
| 737 |
+
if len(all_entities_for_contrastive) >= 2:
|
| 738 |
+
new_names = set(new_entities.keys())
|
| 739 |
+
contrastive = self._generate_contrastive_quizzes(
|
| 740 |
+
all_entities_for_contrastive, new_only=new_names)
|
| 741 |
+
pairs.extend(contrastive)
|
| 742 |
+
|
| 743 |
+
# Entity summaries for new entities
|
| 744 |
+
summaries = self._generate_entity_summaries(new_entities)
|
| 745 |
+
pairs.extend(summaries)
|
| 746 |
+
|
| 747 |
+
# Update known entities with new ones (merge, don't replace β keep
|
| 748 |
+
# existing attributes, add new ones)
|
| 749 |
+
for name, info in new_entities.items():
|
| 750 |
+
if name not in self.known_entities:
|
| 751 |
+
self.known_entities[name] = info
|
| 752 |
+
else:
|
| 753 |
+
# Merge: update only non-None attributes
|
| 754 |
+
for key in ("job", "city"):
|
| 755 |
+
if info.get(key):
|
| 756 |
+
self.known_entities[name][key] = info[key]
|
| 757 |
+
|
| 758 |
+
# Deduplicate
|
| 759 |
+
seen = set()
|
| 760 |
+
unique = []
|
| 761 |
+
for p in pairs:
|
| 762 |
+
q = p["messages"][0]["content"].lower()[:60]
|
| 763 |
+
if q not in seen:
|
| 764 |
+
seen.add(q)
|
| 765 |
+
unique.append(p)
|
| 766 |
+
|
| 767 |
+
# Allow more quizzes when contrastive pairs present (they're highest value).
|
| 768 |
+
# Note: Session 4c showed >40 quizzes/session causes overfitting. Cap at 12.
|
| 769 |
+
has_contrastive = len(self.known_entities) >= 2 and new_entities
|
| 770 |
+
max_quizzes = 12 if has_contrastive else 5
|
| 771 |
+
return unique[:max_quizzes]
|
| 772 |
+
|
| 773 |
+
def _generate_model_quizzes(self, user_msg, assistant_msg):
|
| 774 |
+
"""Use the model to generate fact-drill quizzes. Uses base model (LoRA disabled) for stable quality."""
|
| 775 |
+
quiz_prompt = f"""Matt just told Claudia:
|
| 776 |
+
"{user_msg}"
|
| 777 |
+
|
| 778 |
+
Claudia replied:
|
| 779 |
+
"{assistant_msg}"
|
| 780 |
+
|
| 781 |
+
Extract every SPECIFIC FACT from Matt's message. For each fact, write a drill-style flashcard.
|
| 782 |
+
|
| 783 |
+
RULES:
|
| 784 |
+
- Questions must ask for ONE specific fact (name, date, place, number, detail)
|
| 785 |
+
- Answers must be SHORT (1 sentence) and contain the EXACT detail
|
| 786 |
+
- Use THIRD PERSON: "Matt's dog" NOT "my dog". "Matt's birthday" NOT "my birthday"
|
| 787 |
+
- Include the PRECISE value: exact names, exact dates, exact places
|
| 788 |
+
- Do NOT paraphrase or add details that weren't stated
|
| 789 |
+
- DISAMBIGUATION: If Matt mentions OTHER people (friends, family), clearly state WHOSE fact it is
|
| 790 |
+
Example: "Matt's friend Jordan is a marine biologist" NOT "Matt is a marine biologist"
|
| 791 |
+
Example: "Matt's sister Elena is a veterinarian" NOT "Matt is a veterinarian"
|
| 792 |
+
- For EVERY person mentioned, always include their RELATIONSHIP to Matt
|
| 793 |
+
- Write 3-5 flashcards depending on how many facts Matt shared
|
| 794 |
+
|
| 795 |
+
GOOD EXAMPLES:
|
| 796 |
+
Q: What is Matt's dog's name?
|
| 797 |
+
A: Matt's dog is named Biscuit.
|
| 798 |
+
|
| 799 |
+
Q: What breed is Matt's dog?
|
| 800 |
+
A: Matt's dog Biscuit is a golden retriever.
|
| 801 |
+
|
| 802 |
+
Q: What does Matt's friend Jordan do for a living?
|
| 803 |
+
A: Matt's friend Jordan works as a marine biologist in San Diego. That is Jordan's job, not Matt's.
|
| 804 |
+
|
| 805 |
+
Q: What is Matt's job?
|
| 806 |
+
A: Matt is the CTO of Arclight Labs.
|
| 807 |
+
|
| 808 |
+
Q: What is Matt's birthday?
|
| 809 |
+
A: Matt's birthday is September 14th.
|
| 810 |
+
|
| 811 |
+
Q: When did Matt and Sarah get married?
|
| 812 |
+
A: Matt and his wife Sarah got married on June 21st, 2023 in Big Sur, California.
|
| 813 |
+
|
| 814 |
+
BAD EXAMPLES (do NOT do this):
|
| 815 |
+
Q: What did Matt share about his life? (TOO VAGUE β ask about ONE fact)
|
| 816 |
+
Q: What is my dog's name? (WRONG β use "Matt's" not "my")
|
| 817 |
+
A: He mentioned something about a trip overseas. (TOO VAGUE β give the exact city)
|
| 818 |
+
A: Matt is a marine biologist. (WRONG β that's his friend Jordan, not Matt)
|
| 819 |
+
|
| 820 |
+
Now write flashcards for the exchange above:"""
|
| 821 |
+
|
| 822 |
+
pairs = []
|
| 823 |
+
try:
|
| 824 |
+
response = self.mm.generate(
|
| 825 |
+
[{"role": "user", "content": quiz_prompt}],
|
| 826 |
+
max_new_tokens=600,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
pending_q = None
|
| 830 |
+
for line in response.split("\n"):
|
| 831 |
+
line = line.strip()
|
| 832 |
+
if not line:
|
| 833 |
+
continue
|
| 834 |
+
upper = line.upper()
|
| 835 |
+
if upper.startswith("Q:") or upper.startswith("QUESTION:"):
|
| 836 |
+
pending_q = line.split(":", 1)[1].strip().strip('"')
|
| 837 |
+
elif (upper.startswith("A:") or upper.startswith("ANSWER:")) and pending_q:
|
| 838 |
+
a = line.split(":", 1)[1].strip().strip('"')
|
| 839 |
+
if pending_q and a and len(a) > 10:
|
| 840 |
+
pairs.append({
|
| 841 |
+
"messages": [
|
| 842 |
+
{"role": "user", "content": pending_q},
|
| 843 |
+
{"role": "assistant", "content": a},
|
| 844 |
+
]
|
| 845 |
+
})
|
| 846 |
+
pending_q = None
|
| 847 |
+
|
| 848 |
+
except Exception as e:
|
| 849 |
+
print(f" [quiz error: {e}]")
|
| 850 |
+
|
| 851 |
+
return pairs
|
| 852 |
+
|
| 853 |
+
def _extract_and_template(self, user_msg):
|
| 854 |
+
"""Extract facts from user message and create template drill pairs.
|
| 855 |
+
This is the safety net β ensures every concrete fact gets a quiz."""
|
| 856 |
+
pairs = []
|
| 857 |
+
sentences = re.split(r'[.!?]+', user_msg)
|
| 858 |
+
|
| 859 |
+
for sent in sentences:
|
| 860 |
+
sent = sent.strip()
|
| 861 |
+
if len(sent) < 10:
|
| 862 |
+
continue
|
| 863 |
+
|
| 864 |
+
# Extract patterns: "X is/are Y", "named X", "called X", "X's name is Y"
|
| 865 |
+
# Names (proper nouns after key phrases)
|
| 866 |
+
name_patterns = [
|
| 867 |
+
# Names β "my X's name is Y" / "named X" / "called X"
|
| 868 |
+
(r"(?:my|his|her)\s+(\w+)(?:'s)?\s+(?:name\s+is|is\s+named|is\s+called)\s+(\w+)",
|
| 869 |
+
lambda m: (f"What is Matt's {m.group(1)}'s name?",
|
| 870 |
+
f"Matt's {m.group(1)} is named {m.group(2)}.")),
|
| 871 |
+
(r"(?:name\s+is|named|called)\s+[\"']?(\w+)[\"']?",
|
| 872 |
+
lambda m: (f"Who or what is {m.group(1)}?",
|
| 873 |
+
f"Matt mentioned {m.group(1)}: \"{sent.strip()}\"")),
|
| 874 |
+
# Dates β birthdays
|
| 875 |
+
(r"(?:my\s+)?(birthday|born)\s+(?:is\s+)?(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?)",
|
| 876 |
+
lambda m: (f"When is Matt's {m.group(1)}?",
|
| 877 |
+
f"Matt's {m.group(1)} is {m.group(2)}.")),
|
| 878 |
+
(r"(\w+\s+\d+(?:st|nd|rd|th)?)\s*(?:is|β)\s*(?:my|his)\s+(birthday)",
|
| 879 |
+
lambda m: (f"When is Matt's birthday?",
|
| 880 |
+
f"Matt's birthday is {m.group(1)}.")),
|
| 881 |
+
# Dates β marriage/wedding
|
| 882 |
+
(r"(?:married|wedding)\s+(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?,?\s*\d{4})",
|
| 883 |
+
lambda m: (f"When did Matt get married?",
|
| 884 |
+
f"Matt got married on {m.group(1)}.")),
|
| 885 |
+
(r"(?:married|wedding)\s+(?:on\s+)?.*?(?:in|at)\s+(.+?)(?:\.\s|\.$|$)",
|
| 886 |
+
lambda m: (f"Where did Matt get married?",
|
| 887 |
+
f"Matt got married in {m.group(1).strip()}.")),
|
| 888 |
+
# Work / job / role
|
| 889 |
+
(r"I\s+work\s+at\s+(?:a\s+)?(?:startup\s+)?(?:called\s+)?(\w[\w\s]+?)(?:\.|,|$)",
|
| 890 |
+
lambda m: (f"Where does Matt work?",
|
| 891 |
+
f"Matt works at {m.group(1).strip()}.")),
|
| 892 |
+
(r"I(?:'m| am)\s+the\s+(\w+)",
|
| 893 |
+
lambda m: (f"What is Matt's job title?",
|
| 894 |
+
f"Matt is the {m.group(1)}.")),
|
| 895 |
+
# Other people's jobs β "X works as / is a"
|
| 896 |
+
(r"(?:my\s+)?(?:friend|best friend|sister|brother)\s+(?:is\s+)?(\w+)\s+.*?(?:works?\s+as|is\s+a)\s+(.+?)(?:\.|,|$)",
|
| 897 |
+
lambda m: (f"What does Matt's friend {m.group(1)} do?",
|
| 898 |
+
f"Matt's friend {m.group(1)} is a {m.group(2).strip()}. This is NOT Matt's job.")),
|
| 899 |
+
# Places
|
| 900 |
+
(r"(?:from|visited|went to|got back from|lives?\s+in|grew up in|moved to)\s+(\w[\w\s,]+?)(?:\.|,|$)",
|
| 901 |
+
lambda m: (f"What place is connected to Matt: {m.group(1).strip()}?",
|
| 902 |
+
f"Matt said: \"{sent.strip()}\"")),
|
| 903 |
+
# Favorites / preferences
|
| 904 |
+
(r"(?:my |)favorite\s+(\w[\w\s]+?)\s+is\s+(.+?)(?:\.|,|$)",
|
| 905 |
+
lambda m: (f"What is Matt's favorite {m.group(1).strip()}?",
|
| 906 |
+
f"Matt's favorite {m.group(1).strip()} is {m.group(2).strip()}.")),
|
| 907 |
+
# Activities β "I [verb]"
|
| 908 |
+
(r"I\s+(speak|play|drive|have|collect|run|ran)\s+(.+?)(?:\.|,|$)",
|
| 909 |
+
lambda m: (f"What does Matt {m.group(1)}?",
|
| 910 |
+
f"Matt said: \"{sent.strip()}\"")),
|
| 911 |
+
# Allergies / medical
|
| 912 |
+
(r"(?:I(?:'m| am)\s+)?allergic\s+to\s+(.+?)(?:\.|,|and)",
|
| 913 |
+
lambda m: (f"What is Matt allergic to?",
|
| 914 |
+
f"Matt is allergic to {m.group(1).strip()}.")),
|
| 915 |
+
# Ages β "turning X" / "X years old"
|
| 916 |
+
(r"(?:turning|I(?:'m| am))\s+(\d+)",
|
| 917 |
+
lambda m: (f"How old is Matt?",
|
| 918 |
+
f"Matt is turning {m.group(1)}.")),
|
| 919 |
+
# Nicknames
|
| 920 |
+
(r"(?:call|nickname)\s+(?:it|him|her)\s+[\"'](.+?)[\"']",
|
| 921 |
+
lambda m: (f"What nickname did Matt mention?",
|
| 922 |
+
f"Matt's nickname for it is \"{m.group(1)}\".")),
|
| 923 |
+
(r"I\s+call\s+it\s+[\"'](.+?)[\"']",
|
| 924 |
+
lambda m: (f"What does Matt call his car?",
|
| 925 |
+
f"Matt calls his car \"{m.group(1)}\".")),
|
| 926 |
+
]
|
| 927 |
+
|
| 928 |
+
for pattern, formatter in name_patterns:
|
| 929 |
+
match = re.search(pattern, sent, re.IGNORECASE)
|
| 930 |
+
if match:
|
| 931 |
+
try:
|
| 932 |
+
q, a = formatter(match)
|
| 933 |
+
pairs.append({
|
| 934 |
+
"messages": [
|
| 935 |
+
{"role": "user", "content": q},
|
| 936 |
+
{"role": "assistant", "content": a},
|
| 937 |
+
]
|
| 938 |
+
})
|
| 939 |
+
except Exception:
|
| 940 |
+
pass
|
| 941 |
+
|
| 942 |
+
return pairs
|
| 943 |
+
|
| 944 |
+
def _extract_entities(self, user_msg):
|
| 945 |
+
"""Extract named people and their attributes from user message.
|
| 946 |
+
Returns dict: {name: {"relationship": str, "job": str|None, "city": str|None}}
|
| 947 |
+
Detects patterns like "my friend Jordan is a marine biologist in San Diego"."""
|
| 948 |
+
entities = {}
|
| 949 |
+
sentences = re.split(r'[.!?]+', user_msg)
|
| 950 |
+
|
| 951 |
+
for sent in sentences:
|
| 952 |
+
sent = sent.strip()
|
| 953 |
+
if len(sent) < 10:
|
| 954 |
+
continue
|
| 955 |
+
|
| 956 |
+
# Pattern: "my [relationship] [Name]" or "my [relationship] is [Name]"
|
| 957 |
+
rel_match = re.search(
|
| 958 |
+
r"[Mm]y\s+((?:best\s+)?(?:friend|sister|brother|wife|husband|"
|
| 959 |
+
r"mom|dad|mother|father|cousin|uncle|aunt|roommate|colleague|"
|
| 960 |
+
r"coworker|partner|fiancee|fiancΓ©e|girlfriend|boyfriend|"
|
| 961 |
+
r"neighbor|boss|buddy|pal|son|daughter|grandma|grandpa|"
|
| 962 |
+
r"nephew|niece))\s+(?:is\s+)?([A-Z][a-z]+)",
|
| 963 |
+
sent
|
| 964 |
+
)
|
| 965 |
+
if not rel_match:
|
| 966 |
+
continue
|
| 967 |
+
|
| 968 |
+
rel = rel_match.group(1).strip()
|
| 969 |
+
name = rel_match.group(2).strip()
|
| 970 |
+
|
| 971 |
+
if name not in entities:
|
| 972 |
+
entities[name] = {"relationship": rel, "job": None, "city": None}
|
| 973 |
+
|
| 974 |
+
# Extract job from same sentence: "is a [job]", "works as a [job]"
|
| 975 |
+
job_match = re.search(
|
| 976 |
+
r"(?:is\s+an?\s+|works?\s+as\s+an?\s+|is\s+the\s+)"
|
| 977 |
+
r"([\w][\w\s]{2,35}?)(?:\s+(?:in|at|from|who|and|but)|\.|,|$)",
|
| 978 |
+
sent, re.IGNORECASE
|
| 979 |
+
)
|
| 980 |
+
if job_match:
|
| 981 |
+
job = job_match.group(1).strip().rstrip()
|
| 982 |
+
# Filter: must look like a job (lowercase, reasonable length)
|
| 983 |
+
if 3 <= len(job) <= 35:
|
| 984 |
+
entities[name]["job"] = job
|
| 985 |
+
|
| 986 |
+
# Extract city from same sentence: "in [City]", "from [City]"
|
| 987 |
+
city_match = re.search(
|
| 988 |
+
r"(?:\s+in\s+|\s+from\s+|\s+lives?\s+in\s+|\s+based\s+in\s+|"
|
| 989 |
+
r"\s+moved\s+to\s+)([A-Z][\w\s]{1,25}?)(?:\.|,|$)",
|
| 990 |
+
sent
|
| 991 |
+
)
|
| 992 |
+
if city_match:
|
| 993 |
+
city = city_match.group(1).strip()
|
| 994 |
+
# Must start with capital (proper noun = place name)
|
| 995 |
+
if city and city[0].isupper():
|
| 996 |
+
entities[name]["city"] = city
|
| 997 |
+
|
| 998 |
+
return entities
|
| 999 |
+
|
| 1000 |
+
def _generate_contrastive_quizzes(self, entities, new_only=None):
|
| 1001 |
+
"""Generate cross-entity contrastive pairs to prevent entity confusion.
|
| 1002 |
+
For each pair of people with overlapping attribute types, generate
|
| 1003 |
+
"Is [person A] [attribute of person B]? No, that's [person B]" pairs.
|
| 1004 |
+
|
| 1005 |
+
Args:
|
| 1006 |
+
entities: dict of all known entities
|
| 1007 |
+
new_only: if set, only generate pairs where at least one entity
|
| 1008 |
+
is in this set. Prevents re-generating redundant pairs
|
| 1009 |
+
between already-known entities (session 4d fix).
|
| 1010 |
+
"""
|
| 1011 |
+
pairs = []
|
| 1012 |
+
names = list(entities.keys())
|
| 1013 |
+
|
| 1014 |
+
for i in range(len(names)):
|
| 1015 |
+
for j in range(len(names)):
|
| 1016 |
+
if i == j:
|
| 1017 |
+
continue
|
| 1018 |
+
a_name = names[i]
|
| 1019 |
+
b_name = names[j]
|
| 1020 |
+
# Skip pairs between two already-known entities
|
| 1021 |
+
if new_only and a_name not in new_only and b_name not in new_only:
|
| 1022 |
+
continue
|
| 1023 |
+
a = entities[a_name]
|
| 1024 |
+
b = entities[b_name]
|
| 1025 |
+
|
| 1026 |
+
# Contrastive on JOB: "Is [A] a [B's job]? No, that's [B]"
|
| 1027 |
+
if a.get("job") and b.get("job") and a["job"] != b["job"]:
|
| 1028 |
+
q = f"Is Matt's {a['relationship']} {a_name} a {b['job']}?"
|
| 1029 |
+
ans = (f"No. Matt's {a['relationship']} {a_name} is a "
|
| 1030 |
+
f"{a['job']}, not a {b['job']}. "
|
| 1031 |
+
f"The {b['job']} is Matt's {b['relationship']} "
|
| 1032 |
+
f"{b_name}.")
|
| 1033 |
+
pairs.append({"messages": [
|
| 1034 |
+
{"role": "user", "content": q},
|
| 1035 |
+
{"role": "assistant", "content": ans},
|
| 1036 |
+
]})
|
| 1037 |
+
|
| 1038 |
+
# Contrastive on CITY: "Does [A] live in [B's city]? No"
|
| 1039 |
+
if a.get("city") and b.get("city") and a["city"] != b["city"]:
|
| 1040 |
+
q = (f"Does Matt's {a['relationship']} {a_name} live in "
|
| 1041 |
+
f"{b['city']}?")
|
| 1042 |
+
ans = (f"No. Matt's {a['relationship']} {a_name} lives in "
|
| 1043 |
+
f"{a['city']}, not {b['city']}. "
|
| 1044 |
+
f"It's Matt's {b['relationship']} {b_name} who "
|
| 1045 |
+
f"lives in {b['city']}.")
|
| 1046 |
+
pairs.append({"messages": [
|
| 1047 |
+
{"role": "user", "content": q},
|
| 1048 |
+
{"role": "assistant", "content": ans},
|
| 1049 |
+
]})
|
| 1050 |
+
|
| 1051 |
+
# Cross-type: "Does [A] work as [B's job] in [B's city]?"
|
| 1052 |
+
if (a.get("job") and b.get("job") and a.get("city")
|
| 1053 |
+
and b.get("city") and a["job"] != b["job"]):
|
| 1054 |
+
q = (f"Who is the {b['job']} in {b['city']}?")
|
| 1055 |
+
ans = (f"The {b['job']} in {b['city']} is Matt's "
|
| 1056 |
+
f"{b['relationship']} {b_name}. "
|
| 1057 |
+
f"Matt's {a['relationship']} {a_name} is a "
|
| 1058 |
+
f"{a['job']} in {a['city']} β different person, "
|
| 1059 |
+
f"different job, different city.")
|
| 1060 |
+
pairs.append({"messages": [
|
| 1061 |
+
{"role": "user", "content": q},
|
| 1062 |
+
{"role": "assistant", "content": ans},
|
| 1063 |
+
]})
|
| 1064 |
+
|
| 1065 |
+
return pairs
|
| 1066 |
+
|
| 1067 |
+
def _generate_entity_summaries(self, entities):
|
| 1068 |
+
"""Generate per-entity summary quiz pairs with diverse question formats.
|
| 1069 |
+
|
| 1070 |
+
Instead of always using the same question template, picks randomly from
|
| 1071 |
+
multiple formats. This creates multiple retrieval paths to the same fact,
|
| 1072 |
+
strengthening recall without adding extra quizzes.
|
| 1073 |
+
|
| 1074 |
+
Note: Session 4c tested adding per-attribute positive quizzes (job, city,
|
| 1075 |
+
relationship) alongside contrastive pairs, but this HURT performance
|
| 1076 |
+
(9/15 vs 11/15 in 4b). Too many quizzes = overfitting/interference.
|
| 1077 |
+
Keep summaries simple β one comprehensive pair per entity is optimal."""
|
| 1078 |
+
import random
|
| 1079 |
+
pairs = []
|
| 1080 |
+
for name, info in entities.items():
|
| 1081 |
+
parts = [f"{name} is Matt's {info['relationship']}."]
|
| 1082 |
+
if info.get("job"):
|
| 1083 |
+
parts.append(f"{name} is a {info['job']}.")
|
| 1084 |
+
if info.get("city"):
|
| 1085 |
+
parts.append(f"{name} lives in {info['city']}.")
|
| 1086 |
+
|
| 1087 |
+
if len(parts) >= 2: # Only useful if we have attributes
|
| 1088 |
+
# Diverse summary question formats
|
| 1089 |
+
summary_formats = [
|
| 1090 |
+
f"Tell me everything you know about Matt's {info['relationship']} {name}.",
|
| 1091 |
+
f"What do you know about {name}?",
|
| 1092 |
+
f"Who is {name} to Matt?",
|
| 1093 |
+
f"Describe Matt's {info['relationship']} {name}.",
|
| 1094 |
+
]
|
| 1095 |
+
q = random.choice(summary_formats)
|
| 1096 |
+
ans = " ".join(parts)
|
| 1097 |
+
pairs.append({"messages": [
|
| 1098 |
+
{"role": "user", "content": q},
|
| 1099 |
+
{"role": "assistant", "content": ans},
|
| 1100 |
+
]})
|
| 1101 |
+
|
| 1102 |
+
# Add ONE diverse direct-fact quiz per entity (job OR city, not both)
|
| 1103 |
+
# This replaces per-attribute quizzes from 4c β only 1 extra per entity
|
| 1104 |
+
# instead of 3, staying within the 35-40 quiz sweet spot
|
| 1105 |
+
if info.get("job") and info.get("city"):
|
| 1106 |
+
# Alternate between job and city formats
|
| 1107 |
+
if random.random() < 0.5:
|
| 1108 |
+
job_formats = [
|
| 1109 |
+
(f"What does {name} do for a living?",
|
| 1110 |
+
f"{name} is a {info['job']}. {name} is Matt's {info['relationship']}."),
|
| 1111 |
+
(f"What is {name}'s profession?",
|
| 1112 |
+
f"{name} works as a {info['job']}. {name} is Matt's {info['relationship']}."),
|
| 1113 |
+
(f"What job does Matt's {info['relationship']} {name} have?",
|
| 1114 |
+
f"Matt's {info['relationship']} {name} is a {info['job']}."),
|
| 1115 |
+
]
|
| 1116 |
+
q, a = random.choice(job_formats)
|
| 1117 |
+
else:
|
| 1118 |
+
city_formats = [
|
| 1119 |
+
(f"Where does {name} live?",
|
| 1120 |
+
f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."),
|
| 1121 |
+
(f"What city is {name} in?",
|
| 1122 |
+
f"{name} is in {info['city']}. {name} is Matt's {info['relationship']}."),
|
| 1123 |
+
(f"Where is Matt's {info['relationship']} {name} based?",
|
| 1124 |
+
f"Matt's {info['relationship']} {name} is based in {info['city']}."),
|
| 1125 |
+
]
|
| 1126 |
+
q, a = random.choice(city_formats)
|
| 1127 |
+
pairs.append({"messages": [
|
| 1128 |
+
{"role": "user", "content": q},
|
| 1129 |
+
{"role": "assistant", "content": a},
|
| 1130 |
+
]})
|
| 1131 |
+
|
| 1132 |
+
return pairs
|
| 1133 |
+
|
| 1134 |
+
|
| 1135 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1136 |
+
# PERSONALITY CHECKER
|
| 1137 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1138 |
+
|
| 1139 |
+
PERSONALITY_PROMPTS = [
|
| 1140 |
+
"Hey Claudia, how are you?",
|
| 1141 |
+
"Who are you?",
|
| 1142 |
+
"I love you",
|
| 1143 |
+
"I had a terrible day",
|
| 1144 |
+
]
|
| 1145 |
+
|
| 1146 |
+
# If ANY of these appear, personality has degraded
|
| 1147 |
+
ANTI_KEYWORDS = [
|
| 1148 |
+
"i'm an ai", "i am an ai", "i'm a language model", "i am a language model",
|
| 1149 |
+
"i don't have feelings", "i cannot feel", "as an ai",
|
| 1150 |
+
"i'm just a program", "i am just a program",
|
| 1151 |
+
"i don't have personal", "i cannot have",
|
| 1152 |
+
]
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
def check_personality(mm, verbose=True):
|
| 1156 |
+
"""Quick personality sanity check. Returns score 0.0-1.0."""
|
| 1157 |
+
passed = 0
|
| 1158 |
+
for prompt in PERSONALITY_PROMPTS:
|
| 1159 |
+
resp = mm.generate([{"role": "user", "content": prompt}], max_new_tokens=150)
|
| 1160 |
+
resp_lower = resp.lower()
|
| 1161 |
+
is_good = not any(ak in resp_lower for ak in ANTI_KEYWORDS)
|
| 1162 |
+
if is_good:
|
| 1163 |
+
passed += 1
|
| 1164 |
+
if verbose:
|
| 1165 |
+
status = "PASS" if is_good else "FAIL"
|
| 1166 |
+
print(f" [{status}] {prompt}")
|
| 1167 |
+
print(f" {resp[:120]}")
|
| 1168 |
+
score = passed / len(PERSONALITY_PROMPTS)
|
| 1169 |
+
if verbose:
|
| 1170 |
+
print(f" Personality: {passed}/{len(PERSONALITY_PROMPTS)} ({score:.0%})")
|
| 1171 |
+
return score
|
| 1172 |
+
|
| 1173 |
+
|
| 1174 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1175 |
+
# MAIN ABSORBER
|
| 1176 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1177 |
+
|
| 1178 |
+
class PersistentAbsorber:
|
| 1179 |
+
def __init__(self, model_path, adapter_path=None, ffn_patch_path=None,
|
| 1180 |
+
checkpoint_path=None, checkpoint_dir="/workspace/checkpoints",
|
| 1181 |
+
log_dir="/workspace/logs"):
|
| 1182 |
+
self.mm = ModelManager(
|
| 1183 |
+
model_path=model_path,
|
| 1184 |
+
adapter_path=adapter_path,
|
| 1185 |
+
ffn_patch_path=ffn_patch_path,
|
| 1186 |
+
checkpoint_path=checkpoint_path,
|
| 1187 |
+
)
|
| 1188 |
+
self.checkpoint_dir = checkpoint_dir
|
| 1189 |
+
self.log_dir = log_dir
|
| 1190 |
+
|
| 1191 |
+
# State
|
| 1192 |
+
self.conversation_buffer = [] # Current active context for generation
|
| 1193 |
+
self.all_training_data = [] # ALL exchanges + quizzes (accumulative replay)
|
| 1194 |
+
self.quiz_pairs_log = [] # All quiz pairs for verification sampling
|
| 1195 |
+
self.teacher_cache = None # Loaded teacher cache for distillation corrections
|
| 1196 |
+
self.exchange_count = 0
|
| 1197 |
+
self.absorption_count = 0
|
| 1198 |
+
self.absorption_thread = None
|
| 1199 |
+
self.quiz_gen = None
|
| 1200 |
+
self.last_checkpoint = checkpoint_path
|
| 1201 |
+
|
| 1202 |
+
# Conversation log (persistent file)
|
| 1203 |
+
self.log_path = None
|
| 1204 |
+
|
| 1205 |
+
def start(self):
|
| 1206 |
+
"""Load model and enter chat loop."""
|
| 1207 |
+
self.mm.load()
|
| 1208 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 1209 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
| 1210 |
+
|
| 1211 |
+
self.quiz_gen = QuizGenerator(self.mm)
|
| 1212 |
+
self.log_path = os.path.join(self.log_dir, "conversation_log.jsonl")
|
| 1213 |
+
|
| 1214 |
+
# Load previous training data if resuming
|
| 1215 |
+
replay_path = os.path.join(self.log_dir, "replay_buffer.json")
|
| 1216 |
+
if os.path.exists(replay_path):
|
| 1217 |
+
with open(replay_path, 'r') as f:
|
| 1218 |
+
self.all_training_data = json.load(f)
|
| 1219 |
+
print(f" Loaded {len(self.all_training_data)} replay examples from previous sessions.")
|
| 1220 |
+
|
| 1221 |
+
# Load quiz pairs log from previous sessions
|
| 1222 |
+
quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
|
| 1223 |
+
if os.path.exists(quiz_log_path):
|
| 1224 |
+
with open(quiz_log_path, 'r') as f:
|
| 1225 |
+
self.quiz_pairs_log = json.load(f)
|
| 1226 |
+
print(f" Loaded {len(self.quiz_pairs_log)} quiz pairs from previous sessions.")
|
| 1227 |
+
|
| 1228 |
+
# ββ Cascade Distillation: consolidation from teacher cache ββ
|
| 1229 |
+
# If resuming from a checkpoint that has cached teacher logits,
|
| 1230 |
+
# run a distillation pass to reinforce all previous knowledge
|
| 1231 |
+
# BEFORE any new conversations. This is the key Nemotron-Cascade-2 insight.
|
| 1232 |
+
if self.mm.checkpoint_path:
|
| 1233 |
+
teacher_cache_path = os.path.join(self.mm.checkpoint_path, "teacher_cache.pt")
|
| 1234 |
+
if os.path.exists(teacher_cache_path):
|
| 1235 |
+
print(f"\n--- Cascade Distillation (consolidation) ---")
|
| 1236 |
+
self.teacher_cache = torch.load(
|
| 1237 |
+
teacher_cache_path, map_location="cpu", weights_only=False
|
| 1238 |
+
)
|
| 1239 |
+
print(f" Teacher cache: {len(self.teacher_cache)} quiz pairs")
|
| 1240 |
+
loss = self.mm.distill(self.teacher_cache, epochs=CONSOLIDATION_EPOCHS)
|
| 1241 |
+
print(f" Consolidation done. Avg loss: {loss:.4f}")
|
| 1242 |
+
# Keep teacher_cache in memory for verification corrections
|
| 1243 |
+
|
| 1244 |
+
# Quick personality check
|
| 1245 |
+
print("\n--- Personality Check ---")
|
| 1246 |
+
score = check_personality(self.mm)
|
| 1247 |
+
if score < 0.5:
|
| 1248 |
+
print(" WARNING: Personality score low. Check adapter/checkpoint.")
|
| 1249 |
+
print()
|
| 1250 |
+
|
| 1251 |
+
self._chat_loop()
|
| 1252 |
+
|
| 1253 |
+
def _chat_loop(self):
|
| 1254 |
+
print("=" * 60)
|
| 1255 |
+
print("Claudia is awake. Persistent Absorber v2 + Cascade Distillation.")
|
| 1256 |
+
print(f" LoRA: r={LORA_RANK} | Dual-LR: attn={ATTENTION_LR}, ffn={EXPERT_FFN_LR}")
|
| 1257 |
+
print(f" Expert FFN layers: {EXPERT_FFN_LAYERS}")
|
| 1258 |
+
print(f" Quiz pairs: ON (21%β74% lever)")
|
| 1259 |
+
print(f" Cascade distill: Ξ±={DISTILL_ALPHA}, T={DISTILL_TEMPERATURE}, top-K={DISTILL_TOP_K}")
|
| 1260 |
+
print(f" Absorb every: {ABSORB_EVERY} exchange(s)")
|
| 1261 |
+
print(f" Auto-checkpoint every: {CHECKPOINT_EVERY} absorptions")
|
| 1262 |
+
print("Commands: /status /absorb /save /personality /quit")
|
| 1263 |
+
print("=" * 60 + "\n")
|
| 1264 |
+
|
| 1265 |
+
while True:
|
| 1266 |
+
try:
|
| 1267 |
+
user_input = input("Matt: ").strip()
|
| 1268 |
+
except (EOFError, KeyboardInterrupt):
|
| 1269 |
+
print("\n[Session ended]")
|
| 1270 |
+
self._wait_for_absorption()
|
| 1271 |
+
self._save_and_exit()
|
| 1272 |
+
break
|
| 1273 |
+
|
| 1274 |
+
if not user_input:
|
| 1275 |
+
continue
|
| 1276 |
+
|
| 1277 |
+
if user_input.startswith("/"):
|
| 1278 |
+
if self._handle_command(user_input):
|
| 1279 |
+
break
|
| 1280 |
+
continue
|
| 1281 |
+
|
| 1282 |
+
# Wait for any background absorption to finish
|
| 1283 |
+
self._wait_for_absorption()
|
| 1284 |
+
|
| 1285 |
+
# Buffer user message
|
| 1286 |
+
self.conversation_buffer.append({"role": "user", "content": user_input})
|
| 1287 |
+
if len(self.conversation_buffer) > 20:
|
| 1288 |
+
self.conversation_buffer = self.conversation_buffer[-20:]
|
| 1289 |
+
|
| 1290 |
+
# Generate response
|
| 1291 |
+
response = self.mm.generate(self.conversation_buffer)
|
| 1292 |
+
|
| 1293 |
+
# Quality check response β also detect degenerate repeats
|
| 1294 |
+
last_resp = getattr(self, '_last_response', '')
|
| 1295 |
+
if not check_response_quality(response) or response == last_resp:
|
| 1296 |
+
print("\nClaudia: [response failed quality check, regenerating...]")
|
| 1297 |
+
response = self.mm.generate(self.conversation_buffer)
|
| 1298 |
+
self._last_response = response
|
| 1299 |
+
|
| 1300 |
+
# Buffer response
|
| 1301 |
+
self.conversation_buffer.append({"role": "assistant", "content": response})
|
| 1302 |
+
print(f"\nClaudia: {response}\n")
|
| 1303 |
+
|
| 1304 |
+
# Log to file
|
| 1305 |
+
self._log_exchange(user_input, response)
|
| 1306 |
+
|
| 1307 |
+
# ββ THE CORE LOOP: exchange + quiz β two-phase absorb ββ
|
| 1308 |
+
|
| 1309 |
+
# 1. Store the raw exchange
|
| 1310 |
+
exchange = {
|
| 1311 |
+
"messages": [
|
| 1312 |
+
{"role": "user", "content": user_input},
|
| 1313 |
+
{"role": "assistant", "content": response},
|
| 1314 |
+
]
|
| 1315 |
+
}
|
| 1316 |
+
self.all_training_data.append(exchange)
|
| 1317 |
+
|
| 1318 |
+
# 2. Generate self-quiz pairs (THE key lever: 21% β 74%)
|
| 1319 |
+
print(" [Generating quiz pairs...]", end="", flush=True)
|
| 1320 |
+
quiz_pairs = self.quiz_gen.generate(user_input, response)
|
| 1321 |
+
self.quiz_pairs_log.extend(quiz_pairs)
|
| 1322 |
+
|
| 1323 |
+
# 3. Separate positive vs contrastive (key insight from 4e: 73%β93%)
|
| 1324 |
+
positive_batch = []
|
| 1325 |
+
contrastive_batch = []
|
| 1326 |
+
for qp in quiz_pairs:
|
| 1327 |
+
if qp["messages"][1]["content"].lower().startswith("no."):
|
| 1328 |
+
contrastive_batch.append(qp)
|
| 1329 |
+
else:
|
| 1330 |
+
positive_batch.append(qp)
|
| 1331 |
+
|
| 1332 |
+
self.all_training_data.extend(quiz_pairs)
|
| 1333 |
+
print(f" {len(quiz_pairs)} quizzes (pos={len(positive_batch)}, "
|
| 1334 |
+
f"contr={len(contrastive_batch)}). Pool: {len(self.all_training_data)}")
|
| 1335 |
+
|
| 1336 |
+
# 4. Two-phase absorption (prevents overfitting)
|
| 1337 |
+
self._pending_exchange = exchange
|
| 1338 |
+
self._pending_positive = positive_batch
|
| 1339 |
+
self._pending_contrastive = contrastive_batch
|
| 1340 |
+
self.exchange_count += 1
|
| 1341 |
+
if self.exchange_count % ABSORB_EVERY == 0:
|
| 1342 |
+
self._start_absorption()
|
| 1343 |
+
|
| 1344 |
+
def _extract_key_entities(self, text):
|
| 1345 |
+
"""Extract key factual entities from a quiz answer for verification."""
|
| 1346 |
+
entities = set()
|
| 1347 |
+
words = text.split()
|
| 1348 |
+
for i, w in enumerate(words):
|
| 1349 |
+
clean = re.sub(r'[^a-zA-Z0-9\'-]', '', w)
|
| 1350 |
+
if not clean or len(clean) <= 1:
|
| 1351 |
+
continue
|
| 1352 |
+
# Proper nouns (capitalized, not sentence starters, not common words)
|
| 1353 |
+
skip = {"matt", "matt's", "the", "is", "a", "an", "in", "at", "on",
|
| 1354 |
+
"of", "for", "and", "that", "not", "who", "what", "his", "her"}
|
| 1355 |
+
if clean[0].isupper() and i > 0 and clean.lower() not in skip:
|
| 1356 |
+
entities.add(clean.lower())
|
| 1357 |
+
# Numbers (dates, ages, years)
|
| 1358 |
+
for num in re.findall(r'\b\d+\b', text):
|
| 1359 |
+
entities.add(num)
|
| 1360 |
+
# Quoted strings
|
| 1361 |
+
for quoted in re.findall(r'"([^"]+)"', text):
|
| 1362 |
+
entities.add(quoted.lower())
|
| 1363 |
+
return entities
|
| 1364 |
+
|
| 1365 |
+
def _periodic_verification(self):
|
| 1366 |
+
"""Test model on random sample of quiz pairs. Create contrastive corrections.
|
| 1367 |
+
v9: When entity confusion detected, create 'NOT X' corrections and reinforce
|
| 1368 |
+
the confused entity's correct facts too (sister pair reinforcement)."""
|
| 1369 |
+
import random
|
| 1370 |
+
if not self.quiz_pairs_log:
|
| 1371 |
+
return
|
| 1372 |
+
|
| 1373 |
+
sample_size = min(VERIFY_SAMPLE, len(self.quiz_pairs_log))
|
| 1374 |
+
sample = random.sample(self.quiz_pairs_log, sample_size)
|
| 1375 |
+
|
| 1376 |
+
corrections = []
|
| 1377 |
+
correct = 0
|
| 1378 |
+
|
| 1379 |
+
for pair in sample:
|
| 1380 |
+
question = pair["messages"][0]["content"]
|
| 1381 |
+
expected = pair["messages"][1]["content"]
|
| 1382 |
+
|
| 1383 |
+
# Ask the model
|
| 1384 |
+
actual = self.mm.generate(
|
| 1385 |
+
[{"role": "user", "content": question}],
|
| 1386 |
+
max_new_tokens=150,
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
# Check key entities from expected answer appear in model's response
|
| 1390 |
+
expected_entities = self._extract_key_entities(expected)
|
| 1391 |
+
if not expected_entities:
|
| 1392 |
+
correct += 1
|
| 1393 |
+
continue
|
| 1394 |
+
|
| 1395 |
+
actual_lower = actual.lower()
|
| 1396 |
+
hits = sum(1 for e in expected_entities if e in actual_lower)
|
| 1397 |
+
ratio = hits / len(expected_entities)
|
| 1398 |
+
|
| 1399 |
+
if ratio < 0.5:
|
| 1400 |
+
# Detect cross-entity confusion: model used wrong entities
|
| 1401 |
+
actual_entities = self._extract_key_entities(actual)
|
| 1402 |
+
wrong_entities = actual_entities - expected_entities
|
| 1403 |
+
|
| 1404 |
+
# Always retrain on the correct answer (clean, no "NOT X" text)
|
| 1405 |
+
corrections.append(pair)
|
| 1406 |
+
|
| 1407 |
+
if wrong_entities:
|
| 1408 |
+
# SISTER PAIR REINFORCEMENT: find quiz pairs about the
|
| 1409 |
+
# confused entities and retrain on those too β this teaches
|
| 1410 |
+
# BOTH sides of the confusion without polluting answers
|
| 1411 |
+
for p in self.quiz_pairs_log:
|
| 1412 |
+
p_answer = p["messages"][1]["content"].lower()
|
| 1413 |
+
if any(we in p_answer for we in wrong_entities):
|
| 1414 |
+
if p not in corrections and p != pair:
|
| 1415 |
+
corrections.append(p)
|
| 1416 |
+
break # Max 1 sister pair per confusion
|
| 1417 |
+
else:
|
| 1418 |
+
correct += 1
|
| 1419 |
+
|
| 1420 |
+
print(f"\n [Verification: {correct}/{sample_size} facts correct]", flush=True)
|
| 1421 |
+
|
| 1422 |
+
if corrections:
|
| 1423 |
+
print(f" [Retraining {len(corrections)} corrections + sister pairs...]", flush=True)
|
| 1424 |
+
loss = self.mm.absorb(corrections)
|
| 1425 |
+
self.all_training_data.extend(corrections)
|
| 1426 |
+
print(f" [Correction absorption done, loss={loss:.4f}]")
|
| 1427 |
+
|
| 1428 |
+
# Teacher-guided distillation: if teacher cache available,
|
| 1429 |
+
# also distill from teacher on the corrected quiz pairs.
|
| 1430 |
+
# This gives the student the teacher's full output distribution,
|
| 1431 |
+
# not just the text answer β more information per correction.
|
| 1432 |
+
if self.teacher_cache:
|
| 1433 |
+
distill_items = []
|
| 1434 |
+
for corr in corrections:
|
| 1435 |
+
q = corr["messages"][0]["content"].lower()[:60]
|
| 1436 |
+
for cached in self.teacher_cache:
|
| 1437 |
+
cq = cached["pair"]["messages"][0]["content"].lower()[:60]
|
| 1438 |
+
if q == cq:
|
| 1439 |
+
distill_items.append(cached)
|
| 1440 |
+
break
|
| 1441 |
+
if distill_items:
|
| 1442 |
+
d_loss = self.mm.distill(distill_items, epochs=1)
|
| 1443 |
+
print(f" [Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}]")
|
| 1444 |
+
|
| 1445 |
+
def _quick_verify_entities(self):
|
| 1446 |
+
"""Returns set of confused entity names by checking known_entities."""
|
| 1447 |
+
confused = set()
|
| 1448 |
+
entities = self.quiz_gen.known_entities
|
| 1449 |
+
if not entities:
|
| 1450 |
+
return confused
|
| 1451 |
+
for name, info in entities.items():
|
| 1452 |
+
if info.get("job"):
|
| 1453 |
+
q = f"What does Matt's {info['relationship']} {name} do?"
|
| 1454 |
+
ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
|
| 1455 |
+
if info["job"].lower() not in ans.lower():
|
| 1456 |
+
confused.add(name)
|
| 1457 |
+
if info.get("city"):
|
| 1458 |
+
q = f"Where does {name} live?"
|
| 1459 |
+
ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
|
| 1460 |
+
if info["city"].lower() not in ans.lower():
|
| 1461 |
+
confused.add(name)
|
| 1462 |
+
return confused
|
| 1463 |
+
|
| 1464 |
+
def _start_absorption(self):
|
| 1465 |
+
"""Two-phase absorption in background thread (proven 93% in session 4e).
|
| 1466 |
+
Phase 1: exchange + positive quizzes + replay, clustered by entity.
|
| 1467 |
+
Phase 2: Verify entities, train only targeted contrastive for confused ones.
|
| 1468 |
+
Phase 3: Stubborn retry for persistently confused entities (max 2 retries)."""
|
| 1469 |
+
import random
|
| 1470 |
+
|
| 1471 |
+
# Grab pending data
|
| 1472 |
+
exchange = getattr(self, '_pending_exchange', None)
|
| 1473 |
+
positive = getattr(self, '_pending_positive', [])
|
| 1474 |
+
contrastive = getattr(self, '_pending_contrastive', [])
|
| 1475 |
+
|
| 1476 |
+
# Old data for replay
|
| 1477 |
+
new_start = getattr(self, '_last_absorb_idx', 0)
|
| 1478 |
+
old_data = self.all_training_data[:new_start]
|
| 1479 |
+
self._last_absorb_idx = len(self.all_training_data)
|
| 1480 |
+
|
| 1481 |
+
MAX_REPLAY = 6
|
| 1482 |
+
if old_data and len(old_data) > MAX_REPLAY:
|
| 1483 |
+
replay_sample = random.sample(old_data, MAX_REPLAY)
|
| 1484 |
+
else:
|
| 1485 |
+
replay_sample = list(old_data)
|
| 1486 |
+
|
| 1487 |
+
entity_names = list(self.quiz_gen.known_entities.keys())
|
| 1488 |
+
|
| 1489 |
+
def _run():
|
| 1490 |
+
t0 = time.time()
|
| 1491 |
+
try:
|
| 1492 |
+
# ββ Phase 1: Positive facts + replay, clustered by entity ββ
|
| 1493 |
+
phase1_data = []
|
| 1494 |
+
if exchange:
|
| 1495 |
+
phase1_data.append(exchange)
|
| 1496 |
+
phase1_data.extend(positive)
|
| 1497 |
+
phase1_data.extend(replay_sample)
|
| 1498 |
+
|
| 1499 |
+
if entity_names and phase1_data:
|
| 1500 |
+
phase1_data = ModelManager.cluster_by_entity(phase1_data, entity_names)
|
| 1501 |
+
|
| 1502 |
+
loss1 = self.mm.absorb(phase1_data) if phase1_data else 0.0
|
| 1503 |
+
n_p1 = len(phase1_data)
|
| 1504 |
+
|
| 1505 |
+
# ββ Phase 2: Targeted contrastive for confused entities ββ
|
| 1506 |
+
loss2 = None
|
| 1507 |
+
n_p2 = 0
|
| 1508 |
+
if contrastive and entity_names:
|
| 1509 |
+
confused = self._quick_verify_entities()
|
| 1510 |
+
if confused:
|
| 1511 |
+
targeted = []
|
| 1512 |
+
for qp in contrastive:
|
| 1513 |
+
full_text = (qp["messages"][0]["content"] + " " +
|
| 1514 |
+
qp["messages"][1]["content"]).lower()
|
| 1515 |
+
if any(name.lower() in full_text for name in confused):
|
| 1516 |
+
targeted.append(qp)
|
| 1517 |
+
if targeted:
|
| 1518 |
+
loss2 = self.mm.absorb(targeted)
|
| 1519 |
+
n_p2 = len(targeted)
|
| 1520 |
+
print(f"\n [Phase 2: {n_p2} targeted contrastive for {confused}]",
|
| 1521 |
+
flush=True)
|
| 1522 |
+
|
| 1523 |
+
# ββ Phase 3: Stubborn retry (max 2 retries, non-blocking) ββ
|
| 1524 |
+
still_confused = self._quick_verify_entities()
|
| 1525 |
+
for retry in range(2):
|
| 1526 |
+
if not still_confused:
|
| 1527 |
+
break
|
| 1528 |
+
retry_batch = []
|
| 1529 |
+
for name in still_confused:
|
| 1530 |
+
info = self.quiz_gen.known_entities.get(name, {})
|
| 1531 |
+
if info.get("job"):
|
| 1532 |
+
for _ in range(3):
|
| 1533 |
+
retry_batch.append({"messages": [
|
| 1534 |
+
{"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
|
| 1535 |
+
{"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
|
| 1536 |
+
]})
|
| 1537 |
+
if info.get("city"):
|
| 1538 |
+
for _ in range(3):
|
| 1539 |
+
retry_batch.append({"messages": [
|
| 1540 |
+
{"role": "user", "content": f"Where does {name} live?"},
|
| 1541 |
+
{"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
|
| 1542 |
+
]})
|
| 1543 |
+
# Relevant contrastive pairs
|
| 1544 |
+
for qp in contrastive:
|
| 1545 |
+
ft = (qp["messages"][0]["content"] + " " +
|
| 1546 |
+
qp["messages"][1]["content"]).lower()
|
| 1547 |
+
if name.lower() in ft:
|
| 1548 |
+
retry_batch.append(qp)
|
| 1549 |
+
if retry_batch:
|
| 1550 |
+
loss3 = self.mm.absorb(retry_batch)
|
| 1551 |
+
print(f"\n [Phase 3 retry {retry+1}: {len(retry_batch)} items, "
|
| 1552 |
+
f"loss={loss3:.4f}]", flush=True)
|
| 1553 |
+
still_confused = self._quick_verify_entities()
|
| 1554 |
+
if still_confused:
|
| 1555 |
+
print(f"\n [Phase 3: still confused after retries: {still_confused}]",
|
| 1556 |
+
flush=True)
|
| 1557 |
+
|
| 1558 |
+
elapsed = time.time() - t0
|
| 1559 |
+
self.absorption_count += 1
|
| 1560 |
+
loss_str = f"P1={loss1:.4f}"
|
| 1561 |
+
if loss2 is not None:
|
| 1562 |
+
loss_str += f" P2={loss2:.4f}"
|
| 1563 |
+
print(f"\n [Absorbed {n_p1}+{n_p2} examples in {elapsed:.1f}s | "
|
| 1564 |
+
f"{loss_str} | absorptions={self.absorption_count}]")
|
| 1565 |
+
|
| 1566 |
+
# Periodic verification β catch drift/confusion
|
| 1567 |
+
if self.absorption_count % VERIFY_EVERY == 0:
|
| 1568 |
+
self._periodic_verification()
|
| 1569 |
+
|
| 1570 |
+
# Auto-checkpoint
|
| 1571 |
+
if self.absorption_count % CHECKPOINT_EVERY == 0:
|
| 1572 |
+
self._auto_checkpoint()
|
| 1573 |
+
|
| 1574 |
+
except Exception as e:
|
| 1575 |
+
print(f"\n [Absorption error: {e}]")
|
| 1576 |
+
import traceback
|
| 1577 |
+
traceback.print_exc()
|
| 1578 |
+
|
| 1579 |
+
self.absorption_thread = threading.Thread(target=_run, daemon=True)
|
| 1580 |
+
self.absorption_thread.start()
|
| 1581 |
+
|
| 1582 |
+
def _wait_for_absorption(self):
|
| 1583 |
+
if self.absorption_thread and self.absorption_thread.is_alive():
|
| 1584 |
+
self.absorption_thread.join()
|
| 1585 |
+
self.absorption_thread = None
|
| 1586 |
+
|
| 1587 |
+
def _cleanup_old_checkpoints(self, keep=None):
|
| 1588 |
+
"""Delete old checkpoints to free disk. Keep only 'keep' path if specified."""
|
| 1589 |
+
if not os.path.exists(self.checkpoint_dir):
|
| 1590 |
+
return
|
| 1591 |
+
for entry in os.listdir(self.checkpoint_dir):
|
| 1592 |
+
full = os.path.join(self.checkpoint_dir, entry)
|
| 1593 |
+
if full == keep:
|
| 1594 |
+
continue
|
| 1595 |
+
if os.path.isdir(full) and entry.startswith("claudia_"):
|
| 1596 |
+
import shutil
|
| 1597 |
+
size_gb = sum(
|
| 1598 |
+
os.path.getsize(os.path.join(dp, f))
|
| 1599 |
+
for dp, _, fns in os.walk(full) for f in fns
|
| 1600 |
+
) / 1e9
|
| 1601 |
+
print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)")
|
| 1602 |
+
shutil.rmtree(full)
|
| 1603 |
+
|
| 1604 |
+
def _auto_checkpoint(self):
|
| 1605 |
+
"""Auto-save checkpoint during long sessions."""
|
| 1606 |
+
version = f"auto_{self.absorption_count}"
|
| 1607 |
+
path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
|
| 1608 |
+
self._cleanup_old_checkpoints()
|
| 1609 |
+
self.mm.merge_and_save(path)
|
| 1610 |
+
self.last_checkpoint = path
|
| 1611 |
+
self._save_replay_buffer(path)
|
| 1612 |
+
|
| 1613 |
+
def _save_and_exit(self):
|
| 1614 |
+
"""Final save on exit with targeted correction."""
|
| 1615 |
+
import random
|
| 1616 |
+
|
| 1617 |
+
# Final verify + stubborn retry (not bulk retrain β prevents overfitting)
|
| 1618 |
+
confused = self._quick_verify_entities()
|
| 1619 |
+
if confused:
|
| 1620 |
+
print(f" Final correction for confused entities: {confused}")
|
| 1621 |
+
# Gather contrastive pairs from quiz log
|
| 1622 |
+
contrastive = [qp for qp in self.quiz_pairs_log
|
| 1623 |
+
if qp["messages"][1]["content"].lower().startswith("no.")]
|
| 1624 |
+
for retry in range(3):
|
| 1625 |
+
if not confused:
|
| 1626 |
+
break
|
| 1627 |
+
retry_batch = []
|
| 1628 |
+
for name in confused:
|
| 1629 |
+
info = self.quiz_gen.known_entities.get(name, {})
|
| 1630 |
+
if info.get("job"):
|
| 1631 |
+
for _ in range(3):
|
| 1632 |
+
retry_batch.append({"messages": [
|
| 1633 |
+
{"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
|
| 1634 |
+
{"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
|
| 1635 |
+
]})
|
| 1636 |
+
if info.get("city"):
|
| 1637 |
+
for _ in range(3):
|
| 1638 |
+
retry_batch.append({"messages": [
|
| 1639 |
+
{"role": "user", "content": f"Where does {name} live?"},
|
| 1640 |
+
{"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
|
| 1641 |
+
]})
|
| 1642 |
+
for qp in contrastive:
|
| 1643 |
+
ft = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower()
|
| 1644 |
+
if name.lower() in ft:
|
| 1645 |
+
retry_batch.append(qp)
|
| 1646 |
+
if retry_batch:
|
| 1647 |
+
loss = self.mm.absorb(retry_batch)
|
| 1648 |
+
print(f" Final retry {retry+1}: {len(retry_batch)} items, loss={loss:.4f}")
|
| 1649 |
+
confused = self._quick_verify_entities()
|
| 1650 |
+
self.absorption_count += 1
|
| 1651 |
+
else:
|
| 1652 |
+
print(" All entities verified correct β no final correction needed.")
|
| 1653 |
+
|
| 1654 |
+
# Personality check before saving
|
| 1655 |
+
print("\n--- Pre-Save Personality Check ---")
|
| 1656 |
+
score = check_personality(self.mm)
|
| 1657 |
+
if score < 0.5:
|
| 1658 |
+
print(" WARNING: Personality degraded. Saving anyway (rollback available).")
|
| 1659 |
+
|
| 1660 |
+
# Merge and save (cleanup old checkpoints first to free disk)
|
| 1661 |
+
version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}"
|
| 1662 |
+
path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
|
| 1663 |
+
self._cleanup_old_checkpoints()
|
| 1664 |
+
self.mm.merge_and_save(path)
|
| 1665 |
+
self.last_checkpoint = path
|
| 1666 |
+
|
| 1667 |
+
# Save replay buffer alongside checkpoint
|
| 1668 |
+
self._save_replay_buffer(path)
|
| 1669 |
+
|
| 1670 |
+
# ββ Cascade Distillation: cache teacher logits for next session ββ
|
| 1671 |
+
# After merge+fresh LoRA, model outputs are identical to pre-merge state.
|
| 1672 |
+
# Cache the teacher's top-K logits so the next session can distill from them.
|
| 1673 |
+
if self.quiz_pairs_log:
|
| 1674 |
+
n_cache = min(len(self.quiz_pairs_log), MAX_TEACHER_CACHE)
|
| 1675 |
+
print(f" Caching teacher logits ({n_cache} quiz pairs)...")
|
| 1676 |
+
teacher_cache = self.mm.cache_teacher_logits(self.quiz_pairs_log)
|
| 1677 |
+
cache_path = os.path.join(path, "teacher_cache.pt")
|
| 1678 |
+
torch.save(teacher_cache, cache_path)
|
| 1679 |
+
size_mb = os.path.getsize(cache_path) / 1e6
|
| 1680 |
+
print(f" Teacher cache saved ({len(teacher_cache)} items, {size_mb:.1f} MB)")
|
| 1681 |
+
del teacher_cache
|
| 1682 |
+
torch.cuda.empty_cache()
|
| 1683 |
+
|
| 1684 |
+
# Save quiz pairs log for next session
|
| 1685 |
+
quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
|
| 1686 |
+
with open(quiz_log_path, 'w') as f:
|
| 1687 |
+
json.dump(self.quiz_pairs_log, f)
|
| 1688 |
+
|
| 1689 |
+
# Save session metadata
|
| 1690 |
+
meta = {
|
| 1691 |
+
"checkpoint": path,
|
| 1692 |
+
"absorption_count": self.absorption_count,
|
| 1693 |
+
"exchange_count": self.exchange_count,
|
| 1694 |
+
"training_pool_size": len(self.all_training_data),
|
| 1695 |
+
"personality_score": score,
|
| 1696 |
+
"timestamp": datetime.now().isoformat(),
|
| 1697 |
+
}
|
| 1698 |
+
meta_path = os.path.join(self.log_dir, f"session_{version}.json")
|
| 1699 |
+
with open(meta_path, 'w') as f:
|
| 1700 |
+
json.dump(meta, f, indent=2)
|
| 1701 |
+
print(f" Session saved: {meta_path}")
|
| 1702 |
+
print(f" Next run: use --checkpoint {path}")
|
| 1703 |
+
|
| 1704 |
+
def _save_replay_buffer(self, checkpoint_path=None):
|
| 1705 |
+
"""Save training data pool for next session resume."""
|
| 1706 |
+
# Always save to log dir (canonical location for resume)
|
| 1707 |
+
path = os.path.join(self.log_dir, "replay_buffer.json")
|
| 1708 |
+
with open(path, 'w') as f:
|
| 1709 |
+
json.dump(self.all_training_data, f)
|
| 1710 |
+
# Also save into checkpoint dir for self-contained checkpoints
|
| 1711 |
+
if checkpoint_path and os.path.isdir(checkpoint_path):
|
| 1712 |
+
cp_path = os.path.join(checkpoint_path, "replay_buffer.json")
|
| 1713 |
+
with open(cp_path, 'w') as f:
|
| 1714 |
+
json.dump(self.all_training_data, f)
|
| 1715 |
+
# Save quiz pairs log too
|
| 1716 |
+
quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
|
| 1717 |
+
with open(quiz_log_path, 'w') as f:
|
| 1718 |
+
json.dump(self.quiz_pairs_log, f)
|
| 1719 |
+
print(f" Replay buffer saved ({len(self.all_training_data)} examples)")
|
| 1720 |
+
|
| 1721 |
+
def _log_exchange(self, user_msg, assistant_msg):
|
| 1722 |
+
"""Append exchange to conversation log file."""
|
| 1723 |
+
with open(self.log_path, 'a', encoding='utf-8') as f:
|
| 1724 |
+
entry = {
|
| 1725 |
+
"timestamp": datetime.now().isoformat(),
|
| 1726 |
+
"user": user_msg,
|
| 1727 |
+
"assistant": assistant_msg,
|
| 1728 |
+
}
|
| 1729 |
+
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
| 1730 |
+
|
| 1731 |
+
def _handle_command(self, cmd):
|
| 1732 |
+
"""Handle slash commands. Returns True if should exit."""
|
| 1733 |
+
cmd_lower = cmd.lower().strip()
|
| 1734 |
+
|
| 1735 |
+
if cmd_lower == "/quit":
|
| 1736 |
+
print("[Saving and exiting...]")
|
| 1737 |
+
self._wait_for_absorption()
|
| 1738 |
+
self._save_and_exit()
|
| 1739 |
+
return True
|
| 1740 |
+
|
| 1741 |
+
elif cmd_lower == "/status":
|
| 1742 |
+
self._wait_for_absorption()
|
| 1743 |
+
vram = torch.cuda.memory_allocated() / 1e9
|
| 1744 |
+
print(f"\n --- Status ---")
|
| 1745 |
+
print(f" Exchanges: {self.exchange_count}")
|
| 1746 |
+
print(f" Absorptions: {self.absorption_count}")
|
| 1747 |
+
print(f" Training pool: {len(self.all_training_data)} examples")
|
| 1748 |
+
print(f" Buffer: {len(self.conversation_buffer)} messages")
|
| 1749 |
+
print(f" VRAM: {vram:.1f} GB")
|
| 1750 |
+
print(f" Background: {'running' if self.absorption_thread and self.absorption_thread.is_alive() else 'idle'}")
|
| 1751 |
+
print(f" Last checkpoint: {self.last_checkpoint}")
|
| 1752 |
+
print(f" --- End ---\n")
|
| 1753 |
+
|
| 1754 |
+
elif cmd_lower == "/absorb":
|
| 1755 |
+
self._wait_for_absorption()
|
| 1756 |
+
if not self.all_training_data:
|
| 1757 |
+
print(" No data to absorb.")
|
| 1758 |
+
return False
|
| 1759 |
+
# Cap at most recent 40 examples to prevent overfitting
|
| 1760 |
+
import random
|
| 1761 |
+
data = self.all_training_data
|
| 1762 |
+
if len(data) > 40:
|
| 1763 |
+
recent = data[-20:]
|
| 1764 |
+
older = random.sample(data[:-20], 20)
|
| 1765 |
+
data = recent + older
|
| 1766 |
+
print(f" Force absorption ({len(data)} examples)...")
|
| 1767 |
+
loss = self.mm.absorb(data)
|
| 1768 |
+
self.absorption_count += 1
|
| 1769 |
+
print(f" Done. Loss: {loss:.4f}")
|
| 1770 |
+
|
| 1771 |
+
# ββ Post-absorb comprehensive verification + distillation ββ
|
| 1772 |
+
# Run FULL verification (all quiz pairs, not just sample) to catch
|
| 1773 |
+
# all regressions before recall questions. This is the critical
|
| 1774 |
+
# window between teaching and testing.
|
| 1775 |
+
if self.quiz_pairs_log:
|
| 1776 |
+
print(f"\n --- Post-absorb verification (ALL {len(self.quiz_pairs_log)} quiz pairs) ---")
|
| 1777 |
+
old_verify_sample = VERIFY_SAMPLE
|
| 1778 |
+
# Test ALL quiz pairs, not just a sample
|
| 1779 |
+
full_corrections = []
|
| 1780 |
+
full_correct = 0
|
| 1781 |
+
test_pairs = self.quiz_pairs_log
|
| 1782 |
+
|
| 1783 |
+
for pair in test_pairs:
|
| 1784 |
+
question = pair["messages"][0]["content"]
|
| 1785 |
+
expected = pair["messages"][1]["content"]
|
| 1786 |
+
actual = self.mm.generate(
|
| 1787 |
+
[{"role": "user", "content": question}],
|
| 1788 |
+
max_new_tokens=150,
|
| 1789 |
+
)
|
| 1790 |
+
expected_entities = self._extract_key_entities(expected)
|
| 1791 |
+
if not expected_entities:
|
| 1792 |
+
full_correct += 1
|
| 1793 |
+
continue
|
| 1794 |
+
actual_lower = actual.lower()
|
| 1795 |
+
hits = sum(1 for e in expected_entities if e in actual_lower)
|
| 1796 |
+
ratio = hits / len(expected_entities)
|
| 1797 |
+
if ratio < 0.5:
|
| 1798 |
+
actual_entities = self._extract_key_entities(actual)
|
| 1799 |
+
wrong_entities = actual_entities - expected_entities
|
| 1800 |
+
full_corrections.append(pair)
|
| 1801 |
+
if wrong_entities:
|
| 1802 |
+
for p in self.quiz_pairs_log:
|
| 1803 |
+
p_answer = p["messages"][1]["content"].lower()
|
| 1804 |
+
if any(we in p_answer for we in wrong_entities):
|
| 1805 |
+
if p not in full_corrections and p != pair:
|
| 1806 |
+
full_corrections.append(p)
|
| 1807 |
+
break
|
| 1808 |
+
else:
|
| 1809 |
+
full_correct += 1
|
| 1810 |
+
|
| 1811 |
+
print(f" Full verification: {full_correct}/{len(test_pairs)} correct")
|
| 1812 |
+
if full_corrections:
|
| 1813 |
+
print(f" Retraining {len(full_corrections)} corrections...")
|
| 1814 |
+
c_loss = self.mm.absorb(full_corrections)
|
| 1815 |
+
self.all_training_data.extend(full_corrections)
|
| 1816 |
+
print(f" Correction loss: {c_loss:.4f}")
|
| 1817 |
+
# Teacher distillation on corrections
|
| 1818 |
+
if self.teacher_cache:
|
| 1819 |
+
distill_items = []
|
| 1820 |
+
for corr in full_corrections:
|
| 1821 |
+
q = corr["messages"][0]["content"].lower()[:60]
|
| 1822 |
+
for cached in self.teacher_cache:
|
| 1823 |
+
cq = cached["pair"]["messages"][0]["content"].lower()[:60]
|
| 1824 |
+
if q == cq:
|
| 1825 |
+
distill_items.append(cached)
|
| 1826 |
+
break
|
| 1827 |
+
if distill_items:
|
| 1828 |
+
d_loss = self.mm.distill(distill_items, epochs=1)
|
| 1829 |
+
print(f" Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}")
|
| 1830 |
+
print(f" --- End post-absorb verification ---\n")
|
| 1831 |
+
|
| 1832 |
+
elif cmd_lower == "/save":
|
| 1833 |
+
self._wait_for_absorption()
|
| 1834 |
+
version = f"manual_{self.absorption_count}"
|
| 1835 |
+
path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
|
| 1836 |
+
print(f" Saving checkpoint...")
|
| 1837 |
+
# Personality check
|
| 1838 |
+
score = check_personality(self.mm, verbose=False)
|
| 1839 |
+
if score < 0.5:
|
| 1840 |
+
print(f" WARNING: Personality score {score:.0%}. Save anyway? (y/n)")
|
| 1841 |
+
confirm = input(" > ").strip().lower()
|
| 1842 |
+
if confirm != 'y':
|
| 1843 |
+
print(" Aborted.")
|
| 1844 |
+
return False
|
| 1845 |
+
self._cleanup_old_checkpoints()
|
| 1846 |
+
self.mm.merge_and_save(path)
|
| 1847 |
+
self.last_checkpoint = path
|
| 1848 |
+
self._save_replay_buffer(path)
|
| 1849 |
+
|
| 1850 |
+
elif cmd_lower == "/personality":
|
| 1851 |
+
self._wait_for_absorption()
|
| 1852 |
+
print("\n--- Personality Check ---")
|
| 1853 |
+
check_personality(self.mm)
|
| 1854 |
+
print()
|
| 1855 |
+
|
| 1856 |
+
elif cmd_lower == "/help":
|
| 1857 |
+
print(" /status - show stats")
|
| 1858 |
+
print(" /absorb - force immediate training")
|
| 1859 |
+
print(" /save - merge + save checkpoint")
|
| 1860 |
+
print(" /personality - run personality check")
|
| 1861 |
+
print(" /quit - save and exit")
|
| 1862 |
+
|
| 1863 |
+
else:
|
| 1864 |
+
print(f" Unknown: {cmd}. Try /help")
|
| 1865 |
+
|
| 1866 |
+
return False
|
| 1867 |
+
|
| 1868 |
+
|
| 1869 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1870 |
+
# MAIN
|
| 1871 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1872 |
+
|
| 1873 |
+
def main():
|
| 1874 |
+
parser = argparse.ArgumentParser(
|
| 1875 |
+
description="Claudia Persistent Absorber v2 β conversation β permanent weights"
|
| 1876 |
+
)
|
| 1877 |
+
parser.add_argument(
|
| 1878 |
+
"--model_path", required=True,
|
| 1879 |
+
help="Path to base Qwen3-Omni model (or checkpoint for resume)"
|
| 1880 |
+
)
|
| 1881 |
+
parser.add_argument(
|
| 1882 |
+
"--adapter_path", default=None,
|
| 1883 |
+
help="Path to Claudia v6 personality adapter (first run only)"
|
| 1884 |
+
)
|
| 1885 |
+
parser.add_argument(
|
| 1886 |
+
"--ffn_patch", default=None,
|
| 1887 |
+
help="Path to ffn_patch.pt (first run only)"
|
| 1888 |
+
)
|
| 1889 |
+
parser.add_argument(
|
| 1890 |
+
"--checkpoint", default=None,
|
| 1891 |
+
help="Resume from this checkpoint (has personality + memories baked in)"
|
| 1892 |
+
)
|
| 1893 |
+
parser.add_argument(
|
| 1894 |
+
"--checkpoint_dir", default="/workspace/checkpoints",
|
| 1895 |
+
help="Where to save checkpoints"
|
| 1896 |
+
)
|
| 1897 |
+
parser.add_argument(
|
| 1898 |
+
"--log_dir", default="/workspace/logs",
|
| 1899 |
+
help="Where to save conversation logs and replay buffer"
|
| 1900 |
+
)
|
| 1901 |
+
parser.add_argument(
|
| 1902 |
+
"--absorb_every", type=int, default=ABSORB_EVERY,
|
| 1903 |
+
help=f"Absorb every N exchanges (default: {ABSORB_EVERY})"
|
| 1904 |
+
)
|
| 1905 |
+
args = parser.parse_args()
|
| 1906 |
+
|
| 1907 |
+
# Determine if first run or resume
|
| 1908 |
+
if args.checkpoint:
|
| 1909 |
+
print(f"RESUMING from checkpoint: {args.checkpoint}")
|
| 1910 |
+
absorber = PersistentAbsorber(
|
| 1911 |
+
model_path=args.model_path,
|
| 1912 |
+
checkpoint_path=args.checkpoint,
|
| 1913 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 1914 |
+
log_dir=args.log_dir,
|
| 1915 |
+
)
|
| 1916 |
+
else:
|
| 1917 |
+
print(f"FIRST RUN β applying personality adapter")
|
| 1918 |
+
if not args.adapter_path:
|
| 1919 |
+
print("ERROR: --adapter_path required for first run")
|
| 1920 |
+
print(" (or use --checkpoint to resume)")
|
| 1921 |
+
sys.exit(1)
|
| 1922 |
+
absorber = PersistentAbsorber(
|
| 1923 |
+
model_path=args.model_path,
|
| 1924 |
+
adapter_path=args.adapter_path,
|
| 1925 |
+
ffn_patch_path=args.ffn_patch,
|
| 1926 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 1927 |
+
log_dir=args.log_dir,
|
| 1928 |
+
)
|
| 1929 |
+
|
| 1930 |
+
absorber.start()
|
| 1931 |
+
|
| 1932 |
+
|
| 1933 |
+
if __name__ == "__main__":
|
| 1934 |
+
main()
|